WandbLogger¶
- class pytorch_lightning.loggers.WandbLogger(name=None, save_dir='.', version=None, offline=False, dir=None, id=None, anonymous=None, project='lightning_logs', log_model=False, experiment=None, prefix='', **kwargs)[source]¶
Bases:
pytorch_lightning.loggers.logger.Logger
Log using Weights and Biases.
Installation and set-up
Install with pip:
pip install wandb
Create a WandbLogger instance:
from pytorch_lightning.loggers import WandbLogger wandb_logger = WandbLogger(project="MNIST")
Pass the logger instance to the Trainer:
trainer = Trainer(logger=wandb_logger)
A new W&B run will be created when training starts if you have not created one manually before with wandb.init().
Log metrics
Log from
LightningModule
:class LitModule(LightningModule): def training_step(self, batch, batch_idx): self.log("train/loss", loss)
Use directly wandb module:
wandb.log({"train/loss": loss})
Log hyper-parameters
Save
LightningModule
parameters:class LitModule(LightningModule): def __init__(self, *args, **kwarg): self.save_hyperparameters()
Add other config parameters:
# add one parameter wandb_logger.experiment.config["key"] = value # add multiple parameters wandb_logger.experiment.config.update({key1: val1, key2: val2}) # use directly wandb module wandb.config["key"] = value wandb.config.update()
Log gradients, parameters and model topology
Call the watch method for automatically tracking gradients:
# log gradients and model topology wandb_logger.watch(model) # log gradients, parameter histogram and model topology wandb_logger.watch(model, log="all") # change log frequency of gradients and parameters (100 steps by default) wandb_logger.watch(model, log_freq=500) # do not log graph (in case of errors) wandb_logger.watch(model, log_graph=False)
The watch method adds hooks to the model which can be removed at the end of training:
wandb_logger.experiment.unwatch(model)
Log model checkpoints
Log model checkpoints at the end of training:
wandb_logger = WandbLogger(log_model=True)
Log model checkpoints as they get created during training:
wandb_logger = WandbLogger(log_model="all")
Custom checkpointing can be set up through
ModelCheckpoint
:# log model only if `val_accuracy` increases wandb_logger = WandbLogger(log_model="all") checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max") trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
latest and best aliases are automatically set to easily retrieve a model checkpoint:
# reference can be retrieved in artifacts panel # "VERSION" can be a version (ex: "v2") or an alias ("latest or "best") checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION" # download checkpoint locally (if not already cached) run = wandb.init(project="MNIST") artifact = run.use_artifact(checkpoint_reference, type="model") artifact_dir = artifact.download() # load checkpoint model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
Log media
Log text with:
# using columns and data columns = ["input", "label", "prediction"] data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]] wandb_logger.log_text(key="samples", columns=columns, data=data) # using a pandas DataFrame wandb_logger.log_text(key="samples", dataframe=my_dataframe)
Log images with:
# using tensors, numpy arrays or PIL images wandb_logger.log_image(key="samples", images=[img1, img2]) # adding captions wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"]) # using file path wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])
More arguments can be passed for logging segmentation masks and bounding boxes. Refer to Image Overlays documentation.
Log Tables
W&B Tables can be used to log, query and analyze tabular data.
They support any type of media (text, image, video, audio, molecule, html, etc) and are great for storing, understanding and sharing any form of data, from datasets to model predictions.
columns = ["caption", "image", "sound"] data = [["cheese", wandb.Image(img_1), wandb.Audio(snd_1)], ["wine", wandb.Image(img_2), wandb.Audio(snd_2)]] wandb_logger.log_table(key="samples", columns=columns, data=data)
Downloading and Using Artifacts
To download an artifact without starting a run, call the
download_artifact
function on the class:from pytorch_lightning.loggers import WandbLogger artifact_dir = WandbLogger.download_artifact(artifact="path/to/artifact")
To download an artifact and link it to an ongoing run call the
download_artifact
function on the logger instance:class MyModule(LightningModule): def any_lightning_module_function_or_hook(self): self.logger.download_artifact(artifact="path/to/artifact")
To link an artifact from a previous run you can use
use_artifact
function:from pytorch_lightning.loggers import WandbLogger wandb_logger = WandbLogger(project="my_project", name="my_run") wandb_logger.use_artifact(artifact="path/to/artifact")
See also
Demo in Google Colab with hyperparameter search and model logging
- Parameters:
version¶ (
Optional
[str
]) – Sets the version, mainly used to resume a previous run.offline¶ (
bool
) – Run offline (data can be streamed later to wandb servers).anonymous¶ (
Optional
[bool
]) – Enables or explicitly disables anonymous logging.project¶ (
str
) – The name of the project to which this run will belong.log_model¶ (
Union
[str
,bool
]) –Log checkpoints created by
ModelCheckpoint
as W&B artifacts. latest and best aliases are automatically set.if
log_model == 'all'
, checkpoints are logged during training.if
log_model == True
, checkpoints are logged at the end of training, except whensave_top_k
== -1
which also logs every checkpoint during training.if
log_model == False
(default), no checkpoint is logged.
prefix¶ (
str
) – A string to put at the beginning of metric keys.experiment¶ (
None
) – WandB experiment object. Automatically set when creating a run.**kwargs¶ (
Any
) – Arguments passed towandb.init()
like entity, group, tags, etc.
- Raises:
ModuleNotFoundError – If required WandB package is not installed on the device.
MisconfigurationException – If both
log_model
andoffline
is set toTrue
.
- after_save_checkpoint(checkpoint_callback)[source]¶
Called after model checkpoint callback saves a new checkpoint.
- static download_artifact(artifact, save_dir=None, artifact_type=None, use_artifact=True)[source]¶
Downloads an artifact from the wandb server.
- Parameters:
- Return type:
- Returns:
The path to the downloaded artifact.
- log_image(key, images, step=None, **kwargs)[source]¶
Log images (tensors, numpy arrays, PIL Images or file paths).
Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
- Return type:
- log_metrics(metrics, step=None)[source]¶
Records metrics. This method logs metrics as soon as it received them.
- log_table(key, columns=None, data=None, dataframe=None, step=None)[source]¶
Log a Table containing any object type (text, image, audio, video, molecule, html, etc).
Can be defined either with columns and data or with dataframe.
- Return type:
- log_text(key, columns=None, data=None, dataframe=None, step=None)[source]¶
Log text as a Table.
Can be defined either with columns and data or with dataframe.
- Return type:
- use_artifact(artifact, artifact_type=None)[source]¶
Logs to the wandb dashboard that the mentioned artifact is used by the run.
- property experiment: None¶
Actual wandb object. To use wandb features in your
LightningModule
do the following.Example:
.. code-block:: python
self.logger.experiment.some_wandb_function()