Data Science Asked by DukeLover on September 16, 2020
I am trying to train a LSTM model. Is this model suffering from overfitting?
Here is train and validation loss graph:
Yes this is an overfitting problem since your curve shows point of inflection. This is a sign of very large number of epochs. In this case, model could be stopped at point of inflection or the number of training examples could be increased.
Also, Overfitting is also caused by a deep model over training data. In that case, you'll observe divergence in loss between val and train very early.
Answered by Mohit Banerjee on September 16, 2020
The model is overfitting right from epoch 10, the validation loss is increasing while the training loss is decreasing.
Dealing with such a Model:
There are many other options as well to reduce overfitting, assuming you are using Keras, visit this link.
Answered by user5722540 on September 16, 2020
Another possible cause of overfitting is improper data augmentation. If you're augmenting then make sure it's really doing what you expect.
I had a similar problem, and it turned out to be due to a bug in my Tensorflow data pipeline where I was augmenting before caching:
def get_dataset(inputfile, batchsize):
# Load the data into a TensorFlow dataset.
signals, labels = read_data_from_file(inputfile)
dataset = tf.data.Dataset.from_tensor_slices((signals, labels))
# Augment the data by dynamically tweaking each training sample on the fly.
dataset = dataset.map(
map_func=(lambda signals, labels: (tuple(tf.py_function(func=augment, inp=[signals], Tout=[tf.float32])), labels)))
# Oops! Should have called cache() before augmenting
dataset = dataset.cache()
dataset = ... # Shuffle, repeat, batch, etc.
return dataset
training_data = get_dataset("training.txt", 32)
val_data = get_dataset("validation.txt", 32)
model.fit(training_data, validation_data=val_data, ...)
As a result, the training data was only being augmented for the first epoch, but the validation data was being augmented on every epoch. This caused the model to quickly overfit on the training data while the validation loss continually increased. Moving the augment call after cache() solved the problem.
Answered by Kevin D. on September 16, 2020
Get help from others!
Recent Answers
Recent Questions
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP