Originally from: Shreeyak Sajjan
I’m training with a strategy of alternate batches of 2 datasets.
I.e., 1 batch of images from dataset A only, then a batch full of images from dataset B only. The sizes of the datasets are mismatched, but both use same batch size.
Any directions to achieve this with pytorch lightning? Normally, I’d look at the batch_idx and select a datset to draw from based on whether it’d odd or even
This is possible by using a custom dataset:
https://pytorch-lightning.readthedocs.io/en/stable/multiple_loaders.html
class ConcatDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.datasets = datasets
def __getitem__(self, i):
return tuple(d[i] for d in self.datasets)
def __len__(self):
return min(len(d) for d in self.datasets)
def train_dataloader(self):
concat_dataset = ConcatDataset(
datasets.ImageFolder(traindir_A),
datasets.ImageFolder(traindir_B)
)
loader = torch.utils.data.DataLoader(
concat_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True
)
return loader
def val_dataloader(self):
# SAME
...
def test_dataloader(self):
# SAME