I’m in the situation where I’m setting up a training for multiple datasets of different lengths → different number of batches in the DataLoaders. Since I want to use a loss function which combines losses across different datasets (e.g. weighted differently for each dataset) I was trying to keep dataloaders separately into dictionaries:
def train_dataloader(self): #returns a dict of dataloaders
train_loaders = {}
for key, value in self.train_dict.items():
train_loaders[key] = DataLoader(value,
batch_size = self.batch_size,
collate_fn = collate)
return train_loaders
Then in training_step()
I’m doing the following (I’m working on a contrastive learning project):
def training_step(self, batch, batch_idx):
total_batch_loss = 0
for key, value in batch.items():
anc, pos, neg = value
emb_anc = F.normalize(self.forward(anc.x,
anc.edge_index,
anc.weights,
anc.batch,
training=True
), 2, dim=1)
emb_pos = F.normalize(self.forward(pos.x,
pos.edge_index,
pos.weights,
pos.batch,
training=True
), 2, dim=1)
emb_neg = F.normalize(self.forward(neg.x,
neg.edge_index,
neg.weights,
neg.batch,
training=True
), 2, dim=1)
loss_dataset = LossFunc(emb_anc, emb_pos, emb_neg, anc.y, pos.y, neg.y)
total_batch_loss += loss_dataset
self.log("Loss", total_batch_loss, prog_bar=True, on_epoch=True)
return total_batch_loss
The problem is that I have a different number of batches per dataset and then Lightning will throw a StopIteration
when the smallest dataset is exhausted. I was considering simply concatenating everything into a single DataLoader
, but that way I don’t think I will be able to compute a loss that gets weighted as I want according to each dataset.