TransWikia.com

Copying Weights using NetTrain in Mathematica

Mathematica Asked by Terrell N. on February 2, 2021

I would like to implement the PPO algorithm – Reinforcement Learning – in Mathematica. And for that, I need to be able to copy network weights from say subnetwork NN1 to subnetwork NN2 as an update and as a part of the training process.

Is doing this possible in Mathematica? maybe my question can be split into two parts:

  1. how can I copy weights?

  2. how can I do that as a part of the training process, i.e. using NetTrain?

One Answer

RNNs have ports "State" (all) and "CellState" (LSTM only).

NetGraph[
 {
  BasicRecurrentLayer[8],
  BasicRecurrentLayer[8]
  },
 {
  NetPort[1, "State"] -> NetPort[2, "State"]
  }
 ]

enter image description here

With ConvolutionLayer and LinearLayer it's more complicated. They have no ports other than "Input" and "Output". We can copy and paste weights. But I do not know how to do this in NetTrain without breaking the training loop.

conv = ConvolutionLayer[1, {3, 3}, "Input" -> {1, Automatic, Automatic}] // NetInitialize;
w = NetExtract[conv, "Weights"];
b = NetExtract[conv, "Biases"];
convNew = ConvolutionLayer[1, {3, 3}, "Weights" -> w, "Biases" -> b]

Answered by Alexey Golyshev on February 2, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP