TransWikia.com

Latent loss in variational autoencoder drowns generative loss

Data Science Asked by Ali250 on April 20, 2021

I’m trying to run a variational auto-encoder on the CIFAR-10 dataset, for which I’ve put together a simple network in TensorFlow with 4 layers in the encoder and decoder each, an encoded vector size of 256. For calculating the latent loss, I’m forcing the encoder part of my network to output log variances instead of standard deviations, so the latent loss function looks like:

latent_loss = -0.5 * tf.reduce_sum(1 + log_var_vector - tf.square(mean_vector) - tf.exp(log_var_vector), axis=1)

I found this formulation to be more stable than directly using the logarithms in the KL-divergence formula since the latter often results in infinite loss value. I’m applying a sigmoid activation function on the last layer of the decoder, and the generative loss is computed using mean-squared error. The combined loss is simple a sum of both latent and generative losses. I train the network in batches of 40 using Adam Optimizer with a learning rate of 0.001.

The problem is that my network doesn’t train. The latent loss immediately drops to zero, and the generative loss doesn’t go down. However when I only optimize only for the generative loss, the loss does reduce as expected. Under this setting, the value of the latent loss quickly jumps to very large values (order of 10e4 – 10e6).

I have a hunch that the culprit is the extreme mismatch between the magnitudes of both losses. The KL-divergence is unbounded, whereas the mean-squared error always remains <1, so when optimizing for both, the generative loss basically becomes irrelevant.

Any suggestions to solve the problem are welcome.

2 Answers

I think your hunch is right. The generative loss can't improve because any movement the network would make towards reducing it comes with a huge penalty in the form of the latent loss. It looks like you're squashing the generative loss through a sigmoid, maybe try doing the same thing with the latent loss?

Answered by Matthew on April 20, 2021

I don't like the reduce_sum version of the kl-loss because it depends on the size of your latent vector. My advise is to use the mean instead.

Moreover it is a notorious fact that training a VAE with the kl loss is difficult. You may need to progressively increase the contribution of the kl loss in your total loss. Add a weight w_kl that will control the contribution :

Loss = recons_loss + w_kl * kl_loss

You start with w_kl=0 and progressively increase it every epoch (or batch) to 1. This is a classic trick. Your learning rate seems good, maybe you can try a little higher (4e-4).

If you don't like the tricks, the Wasserstein auto-encoder may be your friend.

Answered by Adrien D on April 20, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP