Source code for pytorch_lightning.callbacks.xla_stats_monitor
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""XLA Stats Monitor=================Monitor and logs XLA stats during training."""importtimeimportpytorch_lightningasplfrompytorch_lightning.acceleratorsimportTPUAcceleratorfrompytorch_lightning.callbacks.baseimportCallbackfrompytorch_lightning.utilitiesimport_TPU_AVAILABLEfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.rank_zeroimportrank_zero_deprecation,rank_zero_infoif_TPU_AVAILABLE:importtorch_xla.core.xla_modelasxm
[docs]classXLAStatsMonitor(Callback):r""" .. deprecated:: v1.5 The `XLAStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7. Please use the `DeviceStatsMonitor` callback instead. Automatically monitors and logs XLA stats during training stage. ``XLAStatsMonitor`` is a callback and in order to use it you need to assign a logger in the ``Trainer``. Args: verbose: Set to ``True`` to print average peak and free memory, and epoch time every epoch. Raises: MisconfigurationException: If not running on TPUs, or ``Trainer`` has no logger. Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import XLAStatsMonitor >>> xla_stats = XLAStatsMonitor() # doctest: +SKIP >>> trainer = Trainer(callbacks=[xla_stats]) # doctest: +SKIP """def__init__(self,verbose:bool=True)->None:super().__init__()rank_zero_deprecation("The `XLAStatsMonitor` callback was deprecated in v1.5 and will be removed in v1.7."" Please use the `DeviceStatsMonitor` callback instead.")ifnot_TPU_AVAILABLE:raiseMisconfigurationException("Cannot use XLAStatsMonitor with TPUs are not available")self._verbose=verbose
[docs]defon_train_start(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:ifnottrainer.loggers:raiseMisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")ifnotisinstance(trainer.accelerator,TPUAccelerator):raiseMisconfigurationException("You are using XLAStatsMonitor but are not running on TPU."f" The accelerator is set to {trainer.accelerator.__class__.__name__}.")device=trainer.strategy.root_devicememory_info=xm.get_memory_info(device)total_memory=trainer.strategy.reduce(memory_info["kb_total"])*0.001rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
[docs]defon_train_epoch_end(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:ifnottrainer.loggers:raiseMisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")device=trainer.strategy.root_devicememory_info=xm.get_memory_info(device)epoch_time=time.time()-self._start_timefree_memory=memory_info["kb_free"]peak_memory=memory_info["kb_total"]-free_memoryfree_memory=trainer.strategy.reduce(free_memory)*0.001peak_memory=trainer.strategy.reduce(peak_memory)*0.001epoch_time=trainer.strategy.reduce(epoch_time)forloggerintrainer.loggers:logger.log_metrics({"avg. free memory (MB)":float(free_memory),"avg. peak memory (MB)":float(peak_memory)},step=trainer.current_epoch,)ifself._verbose:rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")rank_zero_info(f"Average Peak memory: {peak_memory:.2f} MB")rank_zero_info(f"Average Free memory: {free_memory:.2f} MB")
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.