TransWikia.com

How does Batch normalization help optimization? Proof

Data Science Asked by NeverneverNever on July 6, 2021

I am reading the paper How Does Batch Normalization Help Optimization found here.

$newcommand{norm}[1]{leftlVert#1rightrVert}$

But I am having trouble understanding the proof of the paper. It’s about proving the effect of BatchNorm on the Lipschitzness of the loss. For brevity, what I am having trouble with is deriving:

$norm{dfrac{partial hat{L}}{partial y_{j}}}^2 = Big( dfrac{gamma^2}{sigma_j^2} Big)$$Bigg( norm{dfrac{partial hat{L}}{partial z_{j}}}^2-dfrac{1}{m}Bigg< 1,dfrac{partial hat{L}}{partial z_{j}} Bigg>^2 – dfrac{1}{m}Bigg<dfrac{partial hat{L}}{partial z_{j}},hat{y_j} Bigg>^2 Bigg)$

from

$dfrac{partial hat{L}}{partial y_{j}^{(b)}} = Big( dfrac{gamma}{msigma_j} Big)$$Bigg(m dfrac{partial hat{L}}{partial z_{j}^{(b)}}-sum^m_{k=1}dfrac{partial hat{L}}{partial z_{j}^{(k)}} – hat{y_j}^{(b)} sum^m_{k=1}dfrac{partial hat{L}}{partial z_{j}^{(k)}} hat{y_j}^{(k)} Bigg)$

— Details

$dfrac{partial hat{L}}{partial y_j^{(b)}}=dfrac{gamma}{msigma_j}Bigg(mdfrac{partial hat{L}}{partial z_j^{(b)}}-sum^{m}_{k=1}dfrac{partial hat{L}}{partial z_j^{(k)}}-hat{y_j}^{(b)}sum^{m}_{k=1}dfrac{partial hat{L}}{partial z_j^{(k)}}hat{y_j}^{(k)}Bigg)$ (1)

$dfrac{partial hat{L}}{partial y_j}=dfrac{gamma}{msigma_j}Bigg(mdfrac{partial hat{L}}{partial z_j}-1Bigg< 1, dfrac{partial hat{L}}{partial z_j} Bigg>-hat{y_j}Bigg< dfrac{partial hat{L}}{partial z_j}, hat{y_j}Bigg>Bigg)$ (2)

first I can’t fully understand how (1) to (2), since <> is inner product, should result of inner product be scalar value? Since the result of $sum^{m}_{k=1}dfrac{partial hat{L}}{partial z_j^{(k)}}$ is vector, how this and $Bigg< 1, dfrac{partial hat{L}}{partial z_j} Bigg>$ be same?

let $mu_g=dfrac{1}{m} Bigg<1, dfrac{partial hat{L}}{partial z_j} Bigg>$ $hat{y_j}$ is mean-zero and norm-$sqrt{m}$

$dfrac{partial hat{L}}{partial y_j}=dfrac{gamma}{sigma_j}Bigg(Big( dfrac{partial hat{L}}{partial z_j}-1mu_g Big)-dfrac{1}{m}hat{y_j}Bigg< Big( dfrac{partial hat{L}}{partial z_j}-1mu_g Big), hat{y_j}Bigg>Bigg)$ (3)

$=dfrac{gamma}{sigma_j}Bigg(Big( dfrac{partial hat{L}}{partial z_j}-1mu_g Big)-dfrac{hat{y_j}}{norm{hat{y_j}}}Bigg< Big( dfrac{partial hat{L}}{partial z_j}-1mu_g Big), dfrac{hat{y_j}}{norm{hat{y_j}}}Bigg>Bigg)$ (4)

$norm{dfrac{partial hat{L}}{partial y_j}}^2=dfrac{gamma^2}{sigma_j^2}norm{Big( dfrac{partial hat{L}}{partial z_j}-1mu_g Big)-dfrac{hat{y_j}}{norm{hat{y_j}}}Bigg< Big( dfrac{partial hat{L}}{partial z_j}-1mu_g Big), dfrac{hat{y_j}}{norm{hat{y_j}}}Bigg>}^2$ (5)

$norm{dfrac{partial hat{L}}{partial y_j}}^2=dfrac{gamma^2}{sigma_j^2}Bigg(norm{Big( dfrac{partial hat{L}}{partial z_j}-1mu_g Big)}^2-Bigg< Big( dfrac{partial hat{L}}{partial z_j}-1mu_g Big), dfrac{hat{y_j}}{norm{hat{y_j}}}Bigg>^2Bigg)$ (6)

I am having trouble deriving (6) from (5). Can you show me how to derive this?

$norm{dfrac{partial hat{L}}{partial y_{j}}}^2 = Big( dfrac{gamma^2}{sigma_j^2} Big)$$Bigg( norm{dfrac{partial hat{L}}{partial z_{j}}}^2-dfrac{1}{m}Bigg< 1,dfrac{partial hat{L}}{partial z_{j}} Bigg>^2 – dfrac{1}{m}Bigg<dfrac{partial hat{L}}{partial z_{j}},hat{y_j} Bigg>^2 Bigg)$ (7)

Thank you.

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP