TransWikia.com

Sequentially Training Certain Layers/Sub-Networks in Keras Functional API

Data Science Asked by Adam on October 15, 2020

Suppose we have a stacked neural network architecture with a
layer that is to be shared between two “sub-networks”.

Example:

from keras.layers import Input, Dense
from keras.models import Model

main_input = Input(shape=(5, ))

## Model A: main_input -> A_output
layer_A1 = Dense(10, name='A1')(main_input)
layer_A2 = Dense(10, name='A2')(layer_A1)
layer_A3 = Dense(10, name='A3')(layer_A2)
A_output = Dense(1, name='A_output')(layer_A3)

## Model B: main_input -> layer_A2 -> B_output
layer_B1 = Dense(10, name='B1')(layer_A2)
B_output = Dense(1, name='B_output')(layer_B1)

model = Model(inputs=main_input,
              outputs=[A_output, B_output],
)

model.compile(optimizer='adam',
             loss={
                        'A_output':'mean_squared_error',
                        'B_output':'mean_squared_error'
                    },
             )

Architecture

The goal is to train model A first so that model B can learn from the pre-trained weights of Layer A2. However calling fit in the current architecture will train both simultaneously and sum up the losses.

How can I change the architecture so that model A is trained first without creating separate models? Ultimately, I’ll need to call model.predict(new_sample) where new_sample is of shape (5,) in the example.

One Answer

Figured it out. To train certain sub-networks first and re-use the trained weights to initialize other layers:

from keras.layers import Input, Dense
from keras.models import Model

main_input = Input(shape=(5, ))

## Model A: main_input -> A_output
layer_A1 = Dense(10, name='A1')(main_input)
layer_A2 = Dense(10, name='A2')(layer_A1)
layer_A3 = Dense(10, name='A3')(layer_A2)
A_output = Dense(1, name='A_output')(layer_A3)

# Train Model A first
model_A = Model(inputs=main_input, outputs=A_output)
model_A.compile(# args)
model_A.fit(# args)

## Model B: main_input -> layer_A2 -> B_output
layer_B1 = Dense(10, name='B1')(main_input)
B_output = Dense(1, name='B_output')(layer_B1)

model_B = Model(inputs=main_input, outputs=B_output)

# Set weights of layer_B1 and freeze; essentially a copy of pre-trained layer_A2
trained_A2_weights = model_A.layers[2].get_weights()
model_B.layers[1].set_weights(trained_A2_weights)
model_B.layers[1].trainable = False

model_B.compile(# args)
model_B.fit(# args)
```

Answered by Adam on October 15, 2020

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP