TransWikia.com

Generator losses in WGAN and potential convergence failure

Data Science Asked on April 11, 2021

I have been training a WGAN for a while now, with my generator training once in every five epochs.

I have tried several model architectures(no of filters) and also tried varying the relationship with each other. No matter what happens, my output is essentially noise. On further reading, it seems to be a classic case of convergence failure.

Over time, my generator loss gets more and more negative while my discriminator loss remains around -0.4

My guess is that since the discriminator isn’t improving enough, the generator doesn’t get improve enough.

gen_loss = 0.0, disc_loss = -0.03792113810777664
Time for epoch 567 is 3.381150007247925 sec - gen_loss = 0.0, disc_loss = -0.037839196622371674
Time for epoch 568 is 3.3113789558410645 sec - gen_loss = 0.0, disc_loss = -0.040219761431217194
Time for epoch 569 is 3.2963240146636963 sec - gen_loss = 0.0, disc_loss = -0.04105686396360397
Time for epoch 570 is 9.3097665309906 sec - gen_loss = -0.3167822062969208, disc_loss = -0.04042121022939682
Time for epoch 571 is 3.314333200454712 sec - gen_loss = 0.0, disc_loss = -0.03815283626317978
Time for epoch 572 is 3.3485965728759766 sec - gen_loss = 0.0, disc_loss = -0.038569360971450806
Time for epoch 573 is 3.3110241889953613 sec - gen_loss = 0.0, disc_loss = -0.03980369493365288
Time for epoch 574 is 3.4034907817840576 sec - gen_loss = 0.0, disc_loss = -0.0400879867374897
Time for epoch 575 is 9.47887134552002 sec - gen_loss = -0.29222938418388367, disc_loss = -0.04063987731933594
Time for epoch 576 is 3.329411745071411 sec - gen_loss = 0.0, disc_loss = -0.03907758742570877
Time for epoch 577 is 3.3210532665252686 sec - gen_loss = 0.0, disc_loss = -0.03997725248336792
Time for epoch 578 is 3.303483247756958 sec - gen_loss = 0.0, disc_loss = -0.041025031358003616
Time for epoch 579 is 3.3799941539764404 sec - gen_loss = 0.0, disc_loss = -0.04217495769262314
Time for epoch 580 is 9.339993000030518 sec - gen_loss = -0.324483722448349, disc_loss = -0.04243335872888565
Time for epoch 581 is 3.300795555114746 sec - gen_loss = 0.0, disc_loss = -0.04034840315580368
Time for epoch 582 is 3.322876453399658 sec - gen_loss = 0.0, disc_loss = -0.0420600101351738
Time for epoch 583 is 3.328361749649048 sec - gen_loss = 0.0, disc_loss = -0.04354345053434372
Time for epoch 584 is 3.277684211730957 sec - gen_loss = 0.0, disc_loss = -0.04367030784487724
Time for epoch 585 is 9.337851762771606 sec - gen_loss = -0.29389116168022156, disc_loss = -0.04380493611097336
Time for epoch 586 is 3.282655954360962 sec - gen_loss = 0.0, disc_loss = -0.041637666523456573
Time for epoch 587 is 3.2763051986694336 sec - gen_loss = 0.0, disc_loss = -0.04260318726301193
Time for epoch 588 is 3.3087923526763916 sec - gen_loss = 0.0, disc_loss = -0.043825842440128326
Time for epoch 589 is 3.321415662765503 sec - gen_loss = 0.0, disc_loss = -0.044720184057950974
Time for epoch 590 is 9.419749975204468 sec - gen_loss = -0.3067258596420288, disc_loss = -0.044646285474300385
Time for epoch 591 is 3.4093453884124756 sec - gen_loss = 0.0, disc_loss = -0.04444003105163574
Time for epoch 592 is 3.3971118927001953 sec - gen_loss = 0.0, disc_loss = -0.04441792890429497
Time for epoch 593 is 3.4011573791503906 sec - gen_loss = 0.0, disc_loss = -0.04336284473538399

I am not entirely sure about the zeros and I guess it has something to do with my printing:

for epoch in range(epochs):
        start = time.time()
        disc_loss = 0
        gen_loss = 0
        for images in train_dataset:
            #images=np.expand_dims(images, axis=0)
            #images=images/255.
            #images=images.resize
            disc_loss += train_discriminator(images)
            if disc_optimizer.iterations.numpy() % n_critic == 0:
                gen_loss += train_generator()

        print('Time for epoch {} is {} sec - gen_loss = {}, disc_loss = {}'.format(epoch + 1, time.time() - start, gen_loss / batch_size, disc_loss / (batch_size*n_critic)))

        if epoch % save_interval == 0:
            save_imgs(epoch, generator, seed)
            #model.save_weights(checkpoint_path.format(epoch=0))
            #save_weights(checkpoint_path.format(+=0))

The current epoch snippet is for a smaller period. The losses can go as low as:

gen_loss = Time for epoch 420 is 4.274262428283691 sec - gen_loss = -9.779035568237305, disc_loss = -0.9567102193832397
Time for epoch 421 is 3.0774970054626465 sec - gen_loss = 0.0, disc_loss = -0.9157863855361938
Time for epoch 422 is 3.0417258739471436 sec - gen_loss = 0.0, disc_loss = -0.8936088681221008
Time for epoch 423 is 3.0934689044952393 sec - gen_loss = 0.0, disc_loss = -0.896615207195282
Time for epoch 424 is 3.0459794998168945 sec - gen_loss = 0.0, disc_loss = -0.9322511553764343
tf.Tensor(-81.94704, shape=(), dtype=float32)
Time for epoch 425 is 4.237549543380737 sec - gen_loss = -10.243379592895508, disc_loss = -0.9419326782226562
Time for epoch 426 is 3.03023099899292 sec - gen_loss = 0.0, disc_loss = -0.9134870767593384
Time for epoch 427 is 2.998375177383423 sec - gen_loss = 0.0, disc_loss = -0.9380167126655579
Time for epoch 428 is 2.9811060428619385 sec - gen_loss = 0.0, disc_loss = -0.9134092330932617
Time for epoch 429 is 3.087916374206543 sec - gen_loss = 0.0, disc_loss = -0.9051135778427124
tf.Tensor(-83.00238, shape=(), dtype=float32)
Time for epoch 430 is 4.272223949432373 sec - gen_loss = -10.375297546386719, disc_loss = -1.003989577293396
Time for epoch 431 is 3.0454840660095215 sec - gen_loss = 0.0, disc_loss = -0.9496141672134399
Time for epoch 432 is 3.090559959411621 sec - gen_loss = 0.0, disc_loss = -0.9521171450614929
Time for epoch 433 is 3.1101419925689697 sec - gen_loss = 0.0, disc_loss = -0.9876922369003296
Time for epoch 434 is 3.0372989177703857 sec - gen_loss = 0.0, disc_loss = -0.9995473623275757
tf.Tensor(-79.649864, shape=(), dtype=float32)
Time for epoch 435 is 4.32908034324646 sec - gen_loss = -9.956233024597168, disc_loss = -1.003030776977539
Time for epoch 436 is 3.106421947479248 sec - gen_loss = 0.0, disc_loss = -0.9365862011909485
Time for epoch 437 is 3.1067636013031006 sec - gen_loss = 0.0, disc_loss = -1.0536330938339233
Time for epoch 438 is 3.063079833984375 sec - gen_loss = 0.0, disc_loss = -0.9735730886459351
Time for epoch 439 is 3.1522281169891357 sec - gen_loss = 0.0, disc_loss = -0.9937177896499634
tf.Tensor(-75.338615, shape=(), dtype=float32)
Time for epoch 440 is 4.324256896972656 sec - gen_loss = -9.417326927185059, disc_loss = -1.058488130569458
Time for epoch 441 is 3.1664624214172363 sec - gen_loss = 0.0, disc_loss = -0.9834483861923218
Time for epoch 442 is 3.086495876312256 sec - gen_loss = 0.0, disc_loss = -0.9649847149848938
Time for epoch 443 is 3.0835535526275635 sec - gen_loss = 0.0, disc_loss = -1.0420929193496704
Time for epoch 444 is 3.0816898345947266 sec - gen_loss = 0.0, disc_loss = -1.0146191120147705
tf.Tensor(-74.671135, shape=(), dtype=float32)
Time for epoch 445 is 4.27916693687439 sec - gen_loss = -9.333891868591309, disc_loss = -0.9707431793212891
Time for epoch 446 is 3.0552937984466553 sec - gen_loss = 0.0, disc_loss = -1.0040273666381836
Time for epoch 447 is 3.032083034515381 sec - gen_loss = 0.0, disc_loss = -1.078584909439087
Time for epoch 448 is 3.1026992797851562 sec - gen_loss = 0.0, disc_loss = -0.9929893612861633
Time for epoch 449 is 3.111077070236206 sec - gen_loss = 0.0, disc_loss = -1.0746228694915771
tf.Tensor(-76.32341, shape=(), dtype=float32)
Time for epoch 450 is 4.429989337921143 sec - gen_loss = -9.540426254272461, disc_loss = -0.9597347378730774
Time for epoch 451 is 3.0502800941467285 sec - gen_loss = 0.0, disc_loss = -0.9540794491767883
Time for epoch 452 is 3.0913126468658447 sec - gen_loss = 0.0, disc_loss = -1.008721113204956
Time for epoch 453 is 3.05441951751709 sec - gen_loss = 0.0, disc_loss = -1.0138479471206665
Time for epoch 454 is 3.041020631790161 sec - gen_loss = 0.0, disc_loss = -0.9122379422187805
tf.Tensor(-73.73149, shape=(), dtype=float32)
Time for epoch 455 is 4.271183967590332 sec - gen_loss = -9.216436386108398, disc_loss = -0.9789434671401978
Time for epoch 456 is 3.10066556930542 sec - gen_loss = 0.0, disc_loss = -0.9811899065971375
Time for epoch 457 is 3.1411514282226562 sec - gen_loss = 0.0, disc_loss = -0.9947725534439087
Time for epoch 458 is 3.15008807182312 sec - gen_loss = 0.0, disc_loss = -1.0282094478607178
Time for epoch 459 is 3.1146531105041504 sec - gen_loss = 0.0, disc_loss = -1.0895274877548218
tf.Tensor(-79.197815, shape=(), dtype=float32)
Time for epoch 460 is 4.334632396697998 sec - gen_loss = -9.899726867675781, disc_loss = -1.0837827920913696
Time for epoch 461 is 3.084667205810547 sec - gen_loss = 0.0, disc_loss = -1.0000011920928955
Time for epoch 462 is 2.9984536170959473 sec - gen_loss = 0.0, disc_loss = -1.064965009689331
Time for epoch 463 is 3.0733871459960938 sec - gen_loss = 0.0, disc_loss = -1.088195562362671
Time for epoch 464 is 3.0906479358673096 sec - gen_loss = 0.0, disc_loss = -1.0385420322418213
tf.Tensor(-72.012695, shape=(), dtype=float32)
Time for epoch 465 is 4.305361747741699 sec - gen_loss = -9.0015869140625, disc_loss = -1.0921838283538818
Time for epoch 466 is 3.0221426486968994 sec - gen_loss = 0.0, disc_loss = -1.087393045425415
Time for epoch 467 is 3.08805775642395 sec - gen_loss = 0.0, disc_loss = -1.0309044122695923
Time for epoch 468 is 3.0641579627990723 sec - gen_loss = 0.0, disc_loss = -1.021532416343689
Time for epoch 469 is 3.0942575931549072 sec - gen_loss = 0.0, disc_loss = -1.016531229019165
tf.Tensor(-72.69449, shape=(), dtype=float32)
Time for epoch 470 is 4.357362985610962 sec - gen_loss = -9.086811065673828, disc_loss = -1.0745207071304321
Time for epoch 471 is 3.0401017665863037 sec - gen_loss = 0.0, disc_loss = -1.0113626718521118
Time for epoch 472 is 3.0733351707458496 sec - gen_loss = 0.0, disc_loss = -1.0494859218597412
Time for epoch 473 is 3.0829317569732666 sec - gen_loss = 0.0, disc_loss = -1.0223619937896729
Time for epoch 474 is 3.008283853530884 sec - gen_loss = 0.0, disc_loss = -1.0144643783569336
tf.Tensor(-72.84456, shape=(), dtype=float32)
Time for epoch 475 is 4.384555339813232 sec - gen_loss = -9.105569839477539, disc_loss = -1.0208587646484375
Time for epoch 476 is 3.0660083293914795 sec - gen_loss = 0.0, disc_loss = -0.9962297677993774
Time for epoch 477 is 3.0829591751098633 sec - gen_loss = 0.0, disc_loss = -1.1114609241485596
Time for epoch 478 is 3.0963735580444336 sec - gen_loss = 0.0, disc_loss = -1.1005897521972656
Time for epoch 479 is 3.0595879554748535 sec - gen_loss = 0.0, disc_loss = -1.0915520191192627
tf.Tensor(-71.82799, shape=(), dtype=float32)
Time for epoch 480 is 4.281854629516602 sec - gen_loss = -8.978498458862305, disc_loss = -1.1100273132324219
Time for epoch 481 is 3.091787338256836 sec - gen_loss = 0.0, disc_loss = -1.0624439716339111
Time for epoch 482 is 3.006270170211792 sec - gen_loss = 0.0, disc_loss = -1.0699169635772705
Time for epoch 483 is 2.963466167449951 sec - gen_loss = 0.0, disc_loss = -1.0502550601959229
Time for epoch 484 is 3.079402446746826 sec - gen_loss = 0.0, disc_loss = -1.0949811935424805
tf.Tensor(-75.14867, shape=(), dtype=float32)
Time for epoch 485 is 4.3246009349823 sec - gen_loss = -9.393583297729492, disc_loss = -1.0904223918914795
Time for epoch 486 is 3.068676710128784 sec - gen_loss = 0.0, disc_loss = -1.0440151691436768
Time for epoch 487 is 3.0410079956054688 sec - gen_loss = 0.0, disc_loss = -1.0777149200439453
Time for epoch 488 is 3.0098347663879395 sec - gen_loss = 0.0, disc_loss = -1.0182307958602905
Time for epoch 489 is 3.0173022747039795 sec - gen_loss = 0.0, disc_loss = -1.0987876653671265
tf.Tensor(-72.26952, shape=(), dtype=float32)
Time for epoch 490 is 4.272535085678101 sec - gen_loss = -9.033690452575684, disc_loss = -1.1007413864135742
Time for epoch 491 is 2.986027479171753 sec - gen_loss = 0.0, disc_loss = -1.1220260858535767
Time for epoch 492 is 3.033278226852417 sec - gen_loss = 0.0, disc_loss = -1.0299326181411743
Time for epoch 493 is 3.047642946243286 sec - gen_loss = 0.0, disc_loss = -1.0932931900024414
Time for epoch 494 is 2.9984898567199707 sec - gen_loss = 0.0, disc_loss = -1.1239049434661865
tf.Tensor(-79.26546, shape=(), dtype=float32)
Time for epoch 495 is 4.306024551391602 sec - gen_loss = -9.908182144165039, disc_loss = -1.0246700048446655
Time for epoch 496 is 2.9821512699127197 sec - gen_loss = 0.0, disc_loss = -1.0074901580810547
Time for epoch 497 is 2.9943675994873047 sec - gen_loss = 0.0, disc_loss = -1.0436185598373413
Time for epoch 498 is 2.981405735015869 sec - gen_loss = 0.0, disc_loss = -1.0111920833587646
Time for epoch 499 is 3.0199577808380127 sec - gen_loss = 0.0, disc_loss = -1.0668728351593018
tf.Tensor(-73.99868, shape=(), dtype=float32)
Time for epoch 500 is 4.266490459442139 sec - gen_loss = -9.249835014343262, disc_loss = -1.1307373046875
Time for epoch 501 is 3.0170788764953613 sec - gen_loss = 0.0, disc_loss = -1.1350711584091187
Time for epoch 502 is 3.077758312225342 sec - gen_loss = 0.0, disc_loss = -1.1081371307373047
Time for epoch 503 is 3.056257724761963 sec - gen_loss = 0.0, disc_loss = -1.1385138034820557
Time for epoch 504 is 3.099039077758789 sec - gen_loss = 0.0, disc_loss = -1.0430777072906494
tf.Tensor(-80.41304, shape=(), dtype=float32)
Time for epoch 505 is 4.288381099700928 sec - gen_loss = -10.051630020141602, disc_loss = -1.0938444137573242
Time for epoch 506 is 2.9821603298187256 sec - gen_loss = 0.0, disc_loss = -1.0110225677490234
Time for epoch 507 is 3.041010618209839 sec - gen_loss = 0.0, disc_loss = -1.036383032798767
Time for epoch 508 is 2.995178699493408 sec - gen_loss = 0.0, disc_loss = -1.0879932641983032
Time for epoch 509 is 3.0064549446105957 sec - gen_loss = 0.0, disc_loss = -1.065393090248108
tf.Tensor(-77.57954, shape=(), dtype=float32)
Time for epoch 510 is 4.323476076126099 sec - gen_loss = -9.697442054748535, disc_loss = -1.0769941806793213
Time for epoch 511 is 2.997896194458008 sec - gen_loss = 0.0, disc_loss = -1.1603240966796875
Time for epoch 512 is 3.009216785430908 sec - gen_loss = 0.0, disc_loss = -1.0865892171859741
Time for epoch 513 is 3.0198848247528076 sec - gen_loss = 0.0, disc_loss = -1.0485632419586182
Time for epoch 514 is 3.0211234092712402 sec - gen_loss = 0.0, disc_loss = -1.0616661310195923
tf.Tensor(-78.452515, shape=(), dtype=float32)
Time for epoch 515 is 4.19796895980835 sec - gen_loss = -9.806564331054688, disc_loss = -1.0954179763793945
Time for epoch 516 is 3.022573709487915 sec - gen_loss = 0.0, disc_loss = -1.0778796672821045
Time for epoch 517 is 3.0197014808654785 sec - gen_loss = 0.0, disc_loss = -1.0623410940170288
Time for epoch 518 is 3.0477373600006104 sec - gen_loss = 0.0, disc_loss = -1.163627028465271
Time for epoch 519 is 3.0799355506896973 sec - gen_loss = 0.0, disc_loss = -1.149423599243164
tf.Tensor(-76.386566, shape=(), dtype=float32)
Time for epoch 520 is 4.313087224960327 sec - gen_loss = -9.548320770263672, disc_loss = -1.211676836013794
Time for epoch 521 is 3.0512874126434326 sec - gen_loss = 0.0, disc_loss = -1.0809478759765625
Time for epoch 522 is 3.0870730876922607 sec - gen_loss = 0.0, disc_loss = -1.099151611328125
Time for epoch 523 is 3.0168159008026123 sec - gen_loss = 0.0, disc_loss = -1.1511256694793701
Time for epoch 524 is 3.005920171737671 sec - gen_loss = 0.0, disc_loss = -1.069382667541504
tf.Tensor(-79.15112, shape=(), dtype=float32)
Time for epoch 525 is 4.212509870529175 sec - gen_loss = -9.893890380859375, disc_loss = -1.2109851837158203
Time for epoch 526 is 2.9921183586120605 sec - gen_loss = 0.0, disc_loss = -1.0628010034561157
Time for epoch 527 is 2.9992029666900635 sec - gen_loss = 0.0, disc_loss = -1.0484745502471924
Time for epoch 528 is 3.058972120285034 sec - gen_loss = 0.0, disc_loss = -1.0775582790374756
Time for epoch 529 is 2.911708116531372 sec - gen_loss = 0.0, disc_loss = -1.1414921283721924
tf.Tensor(-87.31295, shape=(), dtype=float32)
Time for epoch 530 is 4.178253412246704 sec - gen_loss = -10.914118766784668, disc_loss = -0.9588018655776978
Time for epoch 531 is 2.9323318004608154 sec - gen_loss = 0.0, disc_loss = -1.011317491531372
Time for epoch 532 is 2.960914134979248 sec - gen_loss = 0.0, disc_loss = -1.131879448890686
Time for epoch 533 is 2.939626455307007 sec - gen_loss = 0.0, disc_loss = -1.1429363489151
Time for epoch 534 is 2.974912643432617 sec - gen_loss = 0.0, disc_loss = -1.1597099304199219
tf.Tensor(-79.52698, shape=(), dtype=float32)
Time for epoch 535 is 4.104842662811279 sec - gen_loss = -9.940872192382812, disc_loss = -1.099220633506775
Time for epoch 536 is 2.9220681190490723 sec - gen_loss = 0.0, disc_loss = -1.0265023708343506
Time for epoch 537 is 2.9370899200439453 sec - gen_loss = 0.0, disc_loss = -1.1378921270370483
Time for epoch 538 is 3.0328328609466553 sec - gen_loss = 0.0, disc_loss = -1.0911519527435303
Time for epoch 539 is 3.024807929992676 sec - gen_loss = 0.0, disc_loss = -1.1400941610336304
tf.Tensor(-73.72937, shape=(), dtype=float32)
Time for epoch 540 is 4.289034366607666 sec - gen_loss = -9.216171264648438, disc_loss = -1.2123310565948486
Time for epoch 541 is 2.9510338306427 sec - gen_loss = 0.0, disc_loss = -1.1270792484283447
Time for epoch 542 is 2.9778366088867188 sec - gen_loss = 0.0, disc_loss = -0.9578657150268555
Time for epoch 543 is 3.0087802410125732 sec - gen_loss = 0.0, disc_loss = -1.1833038330078125
Time for epoch 544 is 2.9507062435150146 sec - gen_loss = 0.0, disc_loss = -0.9873024225234985
tf.Tensor(-81.716965, shape=(), dtype=float32)
Time for epoch 545 is 4.178644180297852 sec - gen_loss = -10.214620590209961, disc_loss = -1.1175721883773804
Time for epoch 546 is 3.0262935161590576 sec - gen_loss = 0.0, disc_loss = -0.9683782458305359
Time for epoch 547 is 2.8921780586242676 sec - gen_loss = 0.0, disc_loss = -1.0920101404190063
Time for epoch 548 is 2.9628119468688965 sec - gen_loss = 0.0, disc_loss = -1.1132243871688843
Time for epoch 549 is 3.0580990314483643 sec - gen_loss = 0.0, disc_loss = -1.1871049404144287
tf.Tensor(-80.54263, shape=(), dtype=float32)
Time for epoch 550 is 4.344376564025879 sec - gen_loss = -10.067829132080078, disc_loss = -1.184781789779663
Time for epoch 551 is 3.001291036605835 sec - gen_loss = 0.0, disc_loss = -1.0630102157592773
Time for epoch 552 is 3.0104005336761475 sec - gen_loss = 0.0, disc_loss = -1.1383047103881836
Time for epoch 553 is 2.959202527999878 sec - gen_loss = 0.0, d10.067829132080078, disc_loss = -1.18478178977966

I really don’t know if I should beef up the generator or discriminator as both give me the same issue.

I tried to make my generator twice as powerful as the discriminator

One Answer

Learning with GANs is known to be unstable due to the game theoretic nature of training. Plus, loss is not always a good indicator for GANs. Therefore, sometimes it requires a trial and error approach.

Here is what I would try:

I. If you have not done so: Check intermediate image results of G and see if there is any development in terms of image quality.

II. Check the literature for GANs which have been applied to a similar task like yours, i.e. similar images, and try them.

III. Implement the WGAN exactly as described in one of the WGAN papers, e.g. the ResNet-like net in "Improved Training of Wasserstein GANs" on page 14/15 (here for 32x32 images):

ResNet-like WGAN

Also use the other hyperparameters from the paper incl. loss function, optimizer, optimizer parameters, skipped epochs for G (like your 5/1 scheme) and do not forget batch normalization (part of the Residual blocks in G in above network).

Moreover, implement a simple DCGAN to obtain a baseline.

IV. If your output is truly just noise it could also be an implementation issue. What I mean by "just noise" is that your GAN does not even learn very basic patters like this GAN does where the left hand side is a real example and the right hand side a fake example (output of G):

enter image description here

However, even if your GAN is not able to learn these basic patterns and G outputs purely noise it does not necessarily imply any implementation errors.

Answered by Sammy on April 11, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP