TransWikia.com

How can I perform weighted addition of three layers in keras?

Data Science Asked by Shivam Pande on December 26, 2020

I would like to perform the weighted addition of three outputs from different Keras layers such that the weights are trainable. How can I achieve this? I am using tensorflow 2.0 as backend for Keras.

2 Answers

I solved the problem using subclassing in keras. The code is shown below:

class Wt_Add(keras.layers.Layer):
def __init__(self, units=1, input_dim=1):
    super(Wt_Add, self).__init__()
    w_init = tf.random_normal_initializer()
    self.w1 = tf.Variable(
        initial_value=w_init(shape=(input_dim, units), dtype="float32"),
        trainable=True,
    )
    self.w2 = tf.Variable(
        initial_value=w_init(shape=(input_dim, units), dtype="float32"),
        trainable=True,
    )  
    self.w3 = tf.Variable(
        initial_value=w_init(shape=(input_dim, units), dtype="float32"),
        trainable=True,
    )       

def call(self, input1, input2, input3):
    return tf.multiply(input1,self.w1) + tf.multiply(input2, self.w2) + tf.multiply(input3, self.w3)

Usage:

wt_add = Wt_Add(1,1)
sum_layer = wt_add(input1, input2, input3)

Correct answer by Shivam Pande on December 26, 2020

You have the following basic operations on layers:

  • tf.keras.layers.Lambda so you can multiply each of your 3 layers with a simple lambda operation
layer1 =  tf.keras.layers.Lambda(lambda x: x * weight1)(layer1)
layer2 =  tf.keras.layers.Lambda(lambda x: x * weight2)(layer2)
layer3 =  tf.keras.layers.Lambda(lambda x: x * weight3)(layer3)

then there is the tf.keras.layers.Average that allows to average layers:

average_layer = tf.keras.layers.Average()([layer1, layer2, layer3])

It's a bit awkward, I think a weighted average would be the best thing here but it does not seem to be available in Keras yet (as far as I know)

Answered by RonsenbergVI on December 26, 2020

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP