Source code for pytorch_lightning.callbacks.progress.base
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportAny,Dict,Optional,Unionimportpytorch_lightningasplfrompytorch_lightning.callbacksimportCallbackfrompytorch_lightning.utilities.loggerimport_versionfrompytorch_lightning.utilities.rank_zeroimportrank_zero_warn
[docs]classProgressBarBase(Callback):r""" The base class for progress bars in Lightning. It is a :class:`~pytorch_lightning.callbacks.Callback` that keeps track of the batch progress in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. You should implement your highly custom progress bars with this as the base class. Example:: class LitProgressBar(ProgressBarBase): def __init__(self): super().__init__() # don't forget this :) self.enable = True def disable(self): self.enable = False def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch_idx) # don't forget this :) percent = (self.train_batch_idx / self.total_train_batches) * 100 sys.stdout.flush() sys.stdout.write(f'{percent:.01f} percent complete \r') bar = LitProgressBar() trainer = Trainer(callbacks=[bar]) """def__init__(self)->None:self._trainer:Optional["pl.Trainer"]=Noneself._current_eval_dataloader_idx:Optional[int]=None@propertydeftrainer(self)->"pl.Trainer":ifself._trainerisNone:raiseTypeError(f"The `{self.__class__.__name__}._trainer` reference has not been set yet.")returnself._trainer@propertydefsanity_check_description(self)->str:return"Sanity Checking"@propertydeftrain_description(self)->str:return"Training"@propertydefvalidation_description(self)->str:return"Validation"@propertydeftest_description(self)->str:return"Testing"@propertydefpredict_description(self)->str:return"Predicting"@propertydef_val_processed(self)->int:# use total in case validation runs more than once per training epochreturnself.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed@propertydeftrain_batch_idx(self)->int:"""The number of batches processed during training. Use this to update your progress bar. """returnself.trainer.fit_loop.epoch_loop.batch_progress.current.processed@propertydefval_batch_idx(self)->int:"""The number of batches processed during validation. Use this to update your progress bar. """ifself.trainer.state.fn=="fit":loop=self.trainer.fit_loop.epoch_loop.val_loopelse:loop=self.trainer.validate_loopcurrent_batch_idx=loop.epoch_loop.batch_progress.current.processedreturncurrent_batch_idx@propertydeftest_batch_idx(self)->int:"""The number of batches processed during testing. Use this to update your progress bar. """returnself.trainer.test_loop.epoch_loop.batch_progress.current.processed@propertydefpredict_batch_idx(self)->int:"""The number of batches processed during prediction. Use this to update your progress bar. """returnself.trainer.predict_loop.epoch_loop.batch_progress.current.processed@propertydeftotal_train_batches(self)->Union[int,float]:"""The total number of training batches, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training dataloader is of infinite size. """returnself.trainer.num_training_batches@propertydeftotal_val_batches_current_dataloader(self)->Union[int,float]:"""The total number of validation batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. """assertself._current_eval_dataloader_idxisnotNoneifself.trainer.sanity_checking:returnself.trainer.num_sanity_val_batches[self._current_eval_dataloader_idx]returnself.trainer.num_val_batches[self._current_eval_dataloader_idx]@propertydeftotal_test_batches_current_dataloader(self)->Union[int,float]:"""The total number of testing batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. """assertself._current_eval_dataloader_idxisnotNonereturnself.trainer.num_test_batches[self._current_eval_dataloader_idx]@propertydeftotal_predict_batches_current_dataloader(self)->Union[int,float]:"""The total number of prediction batches, which may change from epoch to epoch for current dataloader. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. """assertself._current_eval_dataloader_idxisnotNonereturnself.trainer.num_predict_batches[self._current_eval_dataloader_idx]@propertydeftotal_val_batches(self)->Union[int,float]:"""The total number of validation batches, which may change from epoch to epoch for all val dataloaders. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. """assertself._trainerisnotNonereturnsum(self.trainer.num_val_batches)ifself._trainer.fit_loop.epoch_loop._should_check_val_epoch()else0defhas_dataloader_changed(self,dataloader_idx:int)->bool:old_dataloader_idx=self._current_eval_dataloader_idxself._current_eval_dataloader_idx=dataloader_idxreturnold_dataloader_idx!=dataloader_idxdefreset_dataloader_idx_tracker(self)->None:self._current_eval_dataloader_idx=None
[docs]defdisable(self)->None:"""You should provide a way to disable the progress bar."""raiseNotImplementedError
[docs]defenable(self)->None:"""You should provide a way to enable the progress bar. The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training routines like the :ref:`learning rate finder <advanced/training_tricks:Learning Rate Finder>`. to temporarily enable and disable the main progress bar. """raiseNotImplementedError
[docs]defprint(self,*args:Any,**kwargs:Any)->None:"""You should provide a way to print without breaking the progress bar."""print(*args,**kwargs)
[docs]defget_metrics(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->Dict[str,Union[int,str]]:r""" Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. Here is an example of how to override the defaults: .. code-block:: python def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) return items Return: Dictionary with the items to be displayed in the progress bar. """standard_metrics=pl_module.get_progress_bar_dict()pbar_metrics=trainer.progress_bar_metricsduplicates=list(standard_metrics.keys()&pbar_metrics.keys())ifduplicates:rank_zero_warn(f"The progress bar already tracks a metric with the name(s) '{', '.join(duplicates)}' and"f" `self.log('{duplicates[0]}', ..., prog_bar=True)` will overwrite this value. "" If this is undesired, change the name or override `get_metrics()` in the progress bar callback.",)return{**standard_metrics,**pbar_metrics}
defget_standard_metrics(trainer:"pl.Trainer",pl_module:"pl.LightningModule")->Dict[str,Union[int,str]]:r""" Returns several standard metrics displayed in the progress bar, including the average loss value, split index of BPTT (if used) and the version of the experiment when using a logger. .. code-block:: Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10] Return: Dictionary with the standard metrics to be displayed in the progress bar. """# call .item() only once but store elements without graphsrunning_train_loss=trainer.fit_loop.running_loss.mean()avg_training_loss=Noneifrunning_train_lossisnotNone:avg_training_loss=running_train_loss.cpu().item()elifpl_module.automatic_optimization:avg_training_loss=float("NaN")items_dict:Dict[str,Union[int,str]]={}ifavg_training_lossisnotNone:items_dict["loss"]=f"{avg_training_loss:.3g}"ifpl_module.truncated_bptt_steps>0:items_dict["split_idx"]=trainer.fit_loop.split_idxiftrainer.loggers:version=_version(trainer.loggers)ifversionisnotNone:ifisinstance(version,str):# show last 4 places of long version stringsversion=version[-4:]items_dict["v_num"]=versionreturnitems_dict
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.