What I try to do is to make predictions on some data after every epoch to see a train tendency.
So I created callback like that
class Visualizer(pl.callbacks.Callback):
def __init__(self, X, y, test):
self.test = test
self.X = X
self.y = y
def on_epoch_end(self, trainer, pl_module):
preds = trainer.predict(pl_module, self.test)
plt.plot(torch.cat(preds), label='preds')
plt.plot(self.y, label='y')
plt.plot(self.X, label='X')
plt.legend()
plt.show()
and added it to the callbacks parameter:
trainer = pl.Trainer(max_epochs=1, limit_train_batches=1., limit_val_batches=1., num_sanity_val_steps=0, val_check_interval=1.0, callbacks=[ visualizer])
but after I call trainer.fit() I get my visualization plot and error immediately after that
AttributeError Traceback (most recent call last)
<ipython-input-29-0caff10202d0> in <module>
11
12 trainer = pl.Trainer(max_epochs=1, limit_train_batches=1., limit_val_batches=1., num_sanity_val_steps=0, val_check_interval=1.0, callbacks=[ visualizer])
---> 13 res = trainer.fit(model, train_dl, val_dl)
14 # plot_metrics(metrics_callback)
~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader)
551 self.checkpoint_connector.resume_start()
552
--> 553 self._run(model)
554
555 assert self.state.stopped
~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
916
917 # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 918 self._dispatch()
919
920 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _dispatch(self)
984 self.accelerator.start_predicting(self)
985 else:
--> 986 self.accelerator.start_training(self)
987
988 def run_stage(self):
~/base_env/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
90
91 def start_training(self, trainer: "pl.Trainer") -> None:
---> 92 self.training_type_plugin.start_training(trainer)
93
94 def start_evaluating(self, trainer: "pl.Trainer") -> None:
~/base_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
159 def start_training(self, trainer: "pl.Trainer") -> None:
160 # double dispatch to initiate the training loop
--> 161 self._results = trainer.run_stage()
162
163 def start_evaluating(self, trainer: "pl.Trainer") -> None:
~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
994 if self.predicting:
995 return self._run_predict()
--> 996 return self._run_train()
997
998 def _pre_training_routine(self):
~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
1043 # reset trainer on this loop and all child loops in case user connected a custom loop
1044 self.fit_loop.trainer = self
-> 1045 self.fit_loop.run()
1046 except KeyboardInterrupt:
1047 rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
109 try:
110 self.on_advance_start(*args, **kwargs)
--> 111 self.advance(*args, **kwargs)
112 self.on_advance_end()
113 self.iteration_count += 1
~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py in advance(self)
198 with self.trainer.profiler.profile("run_training_epoch"):
199 # run train epoch
--> 200 epoch_output = self.epoch_loop.run(train_dataloader)
201
202 if epoch_output is None:
~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
110 self.on_advance_start(*args, **kwargs)
111 self.advance(*args, **kwargs)
--> 112 self.on_advance_end()
113 self.iteration_count += 1
114 self.restarting = False
~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py in on_advance_end(self)
175 if should_check_val:
176 self.trainer.validating = True
--> 177 self._run_validation()
178 self.trainer.training = True
179
~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py in _run_validation(self)
254
255 with torch.no_grad():
--> 256 self.val_loop.run()
257
258 def _on_train_epoch_end_hook(self, processed_epoch_output: List[List[STEP_OUTPUT]]) -> None:
~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
116 break
117
--> 118 output = self.on_run_end()
119 return output
120
~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in on_run_end(self)
135
136 # hook
--> 137 self.on_evaluation_epoch_end()
138
139 # log epoch metrics
~/base_env/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py in on_evaluation_epoch_end(self)
258 self.trainer.call_hook(hook_name)
259 self.trainer.call_hook("on_epoch_end")
--> 260 self.trainer.logger_connector.on_epoch_end()
261
262 def teardown(self) -> None:
~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py in on_epoch_end(self)
248 def on_epoch_end(self) -> None:
249 assert self._epoch_end_reached
--> 250 metrics = self.metrics
251 self._progress_bar_metrics.update(metrics[MetricSource.PBAR])
252 self._callback_metrics.update(metrics[MetricSource.CALLBACK])
~/base_env/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py in metrics(self)
284 """This function returns either batch or epoch metrics depending on ``_epoch_end_reached``."""
285 on_step = not self._epoch_end_reached
--> 286 return self.trainer._results.metrics(on_step)
287
288 @property
AttributeError: 'NoneType' object has no attribute 'metrics'
So the questions themselves are the following:
- what’s the reason of the behaviour?
- what’s the best practice to do what I want?