Hi folks,
I’m quite new to using PyTorch and Lightning and seem to be running into the same error detailed as below using the lr_scheduler.OneCycleLR, but there is no solution there applicable to my problem:
My optimizer and training loops are as follows - perhaps there is an implementation error?:
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.lr)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=self.lr,
total_steps = self.epochs*self.steps_per_epoch,
div_factor = 1
)
scheduler = {"scheduler": lr_scheduler, "interval" : "step"}
return [optimizer], [scheduler]
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
y_hats = torch.split(y_hat, 9) # split every 9 frames (i.e. each sample)
y_hat = torch.stack([torch.mean(yhat, 0) for yhat in y_hats]) # take all yhat predictions and average over all 9 subframes per sample
loss = self.loss_func(y_hat, y)
result = pl.TrainResult(loss)
result.log('train_loss:', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return result
The error trace is as below and always occurs at the end of training, with an error of 2 extra steps. It seems that scheduler.step()
is called 2 too many times during training regardless of the number of training epochs. If anyone has any ideas for solutions I could be grateful, thanks.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-31-0144cd8befb8> in <module>
9 trainer = Trainer(gpus=1, deterministic=True, auto_lr_find=False, callbacks=[lr_monitor])
10 # net.epochs, net.steps_per_epoch
---> 11 trainer.fit(net, data)
/opt/conda/envs/fastai/lib/python3.8/site-packages/pytorch_lightning/trainer/states.py in wrapped_fn(self, *args, **kwargs)
46 if entering is not None:
47 self.state = entering
---> 48 result = fn(self, *args, **kwargs)
49
50 # The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
/opt/conda/envs/fastai/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
1071 self.accelerator_backend = GPUBackend(self)
1072 model = self.accelerator_backend.setup(model)
-> 1073 results = self.accelerator_backend.train(model)
1074
1075 elif self.use_tpu:
/opt/conda/envs/fastai/lib/python3.8/site-packages/pytorch_lightning/accelerators/gpu_backend.py in train(self, model)
49
50 def train(self, model):
---> 51 results = self.trainer.run_pretrain_routine(model)
52 return results
53
/opt/conda/envs/fastai/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_pretrain_routine(self, model)
1237
1238 # CORE TRAINING LOOP
-> 1239 self.train()
1240
1241 def _run_sanity_check(self, ref_model, model):
/opt/conda/envs/fastai/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in train(self)
392 # RUN TNG EPOCH
393 # -----------------
--> 394 self.run_training_epoch()
395
396 if self.max_steps and self.max_steps <= self.global_step:
/opt/conda/envs/fastai/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
529 monitor_metrics = deepcopy(self.callback_metrics)
530 monitor_metrics.update(batch_output.batch_log_metrics)
--> 531 self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
532
533 # progress global step according to grads progress
/opt/conda/envs/fastai/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in update_train_loop_lr_schedulers(self, monitor_metrics)
597 or (self.batch_idx + 1) == self.num_training_batches):
598 # update lr
--> 599 self.update_learning_rates(interval='step', monitor_metrics=monitor_metrics)
600
601 def run_on_epoch_end_hook(self, model):
/opt/conda/envs/fastai/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py in update_learning_rates(self, interval, monitor_metrics)
1304
1305 # update LR
-> 1306 lr_scheduler['scheduler'].step()
1307
1308 if self.dev_debugger.enabled:
/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/optim/lr_scheduler.py in step(self, epoch)
139 if epoch is None:
140 self.last_epoch += 1
--> 141 values = self.get_lr()
142 else:
143 warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/optim/lr_scheduler.py in get_lr(self)
1210
1211 if step_num > self.total_steps:
-> 1212 raise ValueError("Tried to step {} times. The specified number of total steps is {}"
1213 .format(step_num + 1, self.total_steps))
1214
ValueError: Tried to step 402 times. The specified number of total steps is 400