Data Science Asked by ProJaqf on January 22, 2021
I have seen many examples of this syntax that is being used for the loss function specifically:
loss = nn.BCEWithLogitsLoss()(pred, y)
Can anyone explain me what does the (pred, y) do exactly, that it directly computes the loss, instead of calling loss as a function of these 2 arguments again?
This is an example of Python's builtin __call__
method, as described here. In short: BCEWithLogitsLoss is a class. The first set of parentheses (empty, in your case) provides any needed arguments to the class initializer. Then the second set of parentheses are passed to the call method. So, this is convenient syntax that allows you to instantiate the class and evaluate one of its methods in one line.
Confirming this in the source code is actually a bit difficult. You can see the source for BCEWithLogitsLoss here, which confirms it is a class. But its only method is forward
; where is __call__
? For that, we notice that the BCEWithLogitsLoss class inherits from the _WeightedLoss
class, which inherits from Loss
, which inherits from Module
. We can then see that this base class implements the mapping between __call__
and forward
.
Correct answer by cag51 on January 22, 2021
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP