Cross Validated Asked by Mark.F on January 2, 2021
From the comments, I realized that my question wasn’t clear enough, so I’ll start with a short background.
I am trying to construct an attention model that performs classification based on just a small region of the original image (as small as possible).
For this I have one branch of my model output a binary mask of 0’s and 1’s which is than multiplied element-wise with the original image before the masked image is passed on to the classification branch. The output of the attention branch is a map of HxWx1 taken after a hard sigmoid activation. The architecture is demonstrated in the following figure:
I do not have the GT of the attention mask. This part of the model is trained in an unsupervised manner, but I’m trying to avoid using reinforcement learning.
How do I define a loss function that at the same time:
Attempts:
I initially tried using MSE loss between the mask and a zero mask, but that didn’t work (the output is a mask of small non-zero values of similar magnitude). It didn’t have any incentive to use high values or reduce the number of non-zero elements.
I than moved on to using the following loss function on the mask output:
This loss is summed with the standard cross-entropy loss of the classification output for the total loss of the model:
I_xy is the value of the mask at pixel (x, y). Parameter α∈[0,1] is the sparsity regularizer, the higher its value, the higher the incentive of the model to reduce the number of active pixels. Parameter β is the mask-regularization parameter and it determines the contribution of the mask loss to the total weighted sum (the weight of the classification branch is kept on 1.0) and λ is small coefficient added for numerical stability.
However, the training with this loss function is very unstable (mask tends to zero out) and I don’t feel like I am going the right way.
The issue is that a CNN by definition will have an ever-growing receptive field as you keep going deeper. At some point you'll apply a final global pooling, and this somewhat captures the attention aspect of the image. On the one hand this tends to light up regions of the image that are relevant. On the other hand, each piece of the global pooling layer has a rather large receptive field. For example if the global pooling gives you a 10x10 output, then the top-left value of this corresponds to 1/10 of the image (from the top left).
You could take your current loss function and add a term that minimizes the L1 norm of the output of the global pooling layer. This will somewhat mimic trying to force the model to use as little of the image as possible for the final classification part.
Answered by Alex R. on January 2, 2021
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP