Data Science Asked by Marco Caglia on April 22, 2021
I have used a pre-trained VGG-19 model to build an image classifier as part of a MOOC. I have implemented a classifier like this:
classifier = nn.Sequential(nn.Linear(25088, 512),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(512,350),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(350,250),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(250,102),
nn.LogSoftmax(dim=1))
model.classifier = classifier
In the next step, I was asked to save the trained model so it could be reused later. My understanding of saving such a model was that I had to create a dictionary with the input size (given by the pre-trained model), the output size (given by my specific problem), the model’s state dict and the hidden layers (given by my own code). I tried to do that like this:
checkpoint = {'input_size':25088,
'output_size':102,
'hidden_layers':[each for each in model.classifier],
'state_dict':model.state_dict(),
'optimizer':optimizer.state_dict(),
'epochs':epochs,
'class_to_index':cat_to_name}
The I tried to load it like this:
def load_checkpoint(path):
checkpoint = torch.load(path)
model = nn.Sequential(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['hidden_layers'])
model.load_state_dict(checkpoint['state_dict'])
epochs = checkpoint['epochs']
model.class_to_idx = checkpoint['class_to_index']
return model, epochs
test_model, _ = load_checkpoint('trained_model.pth')
print(test_model)
However, this just raises the error:
TypeError: int is not a Module subclass
What did I do wrong? As far as I can tell checkpoint[‘hidden_layers’] does not contain integers by themselves:
print(checkpoint['hidden_layers'])
Returns:
[Linear(in_features=25088, out_features=512, bias=True), ReLU(), Dropout(p=0.2), Linear(in_features=512, out_features=350, bias=True), ReLU(), Dropout(p=0.2), Linear(in_features=350, out_features=250, bias=True), ReLU(), Dropout(p=0.2), Linear(in_features=250, out_features=102, bias=True), LogSoftmax()]
Which is what I would expect.
Thank you for reading! Any thoughts or comments are highly appreciated!
There are a few mistakes. The following is the way u need to change your code. I am not sure what your model class consists but I am hoping it is class function with .classifier consisting the network.
def load_checkpoint(model, path):
checkpoint = torch.load(path)
model.classifier.load_state_dict(checkpoint['state_dict'])
epochs = checkpoint['epochs']
model.class_to_idx = checkpoint['class_to_index']
return model, epochs
classifier = nn.Sequential(nn.Linear(25088, 512),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(512,350),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(350,250),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(250,102),
nn.LogSoftmax(dim=1))
model.classifier = classifier
test_model, _ = load_checkpoint(model, 'trained_model.pth')
print(test_model)
Similarly you need to load optimizer, scheduler and other attributes if you want to retrain the network.
if your intention is to load only part of the network instead of the entire network. Something as below
classifier = nn.Sequential(nn.Linear(25088, 512),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(512,350),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(350,250))
model.classifier = classifier
weights = torch.load("weight.pth")
state_dict = model.classifier.state_dict()
for name, param in model.classifier.named_parameters():
if name in state_dict.keys():
state_dict[name] = weights["state_dict"][name]
model.classifier.load_state_dict(state_dict)
In this way you can only load the part of the network required.
Answered by Prakash Vanapalli on April 22, 2021
Get help from others!
Recent Answers
Recent Questions
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP