TransWikia.com

Help interpreting GAN output, and how to fix it?

Data Science Asked on November 27, 2021

After a few tries, I had trained a GAN to produce semi-sensible output. In this model, it almost instantly found a solution and got stuck there. The loss for both the discriminator and generator were 0.68 (I have used a BCE loss), and the accuracies for both went to around 50%. The output of the generator looked at first glance good enough to be real data, but after analysing it I could see it was still not very good.

My solution here was to increase the power of the discriminator (increased the size of it) and re-train. I hoped by making it larger it would force the generator to create better samples. I got the following output.

enter image description here
enter image description here

It seems that as the GAN loss increases, and is producing worse samples, the discriminator can pick it out more easily.

When I check my output from the trained generator I see it follows some basic rules the real data is following, but again under closer scrutiny, they fail more complex tests the real data would pass.

My questions are:

  • Is my above interpretation of the plots correct?
  • For this run, have I made the discriminator to powerful? Should I increase the power of the generator?
  • Is there another technique I should investigate to stop this form of mode collapse?

Thank you

EDIT:
The architecture I am using is a form of Graph GAN. The generator is just a series of linear layers. The discriminator is 3 Graph Conv Layers, then some linear layers. Slightly similar to this paper.
Two potentially unconventional things I am doing:

  • There is no batch normalisation, I have found this has a very negative effect on the training. Though I could try and persevere with it.
  • I am using StandardScaler to scale my data. This choice was made as it easily allows you to unscale data. This is useful as I can take the output of the generator and easily transform it into an original scale. However, StandardScaler does not scale things between 1 and -1, so I cannot use tanh as the final activation function of my generator, instead the final layer of the generator is just Linear.

The outputs of the GAN (once rescaled and the shape has been changed) are similar to:

[[ 46.09169   -25.462175   20.705683  -31.696495 ]
 [ 35.10637   -18.956036   15.20579   -24.803787 ]
 [ 10.253135   -5.759581    5.9068713  -6.3003526]]

An example of the truth is:

[[ 45.6         30.294546   -17.218746   -29.41284   ]
 [  1.8186008    1.7064333    0.5984112    0.19312467]
 [ 44.31433     28.234058   -17.615921   -29.262213  ]]

Notably, the top-left value in the matrix will always be 45.6. My Generator does not even consistently produce this.

One Answer

In theory, we should first train the discriminator to optimal. However, if it becomes too good, then it will reject fake data every time, thus making the generator unable to learn. Unlike other ML task, lower GAN losses does not mean the training is converging. Newer GANs have better techniques to make training GAN easier such as Wasserstein loss, spectral normalization and progressive growing.

Answered by SoonYau on November 27, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP