Shortcuts

Source code for pytorch_lightning.callbacks.progress.tqdm_progress

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import math
import os
import sys
from typing import Any, Dict, Optional, Union

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed

if importlib.util.find_spec("ipywidgets") is not None:
    from tqdm.auto import tqdm as _tqdm
else:
    from tqdm import tqdm as _tqdm

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities.rank_zero import rank_zero_debug

_PAD_SIZE = 5


class Tqdm(_tqdm):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from
        flickering."""
        # this just to make the make docs happy, otherwise it pulls docs which has some issues...
        super().__init__(*args, **kwargs)

    @staticmethod
    def format_num(n: Union[int, float, str]) -> str:
        """Add additional padding to the formatted numbers."""
        should_be_padded = isinstance(n, (float, str))
        if not isinstance(n, str):
            n = _tqdm.format_num(n)
            assert isinstance(n, str)
        if should_be_padded and "e" not in n:
            if "." not in n and len(n) < _PAD_SIZE:
                try:
                    _ = float(n)
                except ValueError:
                    return n
                n += "."
            n += "0" * (_PAD_SIZE - len(n))
        return n


[docs]class TQDMProgressBar(ProgressBarBase): r""" This is the default progress bar used by Lightning. It prints to ``stdout`` using the :mod:`tqdm` package and shows up to four different bars: - **sanity check progress:** the progress during the sanity check run - **main progress:** shows training + validation progress combined. It also accounts for multiple validation runs during training when :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. - **validation progress:** only visible during validation; shows total progress over all validation datasets. - **test progress:** only active when testing; shows total progress over all test datasets. For infinite datasets, the progress bar never ends. If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the :class:`~pytorch_lightning.trainer.trainer.Trainer`. Example: >>> class LitProgressBar(TQDMProgressBar): ... def init_validation_tqdm(self): ... bar = super().init_validation_tqdm() ... bar.set_description('running validation ...') ... return bar ... >>> bar = LitProgressBar() >>> from pytorch_lightning import Trainer >>> trainer = Trainer(callbacks=[bar]) Args: refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display. By default, the :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress bar and sets the refresh rate to the value provided to the :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. process_position: Set this to a value greater than ``0`` to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. This corresponds to :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. """ def __init__(self, refresh_rate: int = 1, process_position: int = 0): super().__init__() self._refresh_rate = self._resolve_refresh_rate(refresh_rate) self._process_position = process_position self._enabled = True self._main_progress_bar: Optional[_tqdm] = None self._val_progress_bar: Optional[_tqdm] = None self._test_progress_bar: Optional[_tqdm] = None self._predict_progress_bar: Optional[_tqdm] = None def __getstate__(self) -> Dict: # can't pickle the tqdm objects return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()} @property def main_progress_bar(self) -> _tqdm: if self._main_progress_bar is None: raise TypeError(f"The `{self.__class__.__name__}._main_progress_bar` reference has not been set yet.") return self._main_progress_bar @main_progress_bar.setter def main_progress_bar(self, bar: _tqdm) -> None: self._main_progress_bar = bar @property def val_progress_bar(self) -> _tqdm: if self._val_progress_bar is None: raise TypeError(f"The `{self.__class__.__name__}._val_progress_bar` reference has not been set yet.") return self._val_progress_bar @val_progress_bar.setter def val_progress_bar(self, bar: _tqdm) -> None: self._val_progress_bar = bar @property def test_progress_bar(self) -> _tqdm: if self._test_progress_bar is None: raise TypeError(f"The `{self.__class__.__name__}._test_progress_bar` reference has not been set yet.") return self._test_progress_bar @test_progress_bar.setter def test_progress_bar(self, bar: _tqdm) -> None: self._test_progress_bar = bar @property def predict_progress_bar(self) -> _tqdm: if self._predict_progress_bar is None: raise TypeError(f"The `{self.__class__.__name__}._predict_progress_bar` reference has not been set yet.") return self._predict_progress_bar @predict_progress_bar.setter def predict_progress_bar(self, bar: _tqdm) -> None: self._predict_progress_bar = bar @property def refresh_rate(self) -> int: return self._refresh_rate @property def process_position(self) -> int: return self._process_position @property def is_enabled(self) -> bool: return self._enabled and self.refresh_rate > 0 @property def is_disabled(self) -> bool: return not self.is_enabled @property def _val_processed(self) -> int: # use total in case validation runs more than once per training epoch return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed
[docs] def disable(self) -> None: self._enabled = False
[docs] def enable(self) -> None: self._enabled = True
[docs] def init_sanity_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for the validation sanity run.""" bar = Tqdm( desc=self.sanity_check_description, position=(2 * self.process_position), disable=self.is_disabled, leave=False, dynamic_ncols=True, file=sys.stdout, ) return bar
[docs] def init_train_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for training.""" bar = Tqdm( desc=self.train_description, initial=self.train_batch_idx, position=(2 * self.process_position), disable=self.is_disabled, leave=True, dynamic_ncols=True, file=sys.stdout, smoothing=0, ) return bar
[docs] def init_predict_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for predicting.""" bar = Tqdm( desc=self.predict_description, initial=self.train_batch_idx, position=(2 * self.process_position), disable=self.is_disabled, leave=True, dynamic_ncols=True, file=sys.stdout, smoothing=0, ) return bar
[docs] def init_validation_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for validation.""" # The main progress bar doesn't exist in `trainer.validate()` has_main_bar = self.trainer.state.fn != "validate" bar = Tqdm( desc=self.validation_description, position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=not has_main_bar, dynamic_ncols=True, file=sys.stdout, ) return bar
[docs] def init_test_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for testing.""" bar = Tqdm( desc="Testing", position=(2 * self.process_position), disable=self.is_disabled, leave=True, dynamic_ncols=True, file=sys.stdout, ) return bar
[docs] def on_sanity_check_start(self, *_: Any) -> None: self.val_progress_bar = self.init_sanity_tqdm() self.main_progress_bar = Tqdm(disable=True) # dummy progress bar
[docs] def on_sanity_check_end(self, *_: Any) -> None: self.main_progress_bar.close() self.val_progress_bar.close()
[docs] def on_train_start(self, *_: Any) -> None: self.main_progress_bar = self.init_train_tqdm()
[docs] def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float("inf") and total_val_batches != float("inf"): # val can be checked multiple times per epoch val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches self.main_progress_bar.total = convert_inf(total_batches) self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
[docs] def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: if self._should_update(self.train_batch_idx, self.total_train_batches): _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
[docs] def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not self.main_progress_bar.disable: self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
[docs] def on_train_end(self, *_: Any) -> None: self.main_progress_bar.close()
[docs] def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not trainer.sanity_checking: self.val_progress_bar = self.init_validation_tqdm()
[docs] def on_validation_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: if not self.has_dataloader_changed(dataloader_idx): return self.val_progress_bar.total = convert_inf(self.total_val_batches_current_dataloader) desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
[docs] def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: if self._should_update(self.val_batch_idx, self.total_val_batches_current_dataloader): _update_n(self.val_progress_bar, self.val_batch_idx) if trainer.state.fn == "fit": _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
[docs] def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._main_progress_bar is not None and trainer.state.fn == "fit": self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() self.reset_dataloader_idx_tracker()
[docs] def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar = self.init_test_tqdm()
[docs] def on_test_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: if not self.has_dataloader_changed(dataloader_idx): return self.test_progress_bar.total = convert_inf(self.total_test_batches_current_dataloader) self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
[docs] def on_test_batch_end(self, *_: Any) -> None: if self._should_update(self.test_batch_idx, self.total_test_batches_current_dataloader): _update_n(self.test_progress_bar, self.test_batch_idx)
[docs] def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar.close() self.reset_dataloader_idx_tracker()
[docs] def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar = self.init_predict_tqdm()
[docs] def on_predict_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int ) -> None: if not self.has_dataloader_changed(dataloader_idx): return self.predict_progress_bar.total = convert_inf(self.total_predict_batches_current_dataloader) self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
[docs] def on_predict_batch_end(self, *_: Any) -> None: if self._should_update(self.predict_batch_idx, self.total_predict_batches_current_dataloader): _update_n(self.predict_progress_bar, self.predict_batch_idx)
[docs] def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar.close() self.reset_dataloader_idx_tracker()
[docs] def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None: active_progress_bar = None if self._main_progress_bar is not None and not self.main_progress_bar.disable: active_progress_bar = self.main_progress_bar elif self._val_progress_bar is not None and not self.val_progress_bar.disable: active_progress_bar = self.val_progress_bar elif self._test_progress_bar is not None and not self.test_progress_bar.disable: active_progress_bar = self.test_progress_bar elif self._predict_progress_bar is not None and not self.predict_progress_bar.disable: active_progress_bar = self.predict_progress_bar if active_progress_bar is not None: s = sep.join(map(str, args)) active_progress_bar.write(s, **kwargs)
def _should_update(self, current: int, total: Union[int, float]) -> bool: return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total) @staticmethod def _resolve_refresh_rate(refresh_rate: int) -> int: if os.getenv("COLAB_GPU") and refresh_rate == 1: # smaller refresh rate on colab causes crashes, choose a higher value rank_zero_debug("Using a higher refresh rate on Colab. Setting it to `20`") refresh_rate = 20 return refresh_rate
def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: """The tqdm doesn't support inf/nan values. We have to convert it to None. """ if x is None or math.isinf(x) or math.isnan(x): return None return x def _update_n(bar: _tqdm, value: int) -> None: if not bar.disable: bar.n = value bar.refresh()

© Copyright Copyright (c) 2018-2023, William Falcon et al...

Built with Sphinx using a theme provided by Read the Docs.