Data Science Asked by Adam on October 15, 2020
Suppose we have a stacked neural network architecture with a
layer that is to be shared between two “sub-networks”.
Example:
from keras.layers import Input, Dense
from keras.models import Model
main_input = Input(shape=(5, ))
## Model A: main_input -> A_output
layer_A1 = Dense(10, name='A1')(main_input)
layer_A2 = Dense(10, name='A2')(layer_A1)
layer_A3 = Dense(10, name='A3')(layer_A2)
A_output = Dense(1, name='A_output')(layer_A3)
## Model B: main_input -> layer_A2 -> B_output
layer_B1 = Dense(10, name='B1')(layer_A2)
B_output = Dense(1, name='B_output')(layer_B1)
model = Model(inputs=main_input,
outputs=[A_output, B_output],
)
model.compile(optimizer='adam',
loss={
'A_output':'mean_squared_error',
'B_output':'mean_squared_error'
},
)
The goal is to train model A first so that model B can learn from the pre-trained weights of Layer A2. However calling fit in the current architecture will train both simultaneously and sum up the losses.
How can I change the architecture so that model A is trained first without creating separate models? Ultimately, I’ll need to call model.predict(new_sample)
where new_sample
is of shape (5,) in the example.
Figured it out. To train certain sub-networks first and re-use the trained weights to initialize other layers:
from keras.layers import Input, Dense
from keras.models import Model
main_input = Input(shape=(5, ))
## Model A: main_input -> A_output
layer_A1 = Dense(10, name='A1')(main_input)
layer_A2 = Dense(10, name='A2')(layer_A1)
layer_A3 = Dense(10, name='A3')(layer_A2)
A_output = Dense(1, name='A_output')(layer_A3)
# Train Model A first
model_A = Model(inputs=main_input, outputs=A_output)
model_A.compile(# args)
model_A.fit(# args)
## Model B: main_input -> layer_A2 -> B_output
layer_B1 = Dense(10, name='B1')(main_input)
B_output = Dense(1, name='B_output')(layer_B1)
model_B = Model(inputs=main_input, outputs=B_output)
# Set weights of layer_B1 and freeze; essentially a copy of pre-trained layer_A2
trained_A2_weights = model_A.layers[2].get_weights()
model_B.layers[1].set_weights(trained_A2_weights)
model_B.layers[1].trainable = False
model_B.compile(# args)
model_B.fit(# args)
```
Answered by Adam on October 15, 2020
Get help from others!
Recent Answers
Recent Questions
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP