Source code for lightning.pytorch.loggers.tensorboard
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""TensorBoard Logger------------------"""importosfromargparseimportNamespacefromtypingimportAny,Optional,UnionfromtorchimportTensorfromtyping_extensionsimportoverrideimportlightning.pytorchasplfromlightning.fabric.loggers.tensorboardimport_TENSORBOARD_AVAILABLEfromlightning.fabric.loggers.tensorboardimportTensorBoardLoggerasFabricTensorBoardLoggerfromlightning.fabric.utilities.cloud_ioimport_is_dirfromlightning.fabric.utilities.loggerimport_convert_paramsfromlightning.fabric.utilities.typesimport_PATHfromlightning.pytorch.callbacksimportModelCheckpointfromlightning.pytorch.core.savingimportsave_hparams_to_yamlfromlightning.pytorch.loggers.loggerimportLoggerfromlightning.pytorch.utilities.importsimport_OMEGACONF_AVAILABLEfromlightning.pytorch.utilities.rank_zeroimportrank_zero_only,rank_zero_warn
[docs]classTensorBoardLogger(Logger,FabricTensorBoardLogger):r"""Log to local or remote file system in `TensorBoard <https://www.tensorflow.org/tensorboard>`_ format. Implemented using :class:`~tensorboardX.SummaryWriter`. Logs are saved to ``os.path.join(save_dir, name, version)``. This is the default logger in Lightning, it comes preinstalled. This logger supports logging to remote filesystems via ``fsspec``. Make sure you have it installed and you don't have tensorflow (otherwise it will use tf.io.gfile instead of fsspec). Example: .. testcode:: :skipif: not _TENSORBOARD_AVAILABLE or not _TENSORBOARDX_AVAILABLE from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger logger = TensorBoardLogger("tb_logs", name="my_model") trainer = Trainer(logger=logger) Args: save_dir: Save directory name: Experiment name. Defaults to ``'default'``. If it is the empty string then no per-experiment subdirectory is used. version: Experiment version. If version is not specified the logger inspects the save directory for existing versions, then automatically assigns the next available version. If it is a string then it is used as the run-specific subdirectory name, otherwise ``'version_${version}'`` is used. log_graph: Adds the computational graph to tensorboard. This requires that the user has defined the `self.example_input_array` attribute in their model. default_hp_metric: Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is called without a metric (otherwise calls to log_hyperparams without a metric are ignored). prefix: A string to put at the beginning of metric keys. sub_dir: Sub-directory to group TensorBoard logs. If a sub_dir argument is passed then logs are saved in ``/save_dir/name/version/sub_dir/``. Defaults to ``None`` in which logs are saved in ``/save_dir/name/version/``. \**kwargs: Additional arguments used by :class:`tensorboardX.SummaryWriter` can be passed as keyword arguments in this logger. To automatically flush to disk, `max_queue` sets the size of the queue for pending logs before flushing. `flush_secs` determines how many seconds elapses before flushing. """NAME_HPARAMS_FILE="hparams.yaml"def__init__(self,save_dir:_PATH,name:Optional[str]="lightning_logs",version:Optional[Union[int,str]]=None,log_graph:bool=False,default_hp_metric:bool=True,prefix:str="",sub_dir:Optional[_PATH]=None,**kwargs:Any,):super().__init__(root_dir=save_dir,name=name,version=version,default_hp_metric=default_hp_metric,prefix=prefix,sub_dir=sub_dir,**kwargs,)iflog_graphandnot_TENSORBOARD_AVAILABLE:rank_zero_warn("You set `TensorBoardLogger(log_graph=True)` but `tensorboard` is not available.\n"f"{str(_TENSORBOARD_AVAILABLE)}")self._log_graph=log_graphand_TENSORBOARD_AVAILABLEself.hparams:Union[dict[str,Any],Namespace]={}@property@overridedefroot_dir(self)->str:"""Parent directory for all tensorboard checkpoint subdirectories. If the experiment name parameter is an empty string, no experiment subdirectory is used and the checkpoint will be saved in "save_dir/version" """returnos.path.join(super().root_dir,self.name)@property@overridedeflog_dir(self)->str:"""The directory for this run's tensorboard checkpoint. By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the constructor's version parameter instead of ``None`` or an int. """# create a pseudo standard path ala test-tubeversion=self.versionifisinstance(self.version,str)elsef"version_{self.version}"log_dir=os.path.join(self.root_dir,version)ifisinstance(self.sub_dir,str):log_dir=os.path.join(log_dir,self.sub_dir)log_dir=os.path.expandvars(log_dir)log_dir=os.path.expanduser(log_dir)returnlog_dir@property@overridedefsave_dir(self)->str:"""Gets the save directory where the TensorBoard experiments are saved. Returns: The local path to the save directory where the TensorBoard experiments are saved. """returnself._root_dir
[docs]@override@rank_zero_onlydeflog_hyperparams(self,params:Union[dict[str,Any],Namespace],metrics:Optional[dict[str,Any]]=None,step:Optional[int]=None,)->None:"""Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to display the new ones with hyperparameters. Args: params: A dictionary-like container with the hyperparameters metrics: Dictionary with metric names as keys and measured quantities as values step: Optional global step number for the logged metrics """if_OMEGACONF_AVAILABLE:fromomegaconfimportContainer,OmegaConfparams=_convert_params(params)# store params to outputif_OMEGACONF_AVAILABLEandisinstance(params,Container):self.hparams=OmegaConf.merge(self.hparams,params)else:self.hparams.update(params)returnsuper().log_hyperparams(params=params,metrics=metrics,step=step)
[docs]@override@rank_zero_onlydeflog_graph(# type: ignore[override]self,model:"pl.LightningModule",input_array:Optional[Tensor]=None)->None:ifnotself._log_graph:returninput_array=model.example_input_arrayifinput_arrayisNoneelseinput_arrayifinput_arrayisNone:rank_zero_warn("Could not log computational graph to TensorBoard: The `model.example_input_array` attribute"" is not set or `input_array` was not given.")elifnotisinstance(input_array,(Tensor,tuple)):rank_zero_warn("Could not log computational graph to TensorBoard: The `input_array` or `model.example_input_array`"f" has type {type(input_array)} which can't be traced by TensorBoard. Make the input array a tuple"f" representing the positional arguments to the model's `forward()` implementation.")else:input_array=model._on_before_batch_transfer(input_array)input_array=model._apply_batch_transfer_handler(input_array)withpl.core.module._jit_is_scripting():self.experiment.add_graph(model,input_array)
[docs]@override@rank_zero_onlydefsave(self)->None:super().save()dir_path=self.log_dir# prepare the file pathhparams_file=os.path.join(dir_path,self.NAME_HPARAMS_FILE)# save the metatags file if it doesn't exist and the log directory existsif_is_dir(self._fs,dir_path)andnotself._fs.isfile(hparams_file):save_hparams_to_yaml(hparams_file,self.hparams)
[docs]@override@rank_zero_onlydeffinalize(self,status:str)->None:super().finalize(status)ifstatus=="success":# saving hparams happens independent of experiment managerself.save()
[docs]@overridedefafter_save_checkpoint(self,checkpoint_callback:ModelCheckpoint)->None:"""Called after model checkpoint callback saves a new checkpoint. Args: checkpoint_callback: the model checkpoint callback instance """pass
To analyze traffic and optimize your experience, we serve cookies on this
site. By clicking or navigating, you agree to allow our usage of cookies.
Read PyTorch Lightning's
Privacy Policy.
You are viewing an outdated version of PyTorch Lightning Docs