Stack Overflow Asked on January 26, 2021
I’m training ResNet34 on CIFAR-10. I have a very weird behavior when I try to manipulate param.grad
for model.parameters()
.
The following function is where all the mess happens. It currently doesn’t do anything useful as a result of trying to understand what happens.
def add_error(error):
params = (param for param in model.parameters() if param.requires_grad)
# [param.grad + err for param, err in zip(params, error)] # Line 2
# new_error = [param.grad + err for param, err in zip(params, error)] # Line 3
for param in params:
param.grad.zero_()
# new_error = [torch.zeros(param.grad.shape, device=device) for param in params] # Line 6
return new_error
It’s used in the gradient descent step:
def step(model, optimizer, batch, labels, error):
optimizer.zero_grad()
loss = compute_loss(model, batch, labels)
loss.backward()
new_error = add_error(error=error) <- add_error is called here
optimizer.step()
return new_error
where optimizer is optim.SGD(model.parameters(), lr=0.1)
and compute_loss
essentially calls nn.CrossEntropyLoss()
on model(batch)
and labels
.
What I expect: Since I set the gradient to 0, no matter what I do, nothing should change: the loss should be around the original value (2.4) all the time
What actually happens:
2.4
.add_error
at all. I.e. loss decreases at the same rate as usual SGD: 2.4 -> 1.7 -> 1.3 -> ...
(per epoch). In other words, somehow the gradients are propagated.4.3
, and then slowly decreases 4.3 -> 4.2 -> 4.14 -> 4.1 -> ...
(I suspect this decrease to be a result of batch normalization).Note that in neither case I actually use error
, I actually never use error
to update the gradient.
Also, adding more lines like Line 2 doesn’t affect the outcome.
Question: What’s happening?
If it helps, I may try to produce MCVE and post it on pastebin (it’ll be a wall of code, too large to fit here).
The problem was incredibly stupid: params
is a generator and becomes exhausted after the first iteration over it. Creating a list instead of a generator solves the issue.
Answered by Dmitry on January 26, 2021
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP