# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Weights and Biases Logger-------------------------"""importosfromargparseimportNamespacefrompathlibimportPathfromtypingimportAny,Dict,List,Mapping,Optional,Unionimporttorch.nnasnnfromlightning_utilities.core.importsimportRequirementCachefromtorchimportTensorfromlightning_fabric.utilities.loggerimport_add_prefix,_convert_params,_flatten_dict,_sanitize_callable_paramsfromlightning_fabric.utilities.typesimport_PATHfrompytorch_lightning.callbacks.model_checkpointimportModelCheckpointfrompytorch_lightning.loggers.loggerimportLogger,rank_zero_experimentfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.loggerimport_scan_checkpointsfrompytorch_lightning.utilities.rank_zeroimportrank_zero_only,rank_zero_warntry:importwandbfromwandb.sdk.libimportRunDisabledfromwandb.wandb_runimportRunexceptModuleNotFoundError:# needed for test mocks, these tests shall be updatedwandb,Run,RunDisabled=None,None,None_WANDB_AVAILABLE=RequirementCache("wandb")_WANDB_GREATER_EQUAL_0_10_22=RequirementCache("wandb>=0.10.22")_WANDB_GREATER_EQUAL_0_12_10=RequirementCache("wandb>=0.12.10")
[docs]classWandbLogger(Logger):r""" Log using `Weights and Biases <https://docs.wandb.ai/integrations/lightning>`_. **Installation and set-up** Install with pip: .. code-block:: bash pip install wandb Create a `WandbLogger` instance: .. code-block:: python from pytorch_lightning.loggers import WandbLogger wandb_logger = WandbLogger(project="MNIST") Pass the logger instance to the `Trainer`: .. code-block:: python trainer = Trainer(logger=wandb_logger) A new W&B run will be created when training starts if you have not created one manually before with `wandb.init()`. **Log metrics** Log from :class:`~pytorch_lightning.core.module.LightningModule`: .. code-block:: python class LitModule(LightningModule): def training_step(self, batch, batch_idx): self.log("train/loss", loss) Use directly wandb module: .. code-block:: python wandb.log({"train/loss": loss}) **Log hyper-parameters** Save :class:`~pytorch_lightning.core.module.LightningModule` parameters: .. code-block:: python class LitModule(LightningModule): def __init__(self, *args, **kwarg): self.save_hyperparameters() Add other config parameters: .. code-block:: python # add one parameter wandb_logger.experiment.config["key"] = value # add multiple parameters wandb_logger.experiment.config.update({key1: val1, key2: val2}) # use directly wandb module wandb.config["key"] = value wandb.config.update() **Log gradients, parameters and model topology** Call the `watch` method for automatically tracking gradients: .. code-block:: python # log gradients and model topology wandb_logger.watch(model) # log gradients, parameter histogram and model topology wandb_logger.watch(model, log="all") # change log frequency of gradients and parameters (100 steps by default) wandb_logger.watch(model, log_freq=500) # do not log graph (in case of errors) wandb_logger.watch(model, log_graph=False) The `watch` method adds hooks to the model which can be removed at the end of training: .. code-block:: python wandb_logger.experiment.unwatch(model) **Log model checkpoints** Log model checkpoints at the end of training: .. code-block:: python wandb_logger = WandbLogger(log_model=True) Log model checkpoints as they get created during training: .. code-block:: python wandb_logger = WandbLogger(log_model="all") Custom checkpointing can be set up through :class:`~pytorch_lightning.callbacks.ModelCheckpoint`: .. code-block:: python # log model only if `val_accuracy` increases wandb_logger = WandbLogger(log_model="all") checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max") trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback]) `latest` and `best` aliases are automatically set to easily retrieve a model checkpoint: .. code-block:: python # reference can be retrieved in artifacts panel # "VERSION" can be a version (ex: "v2") or an alias ("latest or "best") checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION" # download checkpoint locally (if not already cached) run = wandb.init(project="MNIST") artifact = run.use_artifact(checkpoint_reference, type="model") artifact_dir = artifact.download() # load checkpoint model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt") **Log media** Log text with: .. code-block:: python # using columns and data columns = ["input", "label", "prediction"] data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]] wandb_logger.log_text(key="samples", columns=columns, data=data) # using a pandas DataFrame wandb_logger.log_text(key="samples", dataframe=my_dataframe) Log images with: .. code-block:: python # using tensors, numpy arrays or PIL images wandb_logger.log_image(key="samples", images=[img1, img2]) # adding captions wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"]) # using file path wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"]) More arguments can be passed for logging segmentation masks and bounding boxes. Refer to `Image Overlays documentation <https://docs.wandb.ai/guides/track/log/media#image-overlays>`_. **Log Tables** `W&B Tables <https://docs.wandb.ai/guides/data-vis>`_ can be used to log, query and analyze tabular data. They support any type of media (text, image, video, audio, molecule, html, etc) and are great for storing, understanding and sharing any form of data, from datasets to model predictions. .. code-block:: python columns = ["caption", "image", "sound"] data = [["cheese", wandb.Image(img_1), wandb.Audio(snd_1)], ["wine", wandb.Image(img_2), wandb.Audio(snd_2)]] wandb_logger.log_table(key="samples", columns=columns, data=data) **Downloading and Using Artifacts** To download an artifact without starting a run, call the ``download_artifact`` function on the class: .. code-block:: python from pytorch_lightning.loggers import WandbLogger artifact_dir = WandbLogger.download_artifact(artifact="path/to/artifact") To download an artifact and link it to an ongoing run call the ``download_artifact`` function on the logger instance: .. code-block:: python class MyModule(LightningModule): def any_lightning_module_function_or_hook(self): self.logger.download_artifact(artifact="path/to/artifact") To link an artifact from a previous run you can use ``use_artifact`` function: .. code-block:: python from pytorch_lightning.loggers import WandbLogger wandb_logger = WandbLogger(project="my_project", name="my_run") wandb_logger.use_artifact(artifact="path/to/artifact") See Also: - `Demo in Google Colab <http://wandb.me/lightning>`__ with hyperparameter search and model logging - `W&B Documentation <https://docs.wandb.ai/integrations/lightning>`__ Args: name: Display name for the run. save_dir: Path where data is saved. version: Sets the version, mainly used to resume a previous run. offline: Run offline (data can be streamed later to wandb servers). dir: Same as save_dir. id: Same as version. anonymous: Enables or explicitly disables anonymous logging. project: The name of the project to which this run will belong. log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.ModelCheckpoint` as W&B artifacts. `latest` and `best` aliases are automatically set. * if ``log_model == 'all'``, checkpoints are logged during training. * if ``log_model == True``, checkpoints are logged at the end of training, except when :paramref:`~pytorch_lightning.callbacks.ModelCheckpoint.save_top_k` ``== -1`` which also logs every checkpoint during training. * if ``log_model == False`` (default), no checkpoint is logged. prefix: A string to put at the beginning of metric keys. experiment: WandB experiment object. Automatically set when creating a run. checkpoint_name: Name of the model checkpoint artifact being logged. \**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc. Raises: ModuleNotFoundError: If required WandB package is not installed on the device. MisconfigurationException: If both ``log_model`` and ``offline`` is set to ``True``. """LOGGER_JOIN_CHAR="-"def__init__(self,name:Optional[str]=None,save_dir:_PATH=".",version:Optional[str]=None,offline:bool=False,dir:Optional[_PATH]=None,id:Optional[str]=None,anonymous:Optional[bool]=None,project:str="lightning_logs",log_model:Union[str,bool]=False,experiment:Union[Run,RunDisabled,None]=None,prefix:str="",checkpoint_name:Optional[str]=None,**kwargs:Any,)->None:ifwandbisNone:raiseModuleNotFoundError("You want to use `wandb` logger which is not installed yet,"" install it with `pip install wandb`."# pragma: no-cover)ifofflineandlog_model:raiseMisconfigurationException(f"Providing log_model={log_model} and offline={offline} is an invalid configuration"" since model checkpoints cannot be uploaded in offline mode.\n""Hint: Set `offline=False` to log your model.")iflog_modelandnot_WANDB_GREATER_EQUAL_0_10_22:rank_zero_warn(f"Providing log_model={log_model} requires wandb version >= 0.10.22"" for logging associated model metadata.\n""Hint: Upgrade with `pip install --upgrade wandb`.")super().__init__()self._offline=offlineself._log_model=log_modelself._prefix=prefixself._experiment=experimentself._logged_model_time:Dict[str,float]={}self._checkpoint_callback:Optional[ModelCheckpoint]=None# paths are processed as stringsifsave_dirisnotNone:save_dir=os.fspath(save_dir)elifdirisnotNone:dir=os.fspath(dir)# set wandb init argumentsself._wandb_init:Dict[str,Any]=dict(name=name,project=project,dir=save_dirordir,id=versionorid,resume="allow",anonymous=("allow"ifanonymouselseNone),)self._wandb_init.update(**kwargs)# extract parametersself._project=self._wandb_init.get("project")self._save_dir=self._wandb_init.get("dir")self._name=self._wandb_init.get("name")self._id=self._wandb_init.get("id")# start wandb run (to create an attach_id for distributed modes)if_WANDB_GREATER_EQUAL_0_12_10:wandb.require("service")_=self.experimentself._checkpoint_name=checkpoint_namedef__getstate__(self)->Dict[str,Any]:state=self.__dict__.copy()# args needed to reload correct experimentifself._experimentisnotNone:state["_id"]=getattr(self._experiment,"id",None)state["_attach_id"]=getattr(self._experiment,"_attach_id",None)state["_name"]=self._experiment.name# cannot be pickledstate["_experiment"]=Nonereturnstate@property@rank_zero_experimentdefexperiment(self)->Union[Run,RunDisabled]:r""" Actual wandb object. To use wandb features in your :class:`~pytorch_lightning.core.module.LightningModule` do the following. Example:: .. code-block:: python self.logger.experiment.some_wandb_function() """ifself._experimentisNone:ifself._offline:os.environ["WANDB_MODE"]="dryrun"attach_id=getattr(self,"_attach_id",None)ifwandb.runisnotNone:# wandb process already created in this instancerank_zero_warn("There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"" this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.")self._experiment=wandb.runelifattach_idisnotNoneandhasattr(wandb,"_attach"):# attach to wandb process referencedself._experiment=wandb._attach(attach_id)else:# create new wandb processself._experiment=wandb.init(**self._wandb_init)# define default x-axisifisinstance(self._experiment,(Run,RunDisabled))andgetattr(self._experiment,"define_metric",None):self._experiment.define_metric("trainer/global_step")self._experiment.define_metric("*",step_metric="trainer/global_step",step_sync=True)assertisinstance(self._experiment,(Run,RunDisabled))returnself._experimentdefwatch(self,model:nn.Module,log:str="gradients",log_freq:int=100,log_graph:bool=True)->None:self.experiment.watch(model,log=log,log_freq=log_freq,log_graph=log_graph)
[docs]@rank_zero_onlydeflog_metrics(self,metrics:Mapping[str,float],step:Optional[int]=None)->None:assertrank_zero_only.rank==0,"experiment tried to log from global_rank != 0"metrics=_add_prefix(metrics,self._prefix,self.LOGGER_JOIN_CHAR)ifstepisnotNone:self.experiment.log(dict(metrics,**{"trainer/global_step":step}))else:self.experiment.log(metrics)
[docs]@rank_zero_onlydeflog_table(self,key:str,columns:Optional[List[str]]=None,data:Optional[List[List[Any]]]=None,dataframe:Any=None,step:Optional[int]=None,)->None:"""Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either with `columns` and `data` or with `dataframe`. """metrics={key:wandb.Table(columns=columns,data=data,dataframe=dataframe)}self.log_metrics(metrics,step)
[docs]@rank_zero_onlydeflog_text(self,key:str,columns:Optional[List[str]]=None,data:Optional[List[List[str]]]=None,dataframe:Any=None,step:Optional[int]=None,)->None:"""Log text as a Table. Can be defined either with `columns` and `data` or with `dataframe`. """self.log_table(key,columns,data,dataframe,step)
[docs]@rank_zero_onlydeflog_image(self,key:str,images:List[Any],step:Optional[int]=None,**kwargs:Any)->None:"""Log images (tensors, numpy arrays, PIL Images or file paths). Optional kwargs are lists passed to each image (ex: caption, masks, boxes). """ifnotisinstance(images,list):raiseTypeError(f'Expected a list as "images", found {type(images)}')n=len(images)fork,vinkwargs.items():iflen(v)!=n:raiseValueError(f"Expected {n} items but only found {len(v)} for {k}")kwarg_list=[{k:kwargs[k][i]forkinkwargs.keys()}foriinrange(n)]metrics={key:[wandb.Image(img,**kwarg)forimg,kwarginzip(images,kwarg_list)]}self.log_metrics(metrics,step)
@propertydefsave_dir(self)->Optional[str]:"""Gets the save directory. Returns: The path to the save directory. """returnself._save_dir@propertydefname(self)->Optional[str]:"""The project name of this experiment. Returns: The name of the project the current experiment belongs to. This name is not the same as `wandb.Run`'s name. To access wandb's internal experiment name, use ``logger.experiment.name`` instead. """returnself._project@propertydefversion(self)->Optional[str]:"""Gets the id of the experiment. Returns: The id of the experiment if the experiment exists else the id given to the constructor. """# don't create an experiment if we don't have onereturnself._experiment.idifself._experimentelseself._id
[docs]defafter_save_checkpoint(self,checkpoint_callback:ModelCheckpoint)->None:# log checkpoints as artifactsifself._log_model=="all"orself._log_modelisTrueandcheckpoint_callback.save_top_k==-1:self._scan_and_log_checkpoints(checkpoint_callback)elifself._log_modelisTrue:self._checkpoint_callback=checkpoint_callback
[docs]@staticmethod@rank_zero_onlydefdownload_artifact(artifact:str,save_dir:Optional[_PATH]=None,artifact_type:Optional[str]=None,use_artifact:Optional[bool]=True,)->str:"""Downloads an artifact from the wandb server. Args: artifact: The path of the artifact to download. save_dir: The directory to save the artifact to. artifact_type: The type of artifact to download. use_artifact: Whether to add an edge between the artifact graph. Returns: The path to the downloaded artifact. """ifwandb.runisnotNoneanduse_artifact:artifact=wandb.run.use_artifact(artifact)else:api=wandb.Api()artifact=api.artifact(artifact,type=artifact_type)save_dir=Noneifsave_dirisNoneelseos.fspath(save_dir)returnartifact.download(root=save_dir)
[docs]defuse_artifact(self,artifact:str,artifact_type:Optional[str]=None)->"wandb.Artifact":"""Logs to the wandb dashboard that the mentioned artifact is used by the run. Args: artifact: The path of the artifact. artifact_type: The type of artifact being used. Returns: wandb Artifact object for the artifact. """returnself.experiment.use_artifact(artifact,type=artifact_type)
[docs]@rank_zero_onlydeffinalize(self,status:str)->None:ifstatus!="success":# Currently, checkpoints only get logged on successreturn# log checkpoints as artifactsifself._checkpoint_callbackandself._experimentisnotNone:self._scan_and_log_checkpoints(self._checkpoint_callback)
def_scan_and_log_checkpoints(self,checkpoint_callback:ModelCheckpoint)->None:# get checkpoints to be saved with associated scorecheckpoints=_scan_checkpoints(checkpoint_callback,self._logged_model_time)# log iteratively all new checkpointsfort,p,s,tagincheckpoints:metadata=({"score":s.item()ifisinstance(s,Tensor)elses,"original_filename":Path(p).name,checkpoint_callback.__class__.__name__:{k:getattr(checkpoint_callback,k)forkin["monitor","mode","save_last","save_top_k","save_weights_only","_every_n_train_steps",]# ensure it does not break if `ModelCheckpoint` args changeifhasattr(checkpoint_callback,k)},}if_WANDB_GREATER_EQUAL_0_10_22elseNone)ifnotself._checkpoint_name:self._checkpoint_name=f"model-{self.experiment.id}"artifact=wandb.Artifact(name=self._checkpoint_name,type="model",metadata=metadata)artifact.add_file(p,name="model.ckpt")self.experiment.log_artifact(artifact,aliases=[tag])# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)self._logged_model_time[p]=t
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.