Introducing Lit-GPT: Hackable implementation of open-source large language models released under Apache 2.0 →

← Back to blog

PyTorch Lightning for Dummies – A Tutorial and Overview


You’ll learn to use PyTorch Lightning’s Core API features by completing an applied project to train a Language Transformer written in PyTorch on the WikiText2 dataset.

The code in this tutorial is available on GitHub in the text-lab repo. Clone the repo and follow along!


Training deep learning models at scale is an incredibly interesting and complex task. Reproducibility for projects is key, and reproducible code bases are exactly what we get when we leverage PyTorch Lightning for training and finetuning. An added benefit of using PyTorch Lightning is that the framework is domain agnostic and is complementary to PyTorch. Meaning – it does not replace PyTorch and we are enabled to train text, vision, audio, and multimodal models using the same framework – PyTorch Lightning.

The Research

Our research objective for this tutorial is to train a small language model using a Transformer on the WikiText2 dataset. Both the Transformer and the dataset are available to us in PyTorch Lightning at pytorch_lightning.demos.transformer. We’ll see later how we can pull those into our Python module or Jupyter Notebook for use in our custom LightningDataModule and LightningModule.

PyTorch and PyTorch Lightning

PyTorch Lightning is not a replacement for PyTorch. Rather, PyTorch Lightning is an extension – a framework used to train models that have been implemented with PyTorch. This relationship is visualized in the following snippet.

import pytorch_lightning as pl

class LabModule(pl.LightningModule):
    def __init__(self, vocab_size: int = 33278):
        self.model = Transformer(vocab_size=vocab_size)

When we create self.model as shown above, we often refer to self.model as the internal module. Let’s keep reading to learn how to apply this interoperability between PyTorch and PyTorch Lightning!

PyTorch Lightning: The Core API

Okay – time to get to it! In the next sections, we will cover how to use the Core API of PyTorch Lightning. What is the Core API? First, let’s consider how we might organize the training steps of any deep learning project sequentially according to data processing, creating a model, and then training that model on the given dataset.

These key steps/attributes are exactly how the Core API is structured with LightningDataModule, LightningModule, and Trainer.


LightningDataModule (LDM) wraps the data phase. It takes in a custom PyTorch Dataset and DataLoader which enables Trainer to handle data during training. If needed, LDM exposes the setup and prepare_data hooks in case you need additional customization. For the training phase, the PyTorch DataLoader has to be defined as train_dataloader and val_dataloader. The following code snippet is pseudocode (an example) of how to import LightningDataModule and use it to create a custom class.

import pytorch_lightning as pl

class LabDataModule(pl.LightningDataModule):
    def __init__(self):

We will see examples of creating train_dataloader and val_dataloader methods in LDM later in this tutorial.


LightningModule is the main training interface with the previously mentioned PyTorch models referred to as ‘internal modules’. LightningModule itself is a custom torch.nn.Module that is extended with dozens of additional hooks like on_fit_start and on_fit_end. These hooks allow us better control of Trainer’s flows and enables custom behaviors by overriding these hooks. The following snippet of pseudo-code shows how to import and use LightningModule to create a custom class.

import pytorch_lightning as pl

class LabModule(pl.LightningModule):
    def __init__(self):


Trainer configures the training scope and manages the training loop with LightningModule and LightningDataModule. The simplest Trainer configuration is accomplished by setting flags like devices, accelerator, and strategy and by passing in our choice of loggers, profilers, callbacks, and plugins.

import pytorch_lightning as pl

# instantiate the trainer
trainer = pl.Trainer() 

# instantiate the datamodule
datamodule = LabDataModule() 
# instantiate the model
model = LabModule() 

# call fit to start training, datamodule=datamodule) 

Rather see this explained in a video? Sebastian Raschka, our Lead AI Educator, breaks down how to get started with structuring our PyTorch Code using PyTorch Lightning.

Getting Started: Hands-on Coding

Installing PyTorch Lightning

First, we will need to install PyTorch Lightning. We can further understand how closely integrated PyTorch Lightning is with PyTorch during the installation process. How? Simply by calling out that using the following command in the terminal to install PyTorch Lightning also installs PyTorch into our virtual environment. So let’s go ahead and install PyTorch Lightning using the following command.

pip install pytorch-lightning

We also need to install TorchText in order to run the demo. Let’s also do that by using the following command in the terminal.

pip install torchtext

Do you need help creating a virtual environment? There’s a video for that too!

The Custom LightningDataModule

The dataset we will use is WikiText2. The demo code available to us in PyTorch Lightning will automatically fetch WikiText2 for us – so there’s no need to worry about downloading the dataset from torchtext.

In the example below, WikiText2 is imported as LabDataset. If you wish to do so, you can check out the code used to create the custom PyTorch Dataset in However, for the purposes of this tutorial, we can ignore that implementation for now. Once again, here’s the pseudo-code for creating a LightningDataModule without adding any additional customization.

import pytorch_lightning as pl

class LabDataModule(pl.LightningDataModule):
    def __init__(self):

Creating the Custom Class

Compared to the pseudo code example, we need to customize the __init__ method further in order to enable random splitting of the dataset, and let the LightningDataModule know the data source. This is where we can also provide domain-specific arguments like block_size for datasets used in text problems, or image_size for vision problems.

In the code blocks below we will use a custom extension for this project which inherits the PL basic classes for LightningDataModule. This custom class will be called LabDataModule.

from pathlib import Path

from import DataLoader, random_split

import pytorch_lightning as pl

from textlab import Config
from textlab.pipeline import LabDataset

class LabDataModule(pl.LightningDataModule):
    def __init__(
        num_workers: int = 2,
        data_dir: Path ="data",
        block_size: int = 35,
        download: bool = True,
        train_size: float = 0.8,
        self.data_dir = data_dir
        self.block_size = block_size = download
        self.num_workers = num_workers
        self.train_size = train_size
        self.dataset = None

Preparing the Data

The prepare_data method will be called first by Trainer. When this happens, the dataset will either download or be fetched automatically from the data directory cache. In particular, this is important for multi-node training when each node needs to have its own copy for the training dataset.

    def prepare_data(self):
        self.dataset = LabDataset(

Setting Up the Data Splits

The setup method is used by Trainer while setting up the training process on each used device/GPU. In this hook, we need to create the train_data, val_data, and test_data splits for the entire dataset and refer to them as attributes of LabDataModule (these splits will be later passed to the Dataloaders).

    def setup(self, stage: str):
        if stage == "fit" or stage is None:
            train_size = int(len(self.dataset) * self.train_size)
            test_size = len(self.dataset) - train_size
            self.train_data, self.val_data = random_split(self.dataset, lengths=[train_size, test_size])
        if stage == "test" or stage is None:
            self.test_data = self.val_data

The Train, Validation, and Test DataLoaders

As implied above, Trainer does not access train_data, val_data, or test_data directly, but instead, needs to be fed by batching data with PyTorch DataLoaders.

    def train_dataloader(self):
        return DataLoader(self.train_data, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_data, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_data, num_workers=self.num_workers)

Now that we have handled our data phase by creating LabDataModule, let’s move onto our model phase by creating a custom LightingModule around our Transformer internal module.

Need a video break? Here’s an awesome summary by Sebastian on Organizing Your Data Loaders with Data Modules.

The Custom LightningModule

Creating the Language Model with PyTorch

The demo Transformer from PyTorch Lightning is implemented with PyTorch. The details of the implementation are outside of the scope of this post. For now, let’s just remember the previously discussed concept of torch.nn.Modules as internal modules that will be used in LightningModules as self.model.

Curious about what the Transformer implementation looks like? You can check it out on GitHub!

Creating the Custom Class

Just as we further customized LabDataModule, we also need to customize our LightningModule in the class shown below. This new class, LabModule, is our custom LightningModule that will interface with Trainer and train the internal module with the hooks that we cover in the next few sections.

Below we can see that we have added forward and training_step in addition to __init__. The forward method calls the internal module and will cause the Transformer to go through its training process using what we refer to as autograd in the PyTorch world. Autograd is beyond the scope of this post. However, if you’d like to learn more – you can read this Gentle Introduction to torch.autograd by the PyTorch team.

class LabModule(pl.LightningModule):
    def __init__(self, vocab_size: int = 33278):
        self.model = Transformer(vocab_size=vocab_size)

    def forward(self, inputs, target):
        return self.model(inputs, target)

    def training_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs, target)
        loss = torch.nn.functional.nll_loss(output, target.view(-1))
        return loss

The training_step Method

The primary interaction between a LightningModule and Trainer happens via the training_step method of the LightningModule. This method will call the internal module and calculate a loss function needed for model optimization. In particular, the method sources batched inputs, passes the batch to the model, and collects an output/response. Then, it calculates and returns the loss that will be used by Trainer to update the gradients.

    def training_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs, target)
        loss = torch.nn.functional.nll_loss(output, target.view(-1))
        self.log("training-loss", loss)
        return loss

The validation_step and test_step Methods

These validation and test methods are similar to training_step except that no loss is returned. And, if we were to use the EarlyStopping callback in the validation_step, we’d monitor the loss that is calculated during the validation_step to stop training if no improvement is observed.

    def validation_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs, target)
        loss = torch.nn.functional.nll_loss(output, target.view(-1))
        self.log("val-loss", loss)

    def test_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs, target)
        loss = torch.nn.functional.nll_loss(output, target.view(-1))
        self.log("test-loss", loss)

Curious about EarlyStopping? Check out this video by our founder, Will Falcon to learn more!

The configure_optimizers Method

PyTorch Lightning provides two main modes for managing the optimization process: manual and automatic. For the majority of research cases, automatic optimization will do the right thing for you and it is what most users should use, and we will be using automatic optimization with configure_optimizers as shown below.

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), lr=0.1)

Using Trainer

Okay – let’s talk a little bit more about Trainer, the star of the show. Several years have gone into developing PyTorch Lightning, and especially Trainer and all that lies under the hood – like its ability to configure the environment as mentioned in the introduction.

We’ll take the code snippet for Trainer that was used in the introduction and modify it in a way that allows us to understand how to configure the environment and then conduct a training run.

Configuring Trainer with Flags

We can configure our environment when we first instantiate the Trainer object. This is done with flags like devices, accelerator, and strategy and by passing in our choice of loggers, profilers, and callbacks.

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.profilers import PyTorchProfiler

trainer = pl.Trainer(
    callbacks=EarlyStopping(monitor="val-loss", mode="min"),
    logger=WandbLogger(name="textlab-demo", save_dir="logs/wandb"),

Aside from these additional features, you will also notice that we have set devices, accelerator, strategy, precision, and enable_checkpointing. Additional context for each is provided below:

  • accelerator: supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “auto”) as well as custom accelerator instances.
  • devices: number of devices to train on (int), which devices to train on (list or str), or “auto”
  • strategy: supports passing different training strategies with aliases (“ddp”, “fsdp”, etc) as well as configured strategies.
  • precision: check out this video by Will on precision.
  • enable_checkpointing: saves a checkpoint for you in your current working directory, with the state of your last training epoch.

In the snippet shown above, we are also enabling our trainer session with additional features like EarlyStopping, WandbLogger, and PyTorchProfiler.

Why is using PyTorchProfiler handy? Because it allows us to find bottlenecks in our training loop. If you’re interested in learning more – here’s our official documentation on the topic.

Why use WanbdLogger? Because it will allow us to visualize and compare our experiments with interactive graphs and tables on the Weights and Biases platform.

Training the Model

Now that we have passed in our appropriate flags, callbacks, and support plugins, we are ready to train the model! We can start the training with the three easy lines shown below.

# instantiate the datamodule
datamodule = LabDataModule() 
# instantiate the model
model = LabModule() 
# call .fit, datamodule=datamodule) 

If you’ve cloned and installed the demo repo, text-lab, then you can test out what we’ve done above with one of the following commands in the terminal.

To run Trainer in fast-dev-run mode, use this command:

lab run dev-run

Otherwise, you can test a full demo run and log it with CSVLogger with:

lab run demo-run

And if you have a Weights and Biases account, you can log your run with:

lab run demo-run --logger wandb

Great Work!

Above we learned how to organize our PyTorch code into a LightningDataModule, LightningModule, and to automate everything else with Trainer. By doing so, we trained a language Transformer on the WikiText2 dataset and even saw how we could create custom classes by implementing inheritance in Python interfaces (class objects). We also used a custom CLI built with Typer.


Research and production code can quickly grow from simple modules to feature-rich packages as technical needs like distributed training and quantization arise. PyTorch Lightning has implemented these features for you. Making those features abstracted from the code you care about most – the research code focused on training your model.

By using PyTorch Lightning you inherit well tested features, which translates to faster prototyping thanks to fewer bugs and faster iterations. You also inherit code written for researchers by researchers. PyTorch Lightning’s contributors are individuals in their Ph.D. candidacy or they are working as research scientists or research engineers at leading AI labs.

PyTorch Lightning structures our code into four cohesive segments: data code, engineering code, research code, and support code (loggers, profilers). Compartmentalization according to task helps us to organize our code base, increasing readability and reusability. In turn, this creates a more maintainable code base that is suitable across the spectrum of beginners and experts.

Remember – PyTorch Lightning is also not a replacement framework for PyTorch. Instead, it is a philosophy and methodology of organizing your PyTorch code to create reproducible state of the art research at scale – with ease.


Official Documentation


Still have questions?

We have an amazing community and team of core engineers ready to answer questions you might have about PyTorch Lightning and the rest of the Lightning ecosystem. So, join us on Discourse or Discord. See you there!