TQDMProgressBar¶
- class pytorch_lightning.callbacks.TQDMProgressBar(refresh_rate=1, process_position=0)[source]¶
Bases:
pytorch_lightning.callbacks.progress.base.ProgressBarBase
This is the default progress bar used by Lightning. It prints to
stdout
using thetqdm
package and shows up to four different bars:sanity check progress: the progress during the sanity check run
main progress: shows training + validation progress combined. It also accounts for multiple validation runs during training when
val_check_interval
is used.validation progress: only visible during validation; shows total progress over all validation datasets.
test progress: only active when testing; shows total progress over all test datasets.
For infinite datasets, the progress bar never ends.
If you want to customize the default
tqdm
progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to theTrainer
.Example
>>> class LitProgressBar(TQDMProgressBar): ... def init_validation_tqdm(self): ... bar = super().init_validation_tqdm() ... bar.set_description('running validation ...') ... return bar ... >>> bar = LitProgressBar() >>> from pytorch_lightning import Trainer >>> trainer = Trainer(callbacks=[bar])
- Parameters
refresh_rate¶ (
int
) – Determines at which rate (in number of batches) the progress bars get updated. Set it to0
to disable the display.process_position¶ (
int
) – Set this to a value greater than0
to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. This corresponds toprocess_position
in theTrainer
.
- enable()[source]¶
You should provide a way to enable the progress bar.
The
Trainer
will call this in e.g. pre-training routines like the learning rate finder. to temporarily enable and disable the main progress bar.- Return type
- init_predict_tqdm()[source]¶
Override this to customize the tqdm bar for predicting.
- Return type
Tqdm
- init_sanity_tqdm()[source]¶
Override this to customize the tqdm bar for the validation sanity run.
- Return type
Tqdm
- init_validation_tqdm()[source]¶
Override this to customize the tqdm bar for validation.
- Return type
Tqdm
- on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]¶
Called when the predict batch begins.
- Return type
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]¶
Called when the test batch begins.
- Return type
- on_train_batch_end(trainer, pl_module, *_)[source]¶
Called when the train batch ends.
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.- Return type
- on_train_epoch_end(trainer, pl_module)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either:
Implement training_epoch_end in the LightningModule and access outputs via the module OR
Cache data across train batch hooks inside the callback implementation to post-process in this hook.
- Return type
- on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]¶
Called when the validation batch begins.
- Return type