Train.predict() call in callback raises an error

What I try to do is to make predictions on some data after every epoch to see a train tendency.

So I created callback like that

class Visualizer(pl.callbacks.Callback):
    def __init__(self, X, y, test):
        self.test = test
        self.X = X
        self.y = y
        
        
    def on_epoch_end(self, trainer, pl_module):
        preds = trainer.predict(pl_module, self.test)
        plt.plot(torch.cat(preds), label='preds')
        plt.plot(self.y, label='y')
        plt.plot(self.X, label='X')
        plt.legend()
        plt.show()

and added it to the callbacks parameter:

trainer = pl.Trainer(max_epochs=1, limit_train_batches=1., limit_val_batches=1., num_sanity_val_steps=0, val_check_interval=1.0, callbacks=[ visualizer])

but after I call trainer.fit() I get my visualization plot and error immediately after that

AttributeError                            Traceback (most recent call last)
<ipython-input-29-0caff10202d0> in <module>
     11 
     12 trainer = pl.Trainer(max_epochs=1, limit_train_batches=1., limit_val_batches=1., num_sanity_val_steps=0, val_check_interval=1.0, callbacks=[ visualizer])
---> 13 res = trainer.fit(model, train_dl, val_dl)
     14 # plot_metrics(metrics_callback)

~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader)
    551         self.checkpoint_connector.resume_start()
    552 
--> 553         self._run(model)
    554 
    555         assert self.state.stopped

~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    916 
    917         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 918         self._dispatch()
    919 
    920         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
    984             self.accelerator.start_predicting(self)
    985         else:
--> 986             self.accelerator.start_training(self)
    987 
    988     def run_stage(self):

~/base_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     90 
     91     def start_training(self, trainer: "pl.Trainer") -> None:
---> 92         self.training_type_plugin.start_training(trainer)
     93 
     94     def start_evaluating(self, trainer: "pl.Trainer") -> None:

~/base_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    159     def start_training(self, trainer: "pl.Trainer") -> None:
    160         # double dispatch to initiate the training loop
--> 161         self._results = trainer.run_stage()
    162 
    163     def start_evaluating(self, trainer: "pl.Trainer") -> None:

~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    994         if self.predicting:
    995             return self._run_predict()
--> 996         return self._run_train()
    997 
    998     def _pre_training_routine(self):

~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
   1043             # reset trainer on this loop and all child loops in case user connected a custom loop
   1044             self.fit_loop.trainer = self
-> 1045             self.fit_loop.run()
   1046         except KeyboardInterrupt:
   1047             rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")

~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    109             try:
    110                 self.on_advance_start(*args, **kwargs)
--> 111                 self.advance(*args, **kwargs)
    112                 self.on_advance_end()
    113                 self.iteration_count += 1

~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py in advance(self)
    198         with self.trainer.profiler.profile("run_training_epoch"):
    199             # run train epoch
--> 200             epoch_output = self.epoch_loop.run(train_dataloader)
    201 
    202             if epoch_output is None:

~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    110                 self.on_advance_start(*args, **kwargs)
    111                 self.advance(*args, **kwargs)
--> 112                 self.on_advance_end()
    113                 self.iteration_count += 1
    114                 self.restarting = False

~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py in on_advance_end(self)
    175         if should_check_val:
    176             self.trainer.validating = True
--> 177             self._run_validation()
    178             self.trainer.training = True
    179 

~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py in _run_validation(self)
    254 
    255         with torch.no_grad():
--> 256             self.val_loop.run()
    257 
    258     def _on_train_epoch_end_hook(self, processed_epoch_output: List[List[STEP_OUTPUT]]) -> None:

~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    116                 break
    117 
--> 118         output = self.on_run_end()
    119         return output
    120 

~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in on_run_end(self)
    135 
    136         # hook
--> 137         self.on_evaluation_epoch_end()
    138 
    139         # log epoch metrics

~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in on_evaluation_epoch_end(self)
    258         self.trainer.call_hook(hook_name)
    259         self.trainer.call_hook("on_epoch_end")
--> 260         self.trainer.logger_connector.on_epoch_end()
    261 
    262     def teardown(self) -> None:

~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py in on_epoch_end(self)
    248     def on_epoch_end(self) -> None:
    249         assert self._epoch_end_reached
--> 250         metrics = self.metrics
    251         self._progress_bar_metrics.update(metrics[MetricSource.PBAR])
    252         self._callback_metrics.update(metrics[MetricSource.CALLBACK])

~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py in metrics(self)
    284         """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``."""
    285         on_step = not self._epoch_end_reached
--> 286         return self.trainer._results.metrics(on_step)
    287 
    288     @property

AttributeError: 'NoneType' object has no attribute 'metrics'

So the questions themselves are the following:

  1. what’s the reason of the behaviour?
  2. what’s the best practice to do what I want?

@Algernone have you been able to solve the issue? I am experiencing the same problem.

I’d also be interested in a solution to this.

@gogothorr @algernone see discussion here for an answer: How do I call predict periodically during fit, e.g. at the end of validation? · Discussion #16258 · Lightning-AI/lightning · GitHub