Shortcuts

RichProgressBar

class lightning.pytorch.callbacks.RichProgressBar(refresh_rate=1, leave=False, theme=RichProgressBarTheme(description='white', progress_bar='#6206E0', progress_bar_finished='#6206E0', progress_bar_pulse='#6206E0', batch_progress='white', time='grey54', processing_speed='grey70', metrics='white'), console_kwargs=None)[source]

Bases: lightning.pytorch.callbacks.progress.progress_bar.ProgressBar

Create a progress bar with rich text formatting.

Install it with pip:

pip install rich
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import RichProgressBar

trainer = Trainer(callbacks=RichProgressBar())
Parameters
  • refresh_rate (int) – Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the display.

  • leave (bool) – Leaves the finished progress bar in the terminal at the end of the epoch. Default: False

  • theme (RichProgressBarTheme) – Contains styles used to stylize the progress bar.

  • console_kwargs (Optional[Dict[str, Any]]) – Args for constructing a Console

Raises

ModuleNotFoundError – If required rich package is not installed on the device.

Note

PyCharm users will need to enable “emulate terminal” in output console option in run/debug configuration to see styled output. Reference: https://rich.readthedocs.io/en/latest/introduction.html#requirements

disable()[source]

You should provide a way to disable the progress bar.

Return type

None

enable()[source]

You should provide a way to enable the progress bar.

The Trainer will call this in e.g. pre-training routines like the learning rate finder. to temporarily enable and disable the training progress bar.

Return type

None

on_exception(trainer, pl_module, exception)[source]

Called when any trainer execution is interrupted by an exception.

Return type

None

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the predict batch ends.

Return type

None

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]

Called when the predict batch begins.

Return type

None

on_predict_end(trainer, pl_module)[source]

Called when predict ends.

Return type

None

on_predict_start(trainer, pl_module)[source]

Called when the predict begins.

Return type

None

on_sanity_check_end(trainer, pl_module)[source]

Called when the validation sanity check ends.

Return type

None

on_sanity_check_start(trainer, pl_module)[source]

Called when the validation sanity check starts.

Return type

None

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the test batch ends.

Return type

None

on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]

Called when the test batch begins.

Return type

None

on_test_end(trainer, pl_module)[source]

Called when the test ends.

Return type

None

on_test_start(trainer, pl_module)[source]

Called when the test begins.

Return type

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

Return type

None

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
Return type

None

on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

Return type

None

on_train_start(trainer, pl_module)[source]

Called when the train begins.

Return type

None

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the validation batch ends.

Return type

None

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]

Called when the validation batch begins.

Return type

None

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

Return type

None

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

Return type

None

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

Return type

None

teardown(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune ends.

Return type

None