Shortcuts

Loops

Loops let advanced users swap out the default gradient descent optimization loop at the core of Lightning with a different optimization paradigm.

The Lightning Trainer is built on top of the standard gradient descent optimization loop which works for 90%+ of machine learning use cases:

for i, batch in enumerate(dataloader):
    x, y = batch
    y_hat = model(x)
    loss = loss_function(y_hat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

However, some new research use cases such as meta-learning, active learning, recommendation systems, etc., require a different loop structure. For example here is a simple loop that guides the weight updates with a loss from a special validation split:

for i, batch in enumerate(train_dataloader):
    x, y = batch
    y_hat = model(x)
    loss = loss_function(y_hat, y)
    optimizer.zero_grad()
    loss.backward()

    val_loss = 0
    for i, val_batch in enumerate(val_dataloader):
        x, y = val_batch
        y_hat = model(x)
        val_loss += loss_function(y_hat, y)

    scale_gradients(model, 1 / val_loss)
    optimizer.step()

With Lightning Loops, you can customize to non-standard gradient descent optimizations to get the same loop above:

trainer = Trainer()
trainer.fit_loop.epoch_loop = MyGradientDescentLoop()

Think of this as swapping out the engine in a car!


Understanding the default Trainer loop

The Lightning Trainer automates the standard optimization loop which every PyTorch user is familiar with:

for i, batch in enumerate(dataloader):
    x, y = batch
    y_hat = model(x)
    loss = loss_function(y_hat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

The core research logic is simply shifted to the LightningModule:

for i, batch in enumerate(dataloader):
    # x, y = batch                      moved to training_step
    # y_hat = model(x)                  moved to training_step
    # loss = loss_function(y_hat, y)    moved to training_step
    loss = lightning_module.training_step(batch, i)

    # Lighting handles automatically:
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Under the hood, the above loop is implemented using the Loop API like so:

class DefaultLoop(Loop):
    def advance(self, batch, i):
        loss = lightning_module.training_step(batch, i)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    def run(self, dataloader):
        for i, batch in enumerate(dataloader):
            self.advance(batch, i)

Defining a loop within a class interface instead of hard-coding a raw Python for/while loop has several benefits:

  1. You can have full control over the data flow through loops.

  2. You can add new loops and nest as many of them as you want.

  3. If needed, the state of a loop can be saved and resumed.

  4. New hooks can be injected at any point.

Animation showing how to convert a standard training loop to a Lightning loop

Overriding the default loops

The fastest way to get started with loops, is to override functionality of an existing loop. Lightning has 4 main loops it uses: FitLoop for training and validating, EvaluationLoop for testing, PredictionLoop for predicting.

For simple changes that don’t require a custom loop, you can modify each of these loops.

Each loop has a series of methods that can be modified. For example with the FitLoop:

from pytorch_lightning.loops import FitLoop


class MyLoop(FitLoop):
    def advance(self):
        """Advance from one iteration to the next."""

    def on_advance_end(self):
        """Do something at the end of an iteration."""

    def on_run_end(self):
        """Do something when the loop ends."""

A full list with all built-in loops and subloops can be found here.

To add your own modifications to a loop, simply subclass an existing loop class and override what you need. Here is a simple example how to add a new hook:

from pytorch_lightning.loops import FitLoop


class CustomFitLoop(FitLoop):
    def advance(self):
        # ... whatever code before

        # pass anything you want to the hook
        self.trainer.call_hook("my_new_hook", *args, **kwargs)

        # ... whatever code after

Now simply attach the correct loop in the trainer directly:

trainer = Trainer(...)
trainer.fit_loop = CustomFitLoop()

# fit() now uses the new FitLoop!
trainer.fit(...)

# the equivalent for validate(), test(), predict()
val_loop = CustomValLoop()
trainer = Trainer()
trainer.validate_loop = val_loop
trainer.validate(model)

Now your code is FULLY flexible and you can still leverage ALL the best parts of Lightning!

Animation showing how to replace a loop on the Trainer

Creating a new loop from scratch

You can also go wild and implement a full loop from scratch by sub-classing the Loop base class. You will need to override a minimum of two things:

from pytorch_lightning.loop import Loop


class MyFancyLoop(Loop):
    @property
    def done(self):
        """Provide a condition to stop the loop."""

    def advance(self):
        """
        Access your dataloader/s in whatever way you want.
        Do your fancy optimization things.
        Call the LightningModule methods at your leisure.
        """

Finally, attach it into the Trainer:

trainer = Trainer(...)
trainer.fit_loop = MyFancyLoop()

# fit() now uses your fancy loop!
trainer.fit(...)

Now you have full control over the Trainer. But beware: The power of loop customization comes with great responsibility. We recommend that you familiarize yourself with overriding the default loops first before you start building a new loop from the ground up.


Loop API

Here is the full API of methods available in the Loop base class.

The Loop class is the base for all loops in Lighting just like the LightningModule is the base for all models. It defines a public interface that each loop implementation must follow, the key ones are:

Properties

done

Loop.done

Property indicating when the loop is finished.

Example:

@property
def done(self):
    return self.trainer.global_step >= self.trainer.max_steps
Return type

bool

skip (optional)

Loop.skip

Determine whether to return immediately from the call to run().

Example:

@property
def skip(self):
    return len(self.trainer.train_dataloader) == 0
Return type

bool

Methods

reset (optional)

abstract Loop.reset()[source]

Resets the internal state of the loop at the beginning of each call to run.

Example:

def reset(self):
    # reset your internal state or add custom logic
    # if you expect run() to be called multiple times
    self.current_iteration = 0
    self.outputs = []
Return type

None

advance

abstract Loop.advance(*args, **kwargs)[source]

Performs a single step.

Accepts all arguments passed to run.

Example:

def advance(self, iterator):
    batch = next(iterator)
    loss = self.trainer.lightning_module.training_step(batch, batch_idx)
    ...
Return type

None

run (optional)

Loop.run(*args, **kwargs)[source]

The main entry point to the loop.

Will frequently check the done condition and calls advance until done evaluates to True.

Override this if you wish to change the default behavior. The default implementation is:

Example:

def run(self, *args, **kwargs):
    if self.skip:
        return self.on_skip()

    self.reset()
    self.on_run_start(*args, **kwargs)

    while not self.done:
        self.advance(*args, **kwargs)

    output = self.on_run_end()
    return output
Return type

~T

Returns

The output of on_run_end (often outputs collected from each step of the loop)


Subloops

When you want to customize nested loops within loops, use the connect() method:

# Step 1: create your loop
my_epoch_loop = MyEpochLoop()

# Step 2: use connect()
trainer.fit_loop.connect(epoch_loop=my_epoch_loop)

# Trainer runs the fit loop with your new epoch loop!
trainer.fit(model)

More about the built-in loops and how they are composed is explained in the next section.

Animation showing how to connect a custom subloop

Built-in Loops

The training loop in Lightning is called fit loop and is actually a combination of several loops. Here is what the structure would look like in plain Python:

# FitLoop
for epoch in range(max_epochs):

    # TrainingEpochLoop
    for batch_idx, batch in enumerate(train_dataloader):

        # TrainingBatchLoop
        for split_batch in tbptt_split(batch):

            # OptimizerLoop
            for optimizer_idx, opt in enumerate(optimizers):

                loss = lightning_module.training_step(batch, batch_idx, optimizer_idx)
                ...

        # ValidationEpochLoop
        for batch_idx, batch in enumerate(val_dataloader):
            lightning_module.validation_step(batch, batch_idx, optimizer_idx)
            ...

Each of these for-loops represents a class implementing the Loop interface.

Trainer entry points and associated loops

Built-in loop

Description

FitLoop

The FitLoop is the top-level loop where training starts. It simply counts the epochs and iterates from one to the next by calling TrainingEpochLoop.run() in its advance() method.

TrainingEpochLoop

The TrainingEpochLoop is the one that iterates over the dataloader that the user returns in their train_dataloader() method. Its main responsibilities are calling the *_epoch_start and *_epoch_end hooks, accumulating outputs if the user request them in one of these hooks, and running validation at the requested interval. The validation is carried out by yet another loop, ValidationEpochLoop.

In the run() method, the training epoch loop could in theory simply call the LightningModule.training_step already and perform the optimization. However, Lightning has built-in support for automatic optimization with multiple optimizers and on top of that also supports truncated back-propagation through time. For this reason there are actually two more loops nested under TrainingEpochLoop.

TrainingBatchLoop

The responsibility of the TrainingBatchLoop is to split a batch given by the TrainingEpochLoop along the time-dimension and iterate over the list of splits. It also keeps track of the hidden state hiddens returned by the training step. By default, when truncated back-propagation through time (TBPTT) is turned off, this loop does not do anything except redirect the call to the OptimizerLoop. Read more about TBPTT.

OptimizerLoop

The OptimizerLoop iterates over one or multiple optimizers and for each one it calls the training_step() method with the batch, the current batch index and the optimizer index if multiple optimizers are requested. It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step).

ManualOptimization

Substitutes the OptimizerLoop in case of Manual optimization and implements the manual optimization step.


Available Loops in Lightning Flash

Active Learning is a machine learning practice in which the user interacts with the learner in order to provide new labels when required.

You can find a real use case in Lightning Flash.

Flash implements the ActiveLearningLoop that you can use together with the ActiveLearningDataModule to label new data on the fly. To run the following demo, install Flash and BaaL first:

pip install lightning-flash baal
import torch

import flash
from flash.core.classification import Probabilities
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")

# Implement the research use-case where we mask labels from labelled dataset.
datamodule = ActiveLearningDataModule(
    ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
    val_split=0.1,
)

# 2. Build the task
head = torch.nn.Sequential(
    torch.nn.Dropout(p=0.1),
    torch.nn.Linear(512, datamodule.num_classes),
)
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities())

# 3.1 Create the trainer
trainer = flash.Trainer(max_epochs=3)

# 3.2 Create the active learning loop and connect it to the trainer
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1)
active_learning_loop.connect(trainer.fit_loop)
trainer.fit_loop = active_learning_loop

# 3.3 Finetune
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict what's on a few images! ants or bees?
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Here is the Active Learning Loop example and the code for the active learning loop.


Advanced Examples

Ready-to-run loop examples and tutorials

Link to Example

Description

K-fold Cross Validation

KFold / Cross Validation is a machine learning practice in which the training dataset is being partitioned into num_folds complementary subsets. One cross validation round will perform fitting where one fold is left out for validation and the other folds are used for training. To reduce variability, once all rounds are performed using the different folds, the trained models are ensembled and their predictions are averaged when estimating the model’s predictive performance on the test dataset.

Yielding Training Step

This loop enables you to write the training_step() hook as a Python Generator for automatic optimization with multiple optimizers, i.e., you can yield loss values from it instead of returning them. This can enable more elegant and expressive implementations, as shown shown with a GAN in this example.


Advanced Features

Next: Advanced loop features