I am training my model using an IterableDataset
. I need to use the CosineAnnealingWarmRestarts
scheduler. The dataset is extensively large and I do not know the length of the dataset. How can I use the scheduler such that it restarts after n iterations instead of n epochs in pytorch lightning?
Here is the code I use. We need to specify that the lr_scheduler is updated by n (in my code, equal to 1) steps.
def configure_optimizers(self):
opt = torch.optim.Adam(self.parameters(),
lr=train_config.G_lr,
betas=(train_config.beta1, train_config.beta2))
sch = LinearWarmupCosineAnnealingLR(
# warmup_epochs actually means steps if "interval" is set to "step"
opt,
warmup_epochs=train_config.warmup_step,
max_epochs=train_config.total_step
)
return {
'optimizer': opt,
'lr_scheduler': {
'name': 'train/lr', # put lr inside train group in tensorboard
'scheduler': sch,
'interval': 'step',
'frequency': 1,
}
}
2 Likes
That seems to work. Thank you @pull_request
1 Like