TransWikia.com

Pytorch LSTM not training

Data Science Asked on May 18, 2021

So I am currently trying to implement an LSTM on Pytorch, but for some reason the loss is not decreasing. Here is my network:

class MyNN(nn.Module):
    def __init__(self, input_size=3, seq_len=107, pred_len=68, hidden_size=50, num_layers=1, dropout=0.2):
        super().__init__()
        
        self.pred_len = pred_len
        
        self.rnn = nn.LSTM(
            input_size=input_size, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            dropout=dropout, 
            bidirectional=True,
            batch_first=True
        )
        
        self.linear = nn.Linear(hidden_size*2, 5)
    
    def forward(self, X):
        lstm_output, (hidden_state, cell_state) = self.rnn(X)
        
        labels = self.linear(lstm_output[:, :self.pred_len, :])
        
        return lstm_output, labels

And my training loop

LEARNING_RATE = 1e-2


net = MyNN(num_layers=1, dropout=0)

compute_loss = nn.MSELoss()

optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)


all_loss = []
for data in tqdm(list(train_loader)):
    X, y = data
    
    optimizer.zero_grad()

    lstm_output, output = net(X.float())
    
    # Computing the loss
    loss = compute_loss(y, output)
    all_loss.append(loss)
    loss.backward()
    
    optimizer.step()
    
# Plot
plt.plot(all_loss, marker=".")
plt.xlabel("Epoch")
plt.xlabel("Loss")
plt.show()

And this is what I got
enter image description here

I have been trying to look for what the hell I am doing wrong but I have no idea. Also, before I used a keras LSTM and it worked well on the dataset.

Any help?
Thanks!

One Answer

You look at loss at every batch. You should average your loss over all batches. When you look at different batches your loss may increase simply because one batch is harder to predict than the other one. That's why it's not really interpretable. So start with that. If the problem persists it's probably exploding gradients. In that case lower your learning rate to 1e-3 or 1e-4 or even less if it continues.

Correct answer by YuseqYaseq on May 18, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP