TransWikia.com

Style transfer model outputs only zero valued pixels

Data Science Asked on December 30, 2021

Currently implementing the style transfer model proposed in the article Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization.

The model takes two RGB images as input: one content image and one style image and then generates a new image depicting the content image in the style of the given style image.

The model has an autoencoder-like structure: a pretrained VGG19 model is used as encoder. In the bottleneck
an AdaIn layer is used that takes as input the encoded content and style images and outputs the encoded content features but scaled such that they now have the mean and standard deviations
of the style features. A decoder is then trained to generate a new image from the AdaIn output.

The implementation is done in Tensorflow v2.0. Content images are taken from the COCO dataset and style images are taken from WikiArts.

If I pass a content image and a style image to the model before training it, the model outputs a generated image of random noise where the underlying content is sometimes (barely) visible.
This is as expected. When I train the model on a subset of the content and style datasets, I notice that the loss is continuously decreasing, as desired.

However, after training, the model has supposedly learnt to generate images with only zero-valued pixels, which gives a completely "blank" image.
My question is how is this possible? Has it got something to do with the Tensorflow training regime (that I do not fully understand)?

Any help would be greatly appreciated.

Code


    def __init__(self, name = "AdaIn"):
        super(AdaInLayer, self).__init__(name = name)

    def call(self, c, s):
        epsilon = 1e-5
        axes = [1, 2] # Across spatial dimensions
        meanC, varC = tf.nn.moments(c, axes = axes, keepdims = True)
        meanS, varS = tf.nn.moments(s, axes = axes, keepdims = True)
        stdevC, stdevS = tf.sqrt(varC + epsilon), tf.sqrt(varS + epsilon)
        normalizedC = stdevS * (c - meanC) / stdevC + meanS
        return normalizedC

def BuildEncoder():
    VGG19 = tf.keras.applications.VGG19(include_top = False)
    VGG19.trainable = False

    LayerNames = ['block1_conv1',
                  'block2_conv1',
                  'block3_conv1',
                  'block4_conv1']

    VGGLayers = [VGG19.get_layer(name).output for name in LayerNames]
    encoder = Model(inputs = [VGG19.input], outputs = VGGLayers, name = "Encoder")
    return encoder

def BuildDecoder(input_shape, kernel_size, upsampling_size, activation_fn):
    input = tf.keras.Input(shape = input_shape, name = "AdaIn_output")
    x = Conv2D(filters = 256, kernel_size = kernel_size, activation = activation_fn, padding = 'same', kernel_initializer='he_normal')(input)
    x = UpSampling2D(size = upsampling_size)(x)

    x = Conv2D(filters = 256, kernel_size = kernel_size, activation = activation_fn, padding = 'same', kernel_initializer='he_normal')(x)
    x = Conv2D(filters = 256, kernel_size = kernel_size, activation = activation_fn, padding = 'same', kernel_initializer='he_normal')(x)
    x = Conv2D(filters = 256, kernel_size = kernel_size, activation = activation_fn, padding = 'same', kernel_initializer='he_normal')(x)
    x = Conv2D(filters = 128, kernel_size = kernel_size, activation = activation_fn, padding = 'same', kernel_initializer='he_normal')(x)
    x = UpSampling2D(size = upsampling_size)(x)

    x = Conv2D(filters = 128, kernel_size = kernel_size, activation = activation_fn, padding = 'same', kernel_initializer='he_normal')(x)
    x = Conv2D(filters = 64, kernel_size = kernel_size, activation = activation_fn, padding = 'same', kernel_initializer='he_normal')(x)

    x = UpSampling2D(size = upsampling_size)(x)
    x = Conv2D(filters = 64, kernel_size = kernel_size, activation = activation_fn, padding = 'same', kernel_initializer='he_normal')(x)
    x = Conv2D(filters = 3, kernel_size = kernel_size, activation = activation_fn, padding = 'same', kernel_initializer='he_normal')(x)

    decoder = Model(input, x, name = "Decoder")
    return decoder

la = 1
encoder = BuildEncoder()
output_shape = encoder.layers[-1].output.shape[1:]
adaIn = AdaInLayer()
decoder = BuildDecoder(input_shape = output_shape, kernel_size = 3, upsampling_size = 2, activation_fn = 'relu')


model_input_shape = encoder.input.shape[1:]
content_input = tf.keras.Input(shape = model_input_shape, name = "Content_input")
style_input = tf.keras.Input(shape = model_input_shape, name = "Style_input")

processed_content_input = tf.keras.applications.vgg19.preprocess_input(content_input)
processed_style_input = tf.keras.applications.vgg19.preprocess_input(style_input)

encoded_content = encoder(processed_content_input)
encoded_style = encoder(processed_style_input)

t = adaIn(encoded_content[-1], encoded_style[-1])
T = decoder(t)
T = T + np.array([103.939, 116.779, 123.68])
T = tf.reverse(T, axis = [-1])
T = tf.clip_by_value(T, 0.0, 255.0)

y = encoder(tf.keras.applications.vgg19.preprocess_input(T))

y0 = y[-1]
# Content loss
content_loss = tf.math.reduce_mean(tf.math.reduce_mean(tf.math.squared_difference(y0, t), axis = [1, 2]))

# Style loss
style_loss = tf.add_n([(tf.reduce_mean(yi - si))**2 for yi, si in zip(y, encoded_style)]) + tf.add_n([(tf.math.reduce_std(yi - si))**2 for yi, si in zip(y, encoded_style)])

# Total loss
loss = content_loss + la * style_loss

# Build model
StyleNetModel = Model(inputs = [content_input, style_input], outputs = [T])
StyleNetModel.add_loss(loss)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
StyleNetModel.compile(optimizer)

path_to_training_set = os.getcwd() + '/Data/training_set'
train_set = dh.load_dataset(path_to_training_set) # Load a training set of content and style images
num_training_samples = 500
batch_size = 4

content_train_set = train_set['content'][:num_train_samples] # Shape: (num_train_samples, 256, 256, 3)
style_train_set = train_set['style'][:num_train_samples] # # Shape: (num_train_samples, 256, 256, 3)

# Train the model
StyleNetModel.fit([content_train_set, style_train_set], batch_size = batch_size, epochs = 1)

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP