How to use CosineAnnealingWarmRestarts in Iteration level instead of epoch level

I am training my model using an IterableDataset. I need to use the CosineAnnealingWarmRestarts scheduler. The dataset is extensively large and I do not know the length of the dataset. How can I use the scheduler such that it restarts after n iterations instead of n epochs in pytorch lightning?

Here is the code I use. We need to specify that the lr_scheduler is updated by n (in my code, equal to 1) steps.

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(),
                               lr=train_config.G_lr,
                               betas=(train_config.beta1, train_config.beta2))
        sch = LinearWarmupCosineAnnealingLR(
            # warmup_epochs actually means steps if "interval" is set to "step"
            opt,
            warmup_epochs=train_config.warmup_step, 
            max_epochs=train_config.total_step
        )
        return {
            'optimizer': opt,
            'lr_scheduler': {
                'name': 'train/lr',  # put lr inside train group in tensorboard
                'scheduler': sch,
                'interval': 'step', 
                'frequency': 1,
            }
        }
2 Likes

That seems to work. Thank you @pull_request

1 Like