I need to switch on class-balanced sampling from a certain epoch. What is the most efficient way to achieve this?
I have tried creating a callback that sets the trainer.train_dataloader
attribute inside the on_train_epoch_start
hook. It checks the trainer.current_epoch
and, if needed, tries to set the train_dataloader attribute to the appropriate Dataloader. However, during actual run, this step fails with an AttributeError
saying that the attribute cannot be set.
The best alternative that I can think of right now is to redefine the train_dataloader in my DataModule. That is, choose a sampler based on self.trainer.current_epoch
. And accompany this by setting the reload_dataloaders_every_n_epochs
to some appropriate number.
Is there a better way to do this? I just need random sampling till, say epoch n, and class-balanced sampling thereafter.