Customize the progress bar¶
Lightning supports two different types of progress bars (tqdm and rich). TQDMProgressBar
is used by default,
but you can override it by passing a custom TQDMProgressBar
or RichProgressBar
to the callbacks
argument of the Trainer
.
You could also use the ProgressBar
class to implement your own progress bar.
TQDMProgressBar¶
The TQDMProgressBar
uses the tqdm library internally and is the default progress bar used by Lightning.
It prints to stdout
and shows up to four different bars:
sanity check progress: the progress during the sanity check run
train progress: shows the training progress. It will pause if validation starts and will resume when it ends, and also accounts for multiple validation runs during training when
val_check_interval
is used.validation progress: only visible during validation; shows total progress over all validation datasets.
test progress: only active when testing; shows total progress over all test datasets.
For infinite datasets, the progress bar never ends.
You can update refresh_rate
(rate (number of batches) at which the progress bar get updated) for TQDMProgressBar
by:
from lightning.pytorch.callbacks import TQDMProgressBar
trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)])
By default the training progress bar is reset (overwritten) at each new epoch.
If you wish for a new progress bar to be displayed at the end of every epoch, set
TQDMProgressBar.leave
to True
.
trainer = Trainer(callbacks=[TQDMProgressBar(leave=True)])
If you want to customize the default TQDMProgressBar
used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the Trainer
.
class LitProgressBar(TQDMProgressBar):
def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
bar.set_description("running validation...")
return bar
trainer = Trainer(callbacks=[LitProgressBar()])
See also
TQDMProgressBar
docs.
RichProgressBar¶
Rich is a Python library for rich text and beautiful formatting in the terminal.
To use the RichProgressBar
as your progress bar, first install the package:
pip install rich
Then configure the callback and pass it to the Trainer
:
from lightning.pytorch.callbacks import RichProgressBar
trainer = Trainer(callbacks=[RichProgressBar()])
Customize the theme for your RichProgressBar
like this:
from lightning.pytorch.callbacks import RichProgressBar
from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
# create your own theme!
progress_bar = RichProgressBar(
theme=RichProgressBarTheme(
description="green_yellow",
progress_bar="green1",
progress_bar_finished="green1",
progress_bar_pulse="#6206E0",
batch_progress="green_yellow",
time="grey82",
processing_speed="grey82",
metrics="grey82",
metrics_text_delimiter="\n",
metrics_format=".3e",
)
)
trainer = Trainer(callbacks=progress_bar)
You can customize the components used within RichProgressBar
with ease by overriding the
configure_columns()
method.
from rich.progress import TextColumn
custom_column = TextColumn("[progress.description]Custom Rich Progress Bar!")
class CustomRichProgressBar(RichProgressBar):
def configure_columns(self, trainer):
return [custom_column]
progress_bar = CustomRichProgressBar()
If you wish for a new progress bar to be displayed at the end of every epoch, you should enable
RichProgressBar.leave
by passing True
from lightning.pytorch.callbacks import RichProgressBar
trainer = Trainer(callbacks=[RichProgressBar(leave=True)])
See also
RichProgressBar
docs.RichModelSummary
docs to customize the model summary table.
Note
Progress bar is automatically enabled with the Trainer, and to disable it, one should do this:
trainer = Trainer(enable_progress_bar=False)