Data Science Asked on June 20, 2021
Initially, a data loader is created with certain samples. While training I need to replace a sample which is in dataloader. How to replace it in to dataloader.
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
for sample,label in train_dataloader:
prediction of model
select misclassified samples and change them in train_dataloader but how to change sample in train_dataloader
While training, the misclassified samples need to be modified.
So How to replace a sample within train_dataloader?
Usually, you would process data at the sample level by creating your own class inheriting torch.utils.data.Dataset
and do the processing inside __getitem__
method.
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getitem__(self, i):
# Call the __getitem__ from Dataset class
sample = super().__getitem__(i)
# ...Do some processing on the sample here
return sample
train_data = MyDataset()
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
You can use the collate_fn
parameter of torch.utils.data.DataLoader
to specify a custom function which processes your batch.
from torch.utils.data import Dataset, DataLoader
def collate_fn(samples):
# samples is a list of samples you get from the __getitem__ function of your torch.utils.data.Dataset instance
# You can write here whatever processing you need before stacking samples into one batch of data
batch = torch.stack(samples, dim=0)
return batch
train_data = Dataset()
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size, collate_fn=collate_fn)
Answered by Adam Oudad on June 20, 2021
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP