Source code for pytorch_lightning.callbacks.progress.rich_progress
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importmathfromdataclassesimportdataclassfromdatetimeimporttimedeltafromtypingimportAny,Dict,Optional,Unionimportpytorch_lightningasplfrompytorch_lightning.callbacks.progress.baseimportProgressBarBasefrompytorch_lightning.utilities.importsimport_RICH_AVAILABLETask,Style=None,Noneif_RICH_AVAILABLE:fromrich.consoleimportConsole,RenderableTypefromrich.progressimportBarColumn,Progress,ProgressColumn,Task,TaskID,TextColumnfromrich.progress_barimportProgressBarfromrich.styleimportStylefromrich.textimportTextclassCustomBarColumn(BarColumn):"""Overrides ``BarColumn`` to provide support for dataloaders that do not define a size (infinite size) such as ``IterableDataset``."""defrender(self,task:"Task")->ProgressBar:"""Gets a progress bar widget for a task."""returnProgressBar(total=max(0,task.total),completed=max(0,task.completed),width=Noneifself.bar_widthisNoneelsemax(1,self.bar_width),pulse=nottask.startedornotmath.isfinite(task.remaining),animation_time=task.get_time(),style=self.style,complete_style=self.complete_style,finished_style=self.finished_style,pulse_style=self.pulse_style,)@dataclassclassCustomInfiniteTask(Task):"""Overrides ``Task`` to define an infinite task. This is useful for datasets that do not define a size (infinite size) such as ``IterableDataset``. """@propertydeftime_remaining(self)->Optional[float]:returnNoneclassCustomProgress(Progress):"""Overrides ``Progress`` to support adding tasks that have an infinite total size."""defadd_task(self,description:str,start:bool=True,total:float=100.0,completed:int=0,visible:bool=True,**fields:Any,)->TaskID:ifnotmath.isfinite(total):task=CustomInfiniteTask(self._task_index,description,total,completed,visible=visible,fields=fields,_get_time=self.get_time,_lock=self._lock,)returnself.add_custom_task(task)returnsuper().add_task(description,start,total,completed,visible,**fields)defadd_custom_task(self,task:CustomInfiniteTask,start:bool=True):withself._lock:self._tasks[self._task_index]=taskifstart:self.start_task(self._task_index)new_task_index=self._task_indexself._task_index=TaskID(int(self._task_index)+1)self.refresh()returnnew_task_indexclassCustomTimeColumn(ProgressColumn):# Only refresh twice a second to prevent jittermax_refresh=0.5def__init__(self,style:Union[str,Style])->None:self.style=stylesuper().__init__()defrender(self,task)->Text:elapsed=task.finished_timeiftask.finishedelsetask.elapsedremaining=task.time_remainingelapsed_delta="-:--:--"ifelapsedisNoneelsestr(timedelta(seconds=int(elapsed)))remaining_delta="-:--:--"ifremainingisNoneelsestr(timedelta(seconds=int(remaining)))returnText(f"{elapsed_delta} • {remaining_delta}",style=self.style)classBatchesProcessedColumn(ProgressColumn):def__init__(self,style:Union[str,Style]):self.style=stylesuper().__init__()defrender(self,task)->RenderableType:total=task.totaliftask.total!=float("inf")else"--"returnText(f"{int(task.completed)}/{total}",style=self.style)classProcessingSpeedColumn(ProgressColumn):def__init__(self,style:Union[str,Style]):self.style=stylesuper().__init__()defrender(self,task)->RenderableType:task_speed=f"{task.speed:>.2f}"iftask.speedisnotNoneelse"0.00"returnText(f"{task_speed}it/s",style=self.style)classMetricsTextColumn(ProgressColumn):"""A column containing text."""def__init__(self,trainer,style):self._trainer=trainerself._tasks={}self._current_task_id=0self._metrics={}self._style=stylesuper().__init__()defupdate(self,metrics):# Called when metrics are ready to be rendered.# This is to prevent render from causing deadlock issues by requesting metrics# in separate threads.self._metrics=metricsdefrender(self,task)->Text:if(self._trainer.state.fn!="fit"orself._trainer.sanity_checkingorself._trainer.progress_bar_callback.main_progress_bar_id!=task.id):returnText()ifself._trainer.trainingandtask.idnotinself._tasks:self._tasks[task.id]="None"ifself._renderable_cache:self._tasks[self._current_task_id]=self._renderable_cache[self._current_task_id][1]self._current_task_id=task.idifself._trainer.trainingandtask.id!=self._current_task_id:returnself._tasks[task.id]text=""fork,vinself._metrics.items():text+=f"{k}: {round(v,3)ifisinstance(v,float)elsev} "returnText(text,justify="left",style=self._style)@dataclassclassRichProgressBarTheme:"""Styles to associate to different base components. Args: description: Style for the progress bar description. For eg., Epoch x, Testing, etc. progress_bar: Style for the bar in progress. progress_bar_finished: Style for the finished progress bar. progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed. batch_progress: Style for the progress tracker (i.e 10/50 batches completed). time: Style for the processed time and estimate time remaining. processing_speed: Style for the speed of the batches being processed. metrics: Style for the metrics https://rich.readthedocs.io/en/stable/style.html """description:Union[str,Style]="white"progress_bar:Union[str,Style]="#6206E0"progress_bar_finished:Union[str,Style]="#6206E0"progress_bar_pulse:Union[str,Style]="#6206E0"batch_progress:Union[str,Style]="white"time:Union[str,Style]="grey54"processing_speed:Union[str,Style]="grey70"metrics:Union[str,Style]="white"
[docs]classRichProgressBar(ProgressBarBase):"""Create a progress bar with `rich text formatting <https://github.com/willmcgugan/rich>`_. Install it with pip: .. code-block:: bash pip install rich .. code-block:: python from pytorch_lightning import Trainer from pytorch_lightning.callbacks import RichProgressBar trainer = Trainer(callbacks=RichProgressBar()) Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display. leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. console_kwargs: Args for constructing a `Console` Raises: ModuleNotFoundError: If required `rich` package is not installed on the device. Note: PyCharm users will need to enable “emulate terminal” in output console option in run/debug configuration to see styled output. Reference: https://rich.readthedocs.io/en/latest/introduction.html#requirements """def__init__(self,refresh_rate:int=1,leave:bool=False,theme:RichProgressBarTheme=RichProgressBarTheme(),console_kwargs:Optional[Dict[str,Any]]=None,)->None:ifnot_RICH_AVAILABLE:raiseModuleNotFoundError("`RichProgressBar` requires `rich` >= 10.2.2. Install it by running `pip install -U rich`.")super().__init__()self._refresh_rate:int=refresh_rateself._leave:bool=leaveself._console_kwargs=console_kwargsor{}self._enabled:bool=Trueself.progress:Optional[Progress]=Noneself.val_sanity_progress_bar_id:Optional[int]=Noneself._reset_progress_bar_ids()self._metric_component=Noneself._progress_stopped:bool=Falseself.theme=themeself._update_for_light_colab_theme()@propertydefrefresh_rate(self)->float:returnself._refresh_rate@propertydefis_enabled(self)->bool:returnself._enabledandself.refresh_rate>0@propertydefis_disabled(self)->bool:returnnotself.is_enableddef_update_for_light_colab_theme(self)->None:if_detect_light_colab_theme():attributes=["description","batch_progress","metrics"]forattrinattributes:ifgetattr(self.theme,attr)=="white":setattr(self.theme,attr,"black")
def_init_progress(self,trainer):ifself.is_enabledand(self.progressisNoneorself._progress_stopped):self._reset_progress_bar_ids()self._console=Console(**self._console_kwargs)self._console.clear_live()self._metric_component=MetricsTextColumn(trainer,self.theme.metrics)self.progress=CustomProgress(*self.configure_columns(trainer),self._metric_component,auto_refresh=False,disable=self.is_disabled,console=self._console,)self.progress.start()# progress has startedself._progress_stopped=Falsedefrefresh(self)->None:ifself.progress:self.progress.refresh()
[docs]defon_train_epoch_start(self,trainer,pl_module):total_train_batches=self.total_train_batchestotal_val_batches=self.total_val_batchesiftotal_train_batches!=float("inf"):# val can be checked multiple times per epochval_checks_per_epoch=total_train_batches//trainer.val_check_batchtotal_val_batches=total_val_batches*val_checks_per_epochtotal_batches=total_train_batches+total_val_batchestrain_description=self._get_train_description(trainer.current_epoch)ifself.main_progress_bar_idisnotNoneandself._leave:self._stop_progress()self._init_progress(trainer)ifself.main_progress_bar_idisNone:self.main_progress_bar_id=self._add_task(total_batches,train_description)elifself.progressisnotNone:self.progress.reset(self.main_progress_bar_id,total=total_batches,description=train_description,visible=True)self.refresh()
[docs]defon_validation_batch_start(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule",batch:Any,batch_idx:int,dataloader_idx:int)->None:ifnotself.has_dataloader_changed(dataloader_idx):returniftrainer.sanity_checking:ifself.val_sanity_progress_bar_idisnotNone:self.progress.update(self.val_sanity_progress_bar_id,advance=0,visible=False)self.val_sanity_progress_bar_id=self._add_task(self.total_val_batches_current_dataloader,self.sanity_check_description,visible=False)else:ifself.val_progress_bar_idisnotNone:self.progress.update(self.val_progress_bar_id,advance=0,visible=False)# TODO: remove old tasks when new onces are createdself.val_progress_bar_id=self._add_task(self.total_val_batches_current_dataloader,self.validation_description,visible=False)self.refresh()
[docs]defon_validation_batch_end(self,trainer,pl_module,outputs,batch,batch_idx,dataloader_idx):iftrainer.sanity_checking:self._update(self.val_sanity_progress_bar_id,self.val_batch_idx)elifself.val_progress_bar_idisnotNone:# check to see if we should update the main training progress barifself.main_progress_bar_idisnotNone:self._update(self.main_progress_bar_id,self.train_batch_idx+self._val_processed)self._update(self.val_progress_bar_id,self.val_batch_idx)self.refresh()
def_get_train_description(self,current_epoch:int)->str:train_description=f"Epoch {current_epoch}"iflen(self.validation_description)>len(train_description):# Padding is required to avoid flickering due of uneven lengths of "Epoch X"# and "Validation" Bar descriptionnum_digits=len(str(current_epoch))required_padding=(len(self.validation_description)-len(train_description)+1)-num_digitsfor_inrange(required_padding):train_description+=" "returntrain_descriptiondef_stop_progress(self)->None:ifself.progressisnotNone:self.progress.stop()# # signals for progress to be re-initialized for next stagesself._progress_stopped=Truedef_reset_progress_bar_ids(self):self.main_progress_bar_id:Optional[int]=Noneself.val_progress_bar_id:Optional[int]=Noneself.test_progress_bar_id:Optional[int]=Noneself.predict_progress_bar_id:Optional[int]=Nonedef_update_metrics(self,trainer,pl_module)->None:metrics=self.get_metrics(trainer,pl_module)ifself._metric_component:self._metric_component.update(metrics)
def_detect_light_colab_theme()->bool:"""Detect if it's light theme in Colab."""try:get_ipython# type: ignoreexceptNameError:returnFalseipython=get_ipython()# noqa: F821if"google.colab"instr(ipython.__class__):try:fromgoogle.colabimportoutputreturnoutput.eval_js('document.documentElement.matches("[theme=light]")')exceptModuleNotFoundError:returnFalsereturnFalse
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.