TransWikia.com

Python (Pytorch) loss function syntax

Data Science Asked by ProJaqf on January 22, 2021

I have seen many examples of this syntax that is being used for the loss function specifically:

loss = nn.BCEWithLogitsLoss()(pred, y)

Can anyone explain me what does the (pred, y) do exactly, that it directly computes the loss, instead of calling loss as a function of these 2 arguments again?

One Answer

This is an example of Python's builtin __call__ method, as described here. In short: BCEWithLogitsLoss is a class. The first set of parentheses (empty, in your case) provides any needed arguments to the class initializer. Then the second set of parentheses are passed to the call method. So, this is convenient syntax that allows you to instantiate the class and evaluate one of its methods in one line.

Confirming this in the source code is actually a bit difficult. You can see the source for BCEWithLogitsLoss here, which confirms it is a class. But its only method is forward; where is __call__? For that, we notice that the BCEWithLogitsLoss class inherits from the _WeightedLoss class, which inherits from Loss, which inherits from Module. We can then see that this base class implements the mapping between __call__ and forward.

Correct answer by cag51 on January 22, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP