The Lightning Trainer is empowered with a lot of flags that can help you debug your LightningModule.

The following are flags that make debugging much easier.

Quick Unit Testing


This flag runs a “unit test” by running N if set to N (int) else 1 if set to True training, validation, testing and predict batch(es) for a single epoch. The point is to have a dry run to detect any bugs in the respective loop without having to wait for a complete loop to crash.

Internally, it just updates limit_<train/test/val/predict>_batches=fast_dev_run and sets max_epoch=1 to limit the batches.

(See: fast_dev_run argument of Trainer)

# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)

# runs 7 train, val, test batches and program ends
trainer = Trainer(fast_dev_run=7)


This argument will disable tuner, checkpoint callbacks, early stopping callbacks, loggers and logger callbacks like LearningRateMonitor and DeviceStatsMonitor.

Shorten Epochs

Sometimes it’s helpful to only use a fraction of your training, val, test, or predict data (or a set number of batches). For example, you can use 20% of the training set and 1% of the validation set.

On larger datasets like Imagenet, this can help you debug or test a few things faster than waiting for a full epoch.

# use only 10% of training data and 1% of val data
trainer = Trainer(limit_train_batches=0.1, limit_val_batches=0.01)

# use 10 batches of train and 5 batches of val
trainer = Trainer(limit_train_batches=10, limit_val_batches=5)

Validation Sanity Check

Lightning runs a few steps of validation in the beginning of training. This avoids crashing in the validation loop sometime deep into a lengthy training loop.

(See: num_sanity_val_steps argument of Trainer)

trainer = Trainer(num_sanity_val_steps=2)

Make Model Overfit on Subset of Data

A good debugging technique is to take a tiny portion of your data (say 2 samples per class), and try to get your model to overfit. If it can’t, it’s a sign it won’t work with large datasets.

(See: overfit_batches argument of Trainer)

# use only 1% of training data (and turn off validation)
trainer = Trainer(overfit_batches=0.01)

# similar, but with a fixed 10 batches
trainer = Trainer(overfit_batches=10)

When using this flag, validation will be disabled. We will also replace the sampler in the training set to turn off shuffle for you.


Inspect Gradient Norms

Logs the norm of the gradients to the logger.

(See: track_grad_norm argument of Trainer)

# the 2-norm
trainer = Trainer(track_grad_norm=2)

Detect Anomaly

You can enable anomaly detection for the autograd engine. It uses PyTorch’s built-in Anomaly Detection Context-manager.

To enable it within Lightning, use Trainer’s flag:

trainer = Trainer(detect_anomaly=True)


Log Device Statistics

Monitor and log device stats during training with the DeviceStatsMonitor.

from pytorch_lightning.callbacks import DeviceStatsMonitor

trainer = Trainer(callbacks=[DeviceStatsMonitor()])


Check out the Profiler document.

Model Statistics

Debugging with Distributed Strategies

DDP Debugging

If you are having a hard time debugging DDP on your remote machine you can debug DDP locally on the CPU. Note that this will not provide any speed benefits.

trainer = Trainer(accelerator="cpu", strategy="ddp", devices=2)

To inspect your code, you can use pdb or breakpoint() or use regular print statements.

class LitModel(LightningModule):
    def training_step(self, batch, batch_idx):

        debugging_message = ...
        print(f"RANK - {self.trainer.global_rank}: {debugging_message}")

        if self.trainer.global_rank == 0:
            import pdb


        # to prevent other processes from moving forward until all processes are in sync

When everything works, switch back to GPU by changing only the accelerator.

trainer = Trainer(accelerator="gpu", strategy="ddp", devices=2)