Data Science Asked on June 19, 2021
I am trying to implement a BiLSTM layer for a text classification problem and using PyTorch for this.
self.bilstm = nn.LSTM(embedding_dim, lstm_hidden_dim//2, batch_first=True, bidirectional=True)
lstm_out, (ht, ct) = self.bilstm(embeddings)
Now, in some examples (I came across in internet) people are passing ht
through Linear
layer and generating output. Some people are using lstm_out
for the same. Now I have two questions –
ht
, for bidirectional=True
it is of the shape (2, m, n)
– but we need (1, m, n)
– what is the best way to convert it?Nice question! I'm looking at the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
If i get that right, lstm_out gives you the output features of the LSTM's last layer, for all the tokens in the sequence. This might mean that if your LSTM has two layers and 10 words, assuming batch size of 1, you'll get an output tensor of (10,1, h) assuming uni-directionality and sequence-first orientation (also see the docs).
For the same LSTM, the size of the ht will be (2*1,1,h) because this output only relates to the features of the last word in the sentence, and is collated from all the layers in the LSTM, unlike the lstm_out which only look at the top layer.
In a nutshell, i think you want to use lstm_out for sequence classification tasks because you have features for all the tokens. Then you want to use ht for tasks that apply at sentence level like classification, as i think it is assumed that the features of the last token are the most rich in an LSTM and representative of the sentence because it contains LSTM-like linkages to all the other tokens (assuming uni-directionality, i think for bi-directionality PyTorch will slice the backward states across the layers for the last sequence too)..
if you want ht to be (1,m,n) instead of (2,m,n) , then you need only 1 layer and 1 direction. This is closely linked to the design of your architecture though, you would need to think through whether you really want to get rid of bidirectionality and keep a shallow LSTM.
Answered by Nitin on June 19, 2021
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP