Data Science Asked by minhduc0711 on May 17, 2021
I am implementing WGAN-GP using Tensorflow 2.0, but each training iteration of the critic is very slow (about 4 secs on my CPU, and somehow 9 secs on Colab GPU).
Is WGAN-GP usually this slow or there is a flaw in my code?
Here’s my code for training the critic:
def train_critic(self, X_real, batch_size, gp_loss_factor, optimizer):
y_real = np.ones((batch_size, 1))
# Get batch of generated images
noise = np.random.normal(0, 1, (batch_size, self.z_dim))
X_fake = self.gen.predict(noise)
y_fake = -np.ones((batch_size, 1))
X = np.vstack((X_real, X_fake))
y = np.concatenate((y_real, y_fake))
# Interpolate images
alpha = np.random.uniform(size=(batch_size, 1, 1, 1))
X_interpolated = alpha * X_real + (1 - alpha) * X_fake
X_interpolated = tf.constant(X_interpolated, dtype=tf.float32)
# Perform weight update
with tf.GradientTape() as outer_tape:
# Calculate gradient penalty loss
with tf.GradientTape() as inner_tape:
inner_tape.watch(X_interpolated)
y_interpolated = self.critic(X_interpolated)
gradients = inner_tape.gradient(y_interpolated, X_interpolated)
norm = tf.sqrt(
1e-8 + tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
gp_loss = gp_loss_factor * tf.reduce_mean((norm - 1.) ** 2)
# Calculate Wasserstein loss
y_pred = self.critic(X)
wasserstein_loss = wasserstein(y, y_pred)
# Add two losses
loss = tf.add_n([wasserstein_loss, gp_loss] + self.critic.losses)
gradients = outer_tape.gradient(loss, self.critic.trainable_variables)
optimizer.apply_gradients(zip(gradients, self.critic.trainable_variables))
return wasserstein_loss, gp_loss
def wasserstein(y_true, y_pred):
return tf.reduce_mean(y_true * y_pred)
This part calculates a second derivative. it should be slow and time consuming.
# Perform weight update
with tf.GradientTape() as outer_tape:
# Calculate gradient penalty loss
with tf.GradientTape() as inner_tape:
Try
# Perform weight update
with tf.GradientTape() as outer_tape, tf.GradientTape() as inner_tape:
Answered by Petr Ivanov on May 17, 2021
Get help from others!
Recent Answers
Recent Questions
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP