Shortcuts

progress

Functions

convert_inf

The tqdm doesn’t support inf/nan values.

reset

Resets the tqdm bar to 0 progress with a new total, unless it is disabled.

Classes

ProgressBar

This is the default progress bar used by Lightning.

ProgressBarBase

The base class for progress bars in Lightning.

tqdm

Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering

Progress Bars

Use or override one of the progress bar callbacks.

class pytorch_lightning.callbacks.progress.ProgressBar(refresh_rate=1, process_position=0)[source]

Bases: pytorch_lightning.callbacks.progress.ProgressBarBase

This is the default progress bar used by Lightning. It prints to stdout using the tqdm package and shows up to four different bars:

  • sanity check progress: the progress during the sanity check run

  • main progress: shows training + validation progress combined. It also accounts for multiple validation runs during training when val_check_interval is used.

  • validation progress: only visible during validation; shows total progress over all validation datasets.

  • test progress: only active when testing; shows total progress over all test datasets.

For infinite datasets, the progress bar never ends.

If you want to customize the default tqdm progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the Trainer:

Example:

class LitProgressBar(ProgressBar):

    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.set_description('running validation ...')
        return bar

bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
Parameters
  • refresh_rate (int) – Determines at which rate (in number of batches) the progress bars get updated. Set it to 0 to disable the display. By default, the Trainer uses this implementation of the progress bar and sets the refresh rate to the value provided to the progress_bar_refresh_rate argument in the Trainer.

  • process_position (int) – Set this to a value greater than 0 to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. This corresponds to process_position in the Trainer.

disable()[source]

You should provide a way to disable the progress bar. The Trainer will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training.

Return type

None

enable()[source]

You should provide a way to enable the progress bar. The Trainer will call this in e.g. pre-training routines like the learning rate finder to temporarily enable and disable the main progress bar.

Return type

None

init_predict_tqdm()[source]

Override this to customize the tqdm bar for predicting.

Return type

tqdm

init_sanity_tqdm()[source]

Override this to customize the tqdm bar for the validation sanity run.

Return type

tqdm

init_test_tqdm()[source]

Override this to customize the tqdm bar for testing.

Return type

tqdm

init_train_tqdm()[source]

Override this to customize the tqdm bar for training.

Return type

tqdm

init_validation_tqdm()[source]

Override this to customize the tqdm bar for validation.

Return type

tqdm

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the predict batch ends.

on_predict_end(trainer, pl_module)[source]

Called when predict ends.

on_predict_epoch_start(trainer, pl_module)[source]

Called when the predict epoch begins.

on_sanity_check_end(trainer, pl_module)[source]

Called when the validation sanity check ends.

on_sanity_check_start(trainer, pl_module)[source]

Called when the validation sanity check starts.

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the test batch ends.

on_test_end(trainer, pl_module)[source]

Called when the test ends.

on_test_start(trainer, pl_module)[source]

Called when the test begins.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the train batch ends.

on_train_end(trainer, pl_module)[source]

Called when the train ends.

on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

on_train_start(trainer, pl_module)[source]

Called when the train begins.

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch ends.

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

print(*args, sep=' ', end='\\n', file=None, nolock=False)[source]

You should provide a way to print without breaking the progress bar.

class pytorch_lightning.callbacks.progress.ProgressBarBase[source]

Bases: pytorch_lightning.callbacks.base.Callback

The base class for progress bars in Lightning. It is a Callback that keeps track of the batch progress in the Trainer. You should implement your highly custom progress bars with this as the base class.

Example:

class LitProgressBar(ProgressBarBase):

    def __init__(self):
        super().__init__()  # don't forget this :)
        self.enable = True

    def disable(self):
        self.enable = False

    def on_train_batch_end(self, trainer, pl_module, outputs):
        super().on_train_batch_end(trainer, pl_module, outputs)  # don't forget this :)
        percent = (self.train_batch_idx / self.total_train_batches) * 100
        sys.stdout.flush()
        sys.stdout.write(f'{percent:.01f} percent complete \r')

bar = LitProgressBar()
trainer = Trainer(callbacks=[bar])
disable()[source]

You should provide a way to disable the progress bar. The Trainer will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training.

enable()[source]

You should provide a way to enable the progress bar. The Trainer will call this in e.g. pre-training routines like the learning rate finder to temporarily enable and disable the main progress bar.

on_init_end(trainer)[source]

Called when the trainer initialization ends, model has not yet been set.

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the predict batch ends.

on_predict_epoch_start(trainer, pl_module)[source]

Called when the predict epoch begins.

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the test batch ends.

on_test_start(trainer, pl_module)[source]

Called when the test begins.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the train batch ends.

on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

on_train_start(trainer, pl_module)[source]

Called when the train begins.

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch ends.

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

print(*args, **kwargs)[source]

You should provide a way to print without breaking the progress bar.

property predict_batch_idx: int

The current batch index being processed during predicting. Use this to update your progress bar.

property test_batch_idx: int

The current batch index being processed during testing. Use this to update your progress bar.

property total_predict_batches: int

The total number of predicting batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return inf if the predict dataloader is of infinite size.

property total_test_batches: int

The total number of testing batches during testing, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return inf if the test dataloader is of infinite size.

property total_train_batches: int

The total number of training batches during training, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return inf if the training dataloader is of infinite size.

property total_val_batches: int

The total number of validation batches during validation, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return inf if the validation dataloader is of infinite size.

property train_batch_idx: int

The current batch index being processed during training. Use this to update your progress bar.

property val_batch_idx: int

The current batch index being processed during validation. Use this to update your progress bar.

class pytorch_lightning.callbacks.progress.tqdm(*args, **kwargs)[source]

Bases: tqdm.auto.

Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering

static format_num(n)[source]

Add additional padding to the formatted numbers

Return type

str

pytorch_lightning.callbacks.progress.convert_inf(x)[source]

The tqdm doesn’t support inf/nan values. We have to convert it to None.

Return type

Union[int, float, None]

pytorch_lightning.callbacks.progress.reset(bar, total=None)[source]

Resets the tqdm bar to 0 progress with a new total, unless it is disabled.

Return type

None