This question is taken from github :
I have a labeled and an unlabeled dataset that I am using for a semi-supervised segmentation problem.
Here is what I want to implement in the training loop:
Train an epoch on the labeled dataset. (alternatively, a batch)
Train an epoch on the unlabeled dataset. (alternatively, a batch)
What is the best way to do it?
Hi, it is currently not possible to return multiple dataloaders for training (that only works for validation).
A feature for this is in progress here #1959 .
However, in your case, I think it is more elegant to do this:
Step 1:
return the right dataloader in each epoch:
def train_dataloader(self):
if self.current_epoch % 2 == 0:
labeled_dataloader = ...
return labeled_dataloader
else:
unlabeled_dataloader = ...
return unlabeled_dataloader
Step 2:
modify your training_step like this:
def training_step(...):
if self.current_epoch % 2 == 0:
# apply loss with labels
else:
# apply unsupervised loss
return ...
Step 3:
Finally, tell Trainer to call the train_dataloader method every epoch, so it will switch to the new dataset.
trainer = Trainer(..., reload_dataloaders_every_epoch=True) # False by default
6 Likes
I think, self.current_epoch
is not accessible if we use pl.LightningDataModule
, Any suggestions?
You can use self.trainer.current_epoch
inside pl.LightningDataModule
to get the current epoch.
2 Likes
Olii
August 27, 2022, 12:40pm
5