Artificial Intelligence Asked by Seewoo Lee on January 28, 2021
I’m reading the paper Neural Ordinary Differential Equations and I have a simple question about adjoint method. When we train NODE, it uses a blackbox ODESolver to compute gradients through model parameters, hidden states, and time. It uses another quantity $mathbf{a}(t) = partial L / partial mathbf{z}(t)$ called adjoint, which also satisfies another ODE. As I understand, the authors build a single ODE that computes all the gradients $partial L / partial mathbf{z}(t_{0})$ and $partial L / partial theta$ by solving that single ODE. However, I can’t understand how do we know the value $partial L / partial mathbf{z}(t_1)$ which corresponds to the initial condition for the ODE corresponds to the adjoint. I’m using this tutorial as a reference, and it defines custom forward and backward methods for solving ODE. However, for the backward computation (especially ODEAdjoint
class in the tutorial) we need to pass $partial L / partial mathbf{z}$ for backpropagation, and this enables us to compute $partial L / partial mathbf{z}(t_i)$ from $partial L / partial mathbf{z}(t_{i+1})$, but we still need to know the adjoint value $partial L / partial mathbf{z}(t_N)$. I do not understand well about how pytorch’s autograd
package works, and this seems to be a barrier to understand this. Could anyone explain how it operates, and where $partial L / partial mathbf{z}(t_1)$ (or $partial L / partial mathbf{z}(t_N)$ if this is more comfortable) comes from? Thanks in advance.
Here’s my guess for the initial adjoint from simple example. Let $dmathbf{z}/dt = Az$ be a 2-dim linear ODE with given $A in mathbb{R}^{2times 2}$. If we use Euler’s method as a ODE solver, then the estimate for $z(t_1)$ is explicitly given as $$hat{mathbf{z}}(t_1) = mathrm{ODESolve}(mathbf{z}(t_0), f, t_0, t_1, theta))= left(I + frac{t_1 – t_0}{N}Aright)^{N} mathbf{z}(t_0) $$ where $N$ is the number of steps for Euler’s method (so that $h = (t_1 – t_0) /N$ is the step size). If we use MSE loss for training, then the loss will be
$$
L(mathbf{z}(t_1)) = Bigl|Bigl| mathbf{z}_1 – left(I + frac{t_1 – t_0}{N}Aright)^Nmathbf{z}(t_0)Bigr|Bigr|_2^2
$$
where $mathbf{z}_1$ is the true value at time $t_1$, which is $mathbf{z}_1 = e^{A(t_1 – t_0)}mathbf{z}(t_0)$. Since adjoint $mathbf{a}(t) = partial L / partial mathbf{z}(t)$ satisfies $$frac{dmathbf{a}(t)}{dt} = -mathbf{a}(t)^{T} frac{partial f(mathbf{z}(t), t, theta)}{partial mathbf{z}} = mathbf{0},$$
$mathbf{a}(t)$ is constant and we get $mathbf{a}(t_0) = mathbf{a}(t_1)$. So we do not need to use augmented ODE for computing $mathbf{a}(t)$. However, I still don’t know what $mathbf{a}(t_1) = partial L / partial mathbf{z}(t_1)$ should be. If my understanding is correct, since $L = ||mathbf{z}_1 – mathbf{z}(t_1)||^{2}_{2}$, it seems that the answer might be
$$
frac{partial L}{partial mathbf{z}(t_1)} = 2(mathbf{z}(t_1) – mathbf{z}_1).
$$
However, this doesn’t seem to be true: if it is, and if we have multiple datapoints at $t_1, t_2, dots, t_N$, then the loss is
$$
L = frac{1}{N} sum_{i=1}^{N}||mathbf{z}_i -mathbf{z}(t_i)||_{2}^{2}
$$
and we may have
$$
frac{partial L}{partial mathbf{z}(t_i)} = frac{2}{N} (mathbf{z}(t_i) – mathbf{z}_i),
$$
which means that we don’t need to solve ODE associated to $mathbf{a}(t)$.
The first thing I spotted was that $a(t)=frac{∂L}{∂z}(t)$ should be $a(t)=-frac{∂L}{∂z}(t)$. Later you have the correct value so this is probably a typo.
I do not fully understand every step of N-ODE and I do not fully understand your question. Nevertheless...
First, a forward pass is done to obtain predictions of $z$, at every $t$. Then the adjoint state is run backwards in time for every $t$. Which gives the learning impulse.
This is probably not enough to satisfy your question, hopefully you can get more specific with respect to the question. (ie. refining the question)
Answered by Quibus on January 28, 2021
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP