Track and Visualize Experiments (intermediate)¶
Audience: Users who want to track more complex outputs and use third-party experiment managers.
Track audio and other artifacts¶
To track other artifacts, such as histograms or model topology graphs first select one of the many loggers supported by Lightning
from pytorch_lightning import loggers as pl_loggers
tensorboard = pl_loggers.TensorBoardLogger(save_dir="")
trainer = Trainer(logger=tensorboard)
then access the logger’s API directly
def training_step(self):
tensorboard = self.logger.experiment
tensorboard.add_image()
tensorboard.add_histogram(...)
tensorboard.add_figure(...)
Comet.ml¶
To use Comet.ml first install the comet package:
pip install comet-ml
Configure the logger and pass it to the Trainer
:
from pytorch_lightning.loggers import CometLogger
comet_logger = CometLogger(api_key="YOUR_COMET_API_KEY")
trainer = Trainer(logger=comet_logger)
Access the comet logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts
class LitModel(LightningModule):
def any_lightning_module_function_or_hook(self):
comet = self.logger.experiment
fake_images = torch.Tensor(32, 3, 28, 28)
comet.add_image("generated_images", fake_images, 0)
Here’s the full documentation for the CometLogger
.
MLflow¶
To use MLflow first install the MLflow package:
pip install mlflow
Configure the logger and pass it to the Trainer
:
from pytorch_lightning.loggers import MLFlowLogger
mlf_logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs")
trainer = Trainer(logger=mlf_logger)
Access the comet logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts
class LitModel(LightningModule):
def any_lightning_module_function_or_hook(self):
mlf_logger = self.logger.experiment
fake_images = torch.Tensor(32, 3, 28, 28)
mlf_logger.add_image("generated_images", fake_images, 0)
Here’s the full documentation for the MLFlowLogger
.
Neptune.ai¶
To use Neptune.ai first install the neptune package:
pip install neptune-client
or with conda:
conda install -c conda-forge neptune-client
Configure the logger and pass it to the Trainer
:
from pytorch_lightning.loggers import NeptuneLogger
neptune_logger = NeptuneLogger(
api_key="ANONYMOUS", # replace with your own
project="common/pytorch-lightning-integration", # format "<WORKSPACE/PROJECT>"
)
trainer = Trainer(logger=neptune_logger)
Access the neptune logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts
class LitModel(LightningModule):
def any_lightning_module_function_or_hook(self):
neptune_logger = self.logger.experiment["your/metadata/structure"]
neptune_logger.log(metadata)
Here’s the full documentation for the NeptuneLogger
.
Tensorboard¶
TensorBoard already comes installed with Lightning. If you removed the install install the following package.
pip install tensorboard
Configure the logger and pass it to the Trainer
:
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger()
trainer = Trainer(logger=logger)
Access the tensorboard logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts
class LitModel(LightningModule):
def any_lightning_module_function_or_hook(self):
tensorboard_logger = self.logger.experiment
fake_images = torch.Tensor(32, 3, 28, 28)
tensorboard_logger.add_image("generated_images", fake_images, 0)
Here’s the full documentation for the TensorBoardLogger
.
Weights and Biases¶
To use Weights and Biases (wandb) first install the wandb package:
pip install wandb
Configure the logger and pass it to the Trainer
:
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(project="MNIST", log_model="all")
trainer = Trainer(logger=wandb_logger)
# log gradients and model topology
wandb_logger.watch(model)
Access the wandb logger from any function (except the LightningModule init) to use its API for tracking advanced artifacts
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):
wandb_logger = self.logger.experiment
fake_images = torch.Tensor(32, 3, 28, 28)
# Option 1
wandb_logger.log({"generated_images": [wandb.Image(fake_images, caption="...")]})
# Option 2 for specifically logging images
wandb_logger.log_image(key="generated_images", images=[fake_images])
Here’s the full documentation for the WandbLogger
.
Demo in Google Colab with hyperparameter search and model logging.
Use multiple exp managers¶
To use multiple experiment managers at the same time, pass a list to the logger Trainer
argument.
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
logger1 = TensorBoardLogger()
logger2 = WandbLogger()
trainer = Trainer(logger=[logger1, logger2])
Access all loggers from any function (except the LightningModule init) to use their APIs for tracking advanced artifacts
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):
tensorboard_logger = self.logger.experiment[0]
wandb_logger = self.logger.experiment[1]
fake_images = torch.Tensor(32, 3, 28, 28)
tensorboard_logger.add_image("generated_images", fake_images, 0)
wandb_logger.add_image("generated_images", fake_images, 0)
Track multiple metrics in the same chart¶
If your logger supports plotting multiple metrics on the same chart, pass in a dictionary to self.log.
self.log("performance", {"acc": acc, "recall": recall})
Track hyperparameters¶
To track hyperparameters, first call save_hyperparameters from the LightningModule init:
class MyLightningModule(LightningModule):
def __init__(self, learning_rate, another_parameter, *args, **kwargs):
super().__init__()
self.save_hyperparameters()
If your logger supports tracked hyperparameters, the hyperparameters will automatically show up on the logger dashboard.
TODO: show tracked hyperparameters.
Track model topology¶
Multiple loggers support visualizing the model topology. Here’s an example that tracks the model topology using Tensorboard.
def any_lightning_module_function_or_hook(self):
tensorboard_logger = self.logger.experiment
prototype_array = torch.Tensor(32, 1, 28, 27)
tensorboard_logger.log_graph(model=self, input_array=prototype_array)
TODO: show tensorboard topology.