TransWikia.com

WGAN-GP slow critic training time

Data Science Asked by minhduc0711 on May 17, 2021

I am implementing WGAN-GP using Tensorflow 2.0, but each training iteration of the critic is very slow (about 4 secs on my CPU, and somehow 9 secs on Colab GPU).

Is WGAN-GP usually this slow or there is a flaw in my code?

Here’s my code for training the critic:

def train_critic(self, X_real, batch_size, gp_loss_factor, optimizer):
    y_real = np.ones((batch_size, 1))

    # Get batch of generated images
    noise = np.random.normal(0, 1, (batch_size, self.z_dim))
    X_fake = self.gen.predict(noise)
    y_fake = -np.ones((batch_size, 1))

    X = np.vstack((X_real, X_fake))
    y = np.concatenate((y_real, y_fake))

    # Interpolate images
    alpha = np.random.uniform(size=(batch_size, 1, 1, 1))
    X_interpolated = alpha * X_real + (1 - alpha) * X_fake
    X_interpolated = tf.constant(X_interpolated, dtype=tf.float32)

    # Perform weight update
    with tf.GradientTape() as outer_tape:
        # Calculate gradient penalty loss
        with tf.GradientTape() as inner_tape:
            inner_tape.watch(X_interpolated)
            y_interpolated = self.critic(X_interpolated)
        gradients = inner_tape.gradient(y_interpolated, X_interpolated)
        norm = tf.sqrt(
            1e-8 + tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
        gp_loss = gp_loss_factor * tf.reduce_mean((norm - 1.) ** 2)

        # Calculate Wasserstein loss
        y_pred = self.critic(X)
        wasserstein_loss = wasserstein(y, y_pred)

        # Add two losses
        loss = tf.add_n([wasserstein_loss, gp_loss] + self.critic.losses)
    gradients = outer_tape.gradient(loss, self.critic.trainable_variables)

    optimizer.apply_gradients(zip(gradients, self.critic.trainable_variables))

    return wasserstein_loss, gp_loss

def wasserstein(y_true, y_pred):
    return tf.reduce_mean(y_true * y_pred)

One Answer

This part calculates a second derivative. it should be slow and time consuming.

# Perform weight update
with tf.GradientTape() as outer_tape:
    # Calculate gradient penalty loss
    with tf.GradientTape() as inner_tape:

Try

# Perform weight update
with tf.GradientTape() as outer_tape, tf.GradientTape() as inner_tape:

Answered by Petr Ivanov on May 17, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP