Source code for pytorch_lightning.callbacks.progress.tqdm_progress
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importimportlibimportmathimportosimportsysfromtypingimportAny,Dict,Optional,Union# check if ipywidgets is installed before importing tqdm.auto# to ensure it won't fail and a progress bar is displayedifimportlib.util.find_spec("ipywidgets")isnotNone:fromtqdm.autoimporttqdmas_tqdmelse:fromtqdmimporttqdmas_tqdmimportpytorch_lightningasplfrompytorch_lightning.callbacks.progress.baseimportProgressBarBasefrompytorch_lightning.utilities.rank_zeroimportrank_zero_debug_PAD_SIZE=5classTqdm(_tqdm):def__init__(self,*args:Any,**kwargs:Any)->None:"""Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering."""# this just to make the make docs happy, otherwise it pulls docs which has some issues...super().__init__(*args,**kwargs)@staticmethoddefformat_num(n:Union[int,float,str])->str:"""Add additional padding to the formatted numbers."""should_be_padded=isinstance(n,(float,str))ifnotisinstance(n,str):n=_tqdm.format_num(n)assertisinstance(n,str)ifshould_be_paddedand"e"notinn:if"."notinnandlen(n)<_PAD_SIZE:try:_=float(n)exceptValueError:returnnn+="."n+="0"*(_PAD_SIZE-len(n))returnn
[docs]classTQDMProgressBar(ProgressBarBase):r""" This is the default progress bar used by Lightning. It prints to ``stdout`` using the :mod:`tqdm` package and shows up to four different bars: - **sanity check progress:** the progress during the sanity check run - **main progress:** shows training + validation progress combined. It also accounts for multiple validation runs during training when :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. - **validation progress:** only visible during validation; shows total progress over all validation datasets. - **test progress:** only active when testing; shows total progress over all test datasets. For infinite datasets, the progress bar never ends. If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the :class:`~pytorch_lightning.trainer.trainer.Trainer`. Example: >>> class LitProgressBar(TQDMProgressBar): ... def init_validation_tqdm(self): ... bar = super().init_validation_tqdm() ... bar.set_description('running validation ...') ... return bar ... >>> bar = LitProgressBar() >>> from pytorch_lightning import Trainer >>> trainer = Trainer(callbacks=[bar]) Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display. By default, the :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress bar and sets the refresh rate to the value provided to the :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. process_position: Set this to a value greater than ``0`` to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. This corresponds to :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. """def__init__(self,refresh_rate:int=1,process_position:int=0):super().__init__()self._refresh_rate=self._resolve_refresh_rate(refresh_rate)self._process_position=process_positionself._enabled=Trueself._main_progress_bar:Optional[_tqdm]=Noneself._val_progress_bar:Optional[_tqdm]=Noneself._test_progress_bar:Optional[_tqdm]=Noneself._predict_progress_bar:Optional[_tqdm]=Nonedef__getstate__(self)->Dict:# can't pickle the tqdm objectsreturn{k:vifnotisinstance(v,_tqdm)elseNonefork,vinvars(self).items()}@propertydefmain_progress_bar(self)->_tqdm:ifself._main_progress_barisNone:raiseTypeError(f"The `{self.__class__.__name__}._main_progress_bar` reference has not been set yet.")returnself._main_progress_bar@main_progress_bar.setterdefmain_progress_bar(self,bar:_tqdm)->None:self._main_progress_bar=bar@propertydefval_progress_bar(self)->_tqdm:ifself._val_progress_barisNone:raiseTypeError(f"The `{self.__class__.__name__}._val_progress_bar` reference has not been set yet.")returnself._val_progress_bar@val_progress_bar.setterdefval_progress_bar(self,bar:_tqdm)->None:self._val_progress_bar=bar@propertydeftest_progress_bar(self)->_tqdm:ifself._test_progress_barisNone:raiseTypeError(f"The `{self.__class__.__name__}._test_progress_bar` reference has not been set yet.")returnself._test_progress_bar@test_progress_bar.setterdeftest_progress_bar(self,bar:_tqdm)->None:self._test_progress_bar=bar@propertydefpredict_progress_bar(self)->_tqdm:ifself._predict_progress_barisNone:raiseTypeError(f"The `{self.__class__.__name__}._predict_progress_bar` reference has not been set yet.")returnself._predict_progress_bar@predict_progress_bar.setterdefpredict_progress_bar(self,bar:_tqdm)->None:self._predict_progress_bar=bar@propertydefrefresh_rate(self)->int:returnself._refresh_rate@propertydefprocess_position(self)->int:returnself._process_position@propertydefis_enabled(self)->bool:returnself._enabledandself.refresh_rate>0@propertydefis_disabled(self)->bool:returnnotself.is_enabled
[docs]definit_sanity_tqdm(self)->Tqdm:"""Override this to customize the tqdm bar for the validation sanity run."""bar=Tqdm(desc=self.sanity_check_description,position=(2*self.process_position),disable=self.is_disabled,leave=False,dynamic_ncols=True,file=sys.stdout,)returnbar
[docs]definit_train_tqdm(self)->Tqdm:"""Override this to customize the tqdm bar for training."""bar=Tqdm(desc=self.train_description,initial=self.train_batch_idx,position=(2*self.process_position),disable=self.is_disabled,leave=True,dynamic_ncols=True,file=sys.stdout,smoothing=0,)returnbar
[docs]definit_predict_tqdm(self)->Tqdm:"""Override this to customize the tqdm bar for predicting."""bar=Tqdm(desc=self.predict_description,initial=self.train_batch_idx,position=(2*self.process_position),disable=self.is_disabled,leave=True,dynamic_ncols=True,file=sys.stdout,smoothing=0,)returnbar
[docs]definit_validation_tqdm(self)->Tqdm:"""Override this to customize the tqdm bar for validation."""# The main progress bar doesn't exist in `trainer.validate()`has_main_bar=self.trainer.state.fn!="validate"bar=Tqdm(desc=self.validation_description,position=(2*self.process_position+has_main_bar),disable=self.is_disabled,leave=nothas_main_bar,dynamic_ncols=True,file=sys.stdout,)returnbar
[docs]definit_test_tqdm(self)->Tqdm:"""Override this to customize the tqdm bar for testing."""bar=Tqdm(desc="Testing",position=(2*self.process_position),disable=self.is_disabled,leave=True,dynamic_ncols=True,file=sys.stdout,)returnbar
[docs]defon_sanity_check_start(self,*_:Any)->None:self.val_progress_bar=self.init_sanity_tqdm()self.main_progress_bar=Tqdm(disable=True)# dummy progress bar
[docs]defon_train_epoch_start(self,trainer:"pl.Trainer",*_:Any)->None:total_train_batches=self.total_train_batchestotal_val_batches=self.total_val_batchesiftotal_train_batches!=float("inf")andtotal_val_batches!=float("inf"):# val can be checked multiple times per epochval_checks_per_epoch=total_train_batches//trainer.val_check_batchtotal_val_batches=total_val_batches*val_checks_per_epochtotal_batches=total_train_batches+total_val_batchesself.main_progress_bar.total=convert_inf(total_batches)self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
def_should_update(self,current:int,total:int)->bool:returnself.is_enabledand(current%self.refresh_rate==0orcurrent==total)@staticmethoddef_resolve_refresh_rate(refresh_rate:int)->int:ifos.getenv("COLAB_GPU")andrefresh_rate==1:# smaller refresh rate on colab causes crashes, choose a higher valuerank_zero_debug("Using a higher refresh rate on Colab. Setting it to `20`")refresh_rate=20returnrefresh_rate
defconvert_inf(x:Optional[Union[int,float]])->Optional[Union[int,float]]:"""The tqdm doesn't support inf/nan values. We have to convert it to None. """ifxisNoneormath.isinf(x)ormath.isnan(x):returnNonereturnxdef_update_n(bar:_tqdm,value:int)->None:ifnotbar.disable:bar.n=valuebar.refresh()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.