I am looking for a way to reinitialized my Datamodule with different parameter, I am currently sending the height of my images as argument to my datamodule and I want to change this height at some point during training, the simple way is to call trainer.fit multiple times with different datamodules, but I am wondering is there a way to do this on callback, in the same way as you do when you change the optimizer or lr_scheduler?
Hello in such case you need to force update used logger, I would consider adding a reset/update method to your data module which would be eventually called from Model/hook or callback…
Or shall we add also more hooks to the data module as the model has? @nate @teddy
I have done this using a callback:
class Scheduler(pl.Callback):
def _prepare_epoch(self, trainer, model, epoch):
phase = ...
trainer.datamodule.set_phase(phase)
def on_epoch_end(self, trainer, model):
self._prepare_epoch(trainer, model, trainer.current_epoch + 1)
class Data(pl.LightningDataModule):
def set_phase(self, phase: dict):
self.size = phase.get("size", self.size)
train_transforms = T.Compose(
[
T.RandomResizedCrop(self.size, scale=(self.min_scale, 1.0)),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize,
]
)
self.train_ds = ImageFolder(self.train_dir, transform=train_transforms)
def train_dataloader(self):
train_dl = DataLoader(
self.train_ds,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
)
return train_dl
Its important to note:
- You can access your datamodule from a callback using
trainer.datamodule
- In order to have
train_dataloader()
,val_dataloader()
called every epoch, you must setreload_dataloaders_every_epoch=True
in your trainer.
1 Like
Is the proposed solution still valid ?
In my codebase, trainer.datamodule is None.