Cross Validated Asked by londumas on November 3, 2020

Thanks for having a look at my post.

I had an extensive look at the difference in weight initialization between pytorch

and Keras, and it appears that the definition of he_normal (Keras)

and kaiming_normal_ (pytorch) is different across the two platforms.

They both claim to be applying the solution presented in He et al. 2015 (https://arxiv.org/abs/1502.01852) :

https://pytorch.org/docs/stable/nn.init.html,

https://www.tensorflow.org/api_docs/python/tf/keras/initializers/HeNormal.

However, I found no trace of truncation in that later paper.

To me truncation makes a lot of sense.

Do I have a bug in my simple code that follows, or indeed these two platforms claim

to apply a solution from a paper, but differ in their implementation.

Then how is correct? What is best?

```
import numpy as np
import matplotlib.pyplot as plt
import torch
import keras
import keras.models as Model
from keras.layers import Input
from keras.layers.core import Dense
real = 100
### pyTorch
params = np.array([])
for _ in range(real):
lin = torch.nn.Linear(in_features=16, out_features=16)
torch.nn.init.kaiming_normal_(lin.weight)
params = np.append(params,lin.weight.detach().numpy())
params = params.flatten()
plt.hist(params,bins=50,alpha=0.4,label=r'PyTorch')
### Keras
params = np.array([])
for _ in range(real):
X_input = Input([16])
X = Dense(units=16, activation='relu', kernel_initializer='he_normal')(X_input)
model = Model.Model(inputs=X_input,outputs=X)
params = np.append(params,model.get_weights()[0])
params = params.flatten()
plt.hist(params,bins=50,alpha=0.4,label=r'Keras')
###
plt.xlabel(r'Weights')
plt.ylabel(r'#')
plt.yscale('log')
plt.legend()
plt.grid()
plt.show()
```

I think you're correct that the two initializers are different; this difference is consistent with the description in the documentation.

For Keras, the documentation says

It draws samples from a

truncatednormal distribution centered on 0 with`stddev = sqrt(2 / fan_in)`

where`fan_in`

is the number of input units in the weight tensor.

Where by contrast, for Torch, the documentation says

Fills the input Tensor with values according to the method described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. (2015), using a normal distribution. The resulting tensor will have values sampled from $mathcal{N}(0, text{std}^2)$ where

$$text{std} = frac{text{gain}}{sqrt{text{fan_mode}}}$$

Choosing $text{gain}=sqrt{2}$ and $text{fan_mode}=text{fan_in}$ makes the standard deviations the same, but the Keras function is using a truncated distribution while the Torch function is not, so the resulting distributions will be different. Again, this is consistent with your findings.

So the Torch function isn't truncating, while the Keras function is.

When we look to the paper, again, you're correct: the He 2015 paper does not describe a truncation in the text. Since the cited article doesn't seem to support the initialization for the Keras function, it could be reasonable to create an issue on the library's Github or another official Keras channel. It's also possible that the Keras authors meant to cite a different article, or something like that.

As for your last question,

What is best?

Best for what? Does one of the initializations suit your task well? If so, use that one. If not, use a different one. The He paper describes a network design and finds that this initialization works well, and provides some commentary and theoretical justification. But the network that you want to build may not match the models He was examining, or it may not conform to some of the assumptions that He made in the theoretical analysis. In particular, the He paper is focused on ReLU and PReLU networks; if you're using a different activation function, your results may require an alternative initialization scheme.

Answered by Sycorax on November 3, 2020

Get help from others!

Recent Answers

- Jon Church on Why fry rice before boiling?
- haakon.io on Why fry rice before boiling?
- Joshua Engel on Why fry rice before boiling?
- Peter Machado on Why fry rice before boiling?
- Lex on Does Google Analytics track 404 page responses as valid page views?

Recent Questions

- How can I transform graph image into a tikzpicture LaTeX code?
- How Do I Get The Ifruit App Off Of Gta 5 / Grand Theft Auto 5
- Iv’e designed a space elevator using a series of lasers. do you know anybody i could submit the designs too that could manufacture the concept and put it to use
- Need help finding a book. Female OP protagonist, magic
- Why is the WWF pending games (“Your turn”) area replaced w/ a column of “Bonus & Reward”gift boxes?

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP