Data Science Asked on July 22, 2021
I recently tried computing the derivative of the layer norm function (https://arxiv.org/abs/1607.06450), an essential component of transformers, but the result suggests that no gradient flows through the operation, which can’t be true.
Here’s my calculations:
$textrm{Given a vector of real numbers $X$ of length $N$, indexed as $x_i$,}
textrm{we define the following operations:}
mu =frac{sum_{k=1}^{N}{x_k}}{N}
sigma = sqrt{frac{sum_{k=1}^{N}{(x_k-mu)^2}}{N}}
y_i=frac{(x_i-mu)}{sigma}
textrm{We seek to calculate the derivative of $y_i$ w.r.t $X$. That is,}
frac{dy_i}{dX} = sum^{N}_{k=1}frac{dy_i}{dx_k}
textrm{By the quotient rule:}
frac{dy_i}{dx_j}=frac{(x_i-mu)’sigma-(x_i-mu)sigma’}{sigma^2}
(x_i-mu)’=delta_{ij}-mu’
mu’=frac{1}{N}
implies(x_i-mu)’ = delta_{ij}-frac{1}{N}
sigma’=frac{1}{2}(frac{sum_{k=1}^{N}{(x_k-mu)^2}}{N})^{-frac{1}{2}}*[frac{sum_{k=1}^{N}{(x_k-mu)^2}}{N}]’
[frac{sum_{k=1}^{N}{(x_k-mu)^2}}{N}]’=frac{1}{N}sum_{k=1}^{N}2*(x_k-mu)(delta_{kj}-frac{1}{N})
qquad =frac{2}{N}sum_{k=1}^{N}(x_k-mu)delta_{ij}-(x_k-mu)frac{1}{N}
textrm{Note that $delta_{kj}$ is only 1 when when $k=j$ and 0 otherwise, so we can further reduce:}
qquad =frac{2}{N}((x_j-mu)-sum_{k=1}^{N}(x_k-mu)frac{1}{N})
qquad =frac{2}{N}((x_j-mu)-frac{1}{N}sum_{k=1}^{N}(x_k)+frac{1}{N}sum_{k=1}^{N}mu)
qquad =frac{2}{N}((x_j-mu)-mu-frac{1}{N}Nmu)
qquad =frac{2}{N}(x_j-mu)
textrm{Thus plugging that back into $sigma’$ we get:}
sigma’=frac{1}{2}(frac{sum_{k=1}^{N}{(x_k-mu)^2}}{N})^{-frac{1}{2}}*frac{2}{N}(x_j-mu)
quad=frac{1}{N}(frac{1}{sigma})*(x_j-mu)
quad=frac{(x_j-mu)}{Nsigma}
textrm{Now that we have all the components we can return to the derivative $frac{dy_i}{dx_j}$:}
frac{dy_i}{dx_j}=frac{(x_i-mu)’sigma-(x_i)sigma’}{sigma^2}
qquad=frac{(x_i-mu)’sigma}{sigma^2}-frac{(x_i-mu)sigma’}{sigma^2}
qquad=frac{delta_{ij}-frac{1}{N}}{sigma}-frac{(x_i-mu)frac{(x_j-mu)}{Nsigma}}{sigma^2}
qquad=frac{delta_{ij}-frac{1}{N}}{sigma}-frac{(x_i-mu)(x_j-mu)}{Nsigma^3}
qquad=frac{1}{Nsigma}(Ndelta_{ij}-1-frac{(x_i-mu)(x_j-mu)}{sigma^2})
qquad=frac{1}{Nsigma}(Ndelta_{ij}-1-frac{(x_i-mu)}{sigma}frac{(x_j-mu)}{sigma})
qquad=frac{1}{Nsigma}(Ndelta_{ij}-1-y_iy_j)
textrm{Finally, returning to $frac{dy_i}{dX}$:}
frac{dy_i}{dX}=sum^{N}_{j=1}frac{1}{Nsigma}(Ndelta_{ij}-1-y_iy_j)
textrm{Note that we are adding $N$ once (when $i=j$) and $(-1)$ $N$ times, so we can simplify to:}
frac{dy_i}{dX}=frac{1}{Nsigma}(N+(-1)N-sum^{N}_{j=1}y_iy_j)
quad=frac{1}{Nsigma}(-sum^{N}_{j=1}y_iy_j)
quad=frac{1}{Nsigma}(-y_isum^{N}_{j=1}y_j)
quad=frac{-y_i}{sigma}frac{(sum^{N}_{j=1}y_j)}{N}
quad=frac{-y_i}{sigma}mu_y
textrm{BUT by properties of data following a standard normal distribution $mu_y=0$, so}
frac{dy_i}{dX}=frac{-y_i}{sigma}0
quad=0
textrm{Which means no gradient flows through a layer normalization}\$
I’m almost certain I’ve simply made a mistake somewhere, so if someone could point it out I’d greatly appreciate it. Thanks!
If you're computing the deriviative of the layer norm for the purpose of using it in backprop, then you need to compute the derivative with respect to the parameters of the layer, not it's inputs.
i.e., fix $x$ and compute $frac{partial y }{partial mu }$ and $frac{partial y }{partial sigma }$
I did not look at it closely, partly because I suspect that its not actually the derivative you care about as I said above. But it looks like you're over-complicating things in the sense that you should treat $mu$ and $sigma$ as fixed variables and not as functions of $x$.
If they're functions of $x$, it may be that you're running into something like this:
You define $y = frac {x - f(x)}{g(x)}$ and you compute $frac{partial y }{partial x }$
... quotient rule ... computation... re-arrange...whatever
then you plug $f(x) = x$ in to the result and correctly observe that $frac{partial y }{partial x } = 0$. Which is kind of obvious if you plug in at the start, but has been obscured because of notation. Now, where you started is not as simple as what I've written, but this same thing could happen.
Correct answer by bogovicj on July 22, 2021
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP