Data Science Asked by D_H on August 1, 2021
The Deep Implicit Layers Tutorial is a nice resource that dives into Neural ODEs, Deep Equilibrium Models etc., using the JAX library. In chapter 3 (out of 5), [Link: http://implicit-layers-tutorial.org/neural_odes/] there is the following example ResNet, which I am confused about:
import jax.numpy as jnp
def mlp(params, inputs):
# A multi-layer perceptron, i.e. a fully-connected neural network.
for w, b in params:
outputs = jnp.dot(inputs, w) + b # Linear transform
inputs = jnp.tanh(outputs) # Nonlinearity
return outputs
def resnet(params, inputs, depth):
for i in range(depth):
outputs = mlp(params, inputs) + inputs
return outputs
What is particularly confusing to me is that in the ‘resnet’ function, the for loop over the depth seems COMPLETELY redundant.
The inputs of previous layers are not being fed to later layers, and the ‘i’ index is not being used.
Am I missing some fundamental information about ResNets? Is it a mistake? Can someone explain the point of the for loop?
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP