RichProgressBar

class lightning.pytorch.callbacks.RichProgressBar(refresh_rate=1, leave=False, theme=RichProgressBarTheme(description='', progress_bar='#6206E0', progress_bar_finished='#6206E0', progress_bar_pulse='#6206E0', batch_progress='', time='dim', processing_speed='dim underline', metrics='italic', metrics_text_delimiter=' ', metrics_format='.3f'), console_kwargs=None)[source]

Bases: ProgressBar

Create a progress bar with rich text formatting.

Install it with pip:

pip install rich
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import RichProgressBar

trainer = Trainer(callbacks=RichProgressBar())
Parameters:
  • refresh_rate (int) – Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the display.

  • leave (bool) – Leaves the finished progress bar in the terminal at the end of the epoch. Default: False

  • theme (RichProgressBarTheme) – Contains styles used to stylize the progress bar.

  • console_kwargs (Optional[dict[str, Any]]) – Args for constructing a Console

Raises:

ModuleNotFoundError – If required rich package is not installed on the device.

Note

PyCharm users will need to enable “emulate terminal” in output console option in run/debug configuration to see styled output. Reference: https://rich.readthedocs.io/en/latest/introduction.html#requirements

disable()[source]

You should provide a way to disable the progress bar.

Return type:

None

enable()[source]

You should provide a way to enable the progress bar.

The Trainer will call this in e.g. pre-training routines like the learning rate finder. to temporarily enable and disable the training progress bar.

Return type:

None

on_exception(trainer, pl_module, exception)[source]

Called when any trainer execution is interrupted by an exception.

Return type:

None

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the predict batch ends.

Return type:

None

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]

Called when the predict batch begins.

Return type:

None

on_predict_end(trainer, pl_module)[source]

Called when predict ends.

Return type:

None

on_predict_start(trainer, pl_module)[source]

Called when the predict begins.

Return type:

None

on_sanity_check_end(trainer, pl_module)[source]

Called when the validation sanity check ends.

Return type:

None

on_sanity_check_start(trainer, pl_module)[source]

Called when the validation sanity check starts.

Return type:

None

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the test batch ends.

Return type:

None

on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]

Called when the test batch begins.

Return type:

None

on_test_end(trainer, pl_module)[source]

Called when the test ends.

Return type:

None

on_test_start(trainer, pl_module)[source]

Called when the test begins.

Return type:

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Called when the train batch ends. :rtype: None

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
Return type:

None

on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

Return type:

None

on_train_start(trainer, pl_module)[source]

Called when the train begins.

Return type:

None

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the validation batch ends.

Return type:

None

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]

Called when the validation batch begins.

Return type:

None

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

Return type:

None

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

Return type:

None

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

Return type:

None

teardown(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune ends.

Return type:

None