RichProgressBar¶
- class lightning.pytorch.callbacks.RichProgressBar(refresh_rate=1, leave=False, theme=RichProgressBarTheme(description='white', progress_bar='#6206E0', progress_bar_finished='#6206E0', progress_bar_pulse='#6206E0', batch_progress='white', time='grey54', processing_speed='grey70', metrics='white'), console_kwargs=None)[source]¶
Bases:
ProgressBar
Create a progress bar with rich text formatting.
Install it with pip:
pip install rich
from lightning.pytorch import Trainer from lightning.pytorch.callbacks import RichProgressBar trainer = Trainer(callbacks=RichProgressBar())
- Parameters:
refresh_rate¶ (
int
) – Determines at which rate (in number of batches) the progress bars get updated. Set it to0
to disable the display.leave¶ (
bool
) – Leaves the finished progress bar in the terminal at the end of the epoch. Default: Falsetheme¶ (
RichProgressBarTheme
) – Contains styles used to stylize the progress bar.console_kwargs¶ (
Optional
[Dict
[str
,Any
]]) – Args for constructing a Console
- Raises:
ModuleNotFoundError – If required rich package is not installed on the device.
Note
PyCharm users will need to enable “emulate terminal” in output console option in run/debug configuration to see styled output. Reference: https://rich.readthedocs.io/en/latest/introduction.html#requirements
- enable()[source]¶
You should provide a way to enable the progress bar.
The
Trainer
will call this in e.g. pre-training routines like the learning rate finder. to temporarily enable and disable the training progress bar.- Return type:
- on_exception(trainer, pl_module, exception)[source]¶
Called when any trainer execution is interrupted by an exception.
- Return type:
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the predict batch ends.
- Return type:
- on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the predict batch begins.
- Return type:
- on_sanity_check_end(trainer, pl_module)[source]¶
Called when the validation sanity check ends.
- Return type:
- on_sanity_check_start(trainer, pl_module)[source]¶
Called when the validation sanity check starts.
- Return type:
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the test batch ends.
- Return type:
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the test batch begins.
- Return type:
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]¶
Called when the train batch ends. :rtype:
None
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.
- on_train_epoch_end(trainer, pl_module)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- Return type:
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the validation batch ends.
- Return type:
- on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the validation batch begins.
- Return type: