
# Introduction to Pytorch Lightning

* **Author:** PL team
* **License:** CC BY-SA
* **Generated:** 2021-11-09T00:18:24.296916

In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).

---
Open in [{height="20px" width="117px"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/mnist-hello-world.ipynb)

Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)
| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)
| Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)

## Setup
This notebook requires some packages besides pytorch-lightning.

In [1]:
! pip install --quiet "torchvision" "torch>=1.6, <1.9" "pytorch-lightning>=1.3" "torchmetrics>=0.3"

In [2]:
import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

## Simplest example

Here's the simplest most minimal example with just a training loop (no validation, no testing).

**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features.

In [3]:
class MNISTModel(LightningModule):
 def __init__(self):
 super().__init__()
 self.l1 = torch.nn.Linear(28 * 28, 10)

 def forward(self, x):
 return torch.relu(self.l1(x.view(x.size(0), -1)))

 def training_step(self, batch, batch_nb):
 x, y = batch
 loss = F.cross_entropy(self(x), y)
 return loss

 def configure_optimizers(self):
 return torch.optim.Adam(self.parameters(), lr=0.02)

By using the `Trainer` you automatically get:
1. Tensorboard logging
2. Model checkpointing
3. Training and validation loop
4. early-stopping

In [4]:
# Init our model
mnist_model = MNISTModel()

# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

# Initialize a trainer
trainer = Trainer(
 gpus=AVAIL_GPUS,
 max_epochs=3,
 progress_bar_refresh_rate=20,
)

# Train the model ⚡
trainer.fit(mnist_model, train_loader)

 rank_zero_deprecation(
GPU available: True, used: True


TPU available: False, using: 0 TPU cores


IPU available: False, using: 0 IPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



 | Name | Type | Params
--------------------------------
0 | l1 | Linear | 7.9 K 
--------------------------------
7.9 K Trainable params
0 Non-trainable params
7.9 K Total params
0.031 Total estimated model params size (MB)


 rank_zero_warn(


Training: 0it [00:00, ?it/s]

## A more complete MNIST Lightning Module Example

That wasn't so hard was it?

Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...

This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`.
This way, we can avoid writing extra code at the beginning of our script every time we want to run it.

---

### Note what the following built-in functions are doing:

1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾
 - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.
 - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)

2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning-module.html#setup) ⚙️
 - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).
 - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.
 - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).
 - **Note this runs across all GPUs and it *is* safe to make state assignments here**

3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning-module.html#data-hooks) ♻️
 - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`

In [5]:
class LitMNIST(LightningModule):
 def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):

 super().__init__()

 # Set our init args as class attributes
 self.data_dir = data_dir
 self.hidden_size = hidden_size
 self.learning_rate = learning_rate

 # Hardcode some dataset specific attributes
 self.num_classes = 10
 self.dims = (1, 28, 28)
 channels, width, height = self.dims
 self.transform = transforms.Compose(
 [
 transforms.ToTensor(),
 transforms.Normalize((0.1307,), (0.3081,)),
 ]
 )

 # Define PyTorch model
 self.model = nn.Sequential(
 nn.Flatten(),
 nn.Linear(channels * width * height, hidden_size),
 nn.ReLU(),
 nn.Dropout(0.1),
 nn.Linear(hidden_size, hidden_size),
 nn.ReLU(),
 nn.Dropout(0.1),
 nn.Linear(hidden_size, self.num_classes),
 )

 self.accuracy = Accuracy()

 def forward(self, x):
 x = self.model(x)
 return F.log_softmax(x, dim=1)

 def training_step(self, batch, batch_idx):
 x, y = batch
 logits = self(x)
 loss = F.nll_loss(logits, y)
 return loss

 def validation_step(self, batch, batch_idx):
 x, y = batch
 logits = self(x)
 loss = F.nll_loss(logits, y)
 preds = torch.argmax(logits, dim=1)
 self.accuracy(preds, y)

 # Calling self.log will surface up scalars for you in TensorBoard
 self.log("val_loss", loss, prog_bar=True)
 self.log("val_acc", self.accuracy, prog_bar=True)
 return loss

 def test_step(self, batch, batch_idx):
 # Here we just reuse the validation_step for testing
 return self.validation_step(batch, batch_idx)

 def configure_optimizers(self):
 optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
 return optimizer

 ####################
 # DATA RELATED HOOKS
 ####################

 def prepare_data(self):
 # download
 MNIST(self.data_dir, train=True, download=True)
 MNIST(self.data_dir, train=False, download=True)

 def setup(self, stage=None):

 # Assign train/val datasets for use in dataloaders
 if stage == "fit" or stage is None:
 mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
 self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

 # Assign test dataset for use in dataloader(s)
 if stage == "test" or stage is None:
 self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

 def train_dataloader(self):
 return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

 def val_dataloader(self):
 return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

 def test_dataloader(self):
 return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

In [6]:
model = LitMNIST()
trainer = Trainer(
 gpus=AVAIL_GPUS,
 max_epochs=3,
 progress_bar_refresh_rate=20,
)
trainer.fit(model)

GPU available: True, used: True


TPU available: False, using: 0 TPU cores


IPU available: False, using: 0 IPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



 | Name | Type | Params
----------------------------------------
0 | model | Sequential | 55.1 K
1 | accuracy | Accuracy | 0 
----------------------------------------
55.1 K Trainable params
0 Non-trainable params
55.1 K Total params
0.220 Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

 rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

### Testing

To test a model, call `trainer.test(model)`.

Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically
test using the best saved checkpoint (conditioned on val_loss).

In [7]:
trainer.test()

 rank_zero_warn(
Restoring states from the checkpoint path at /__w/1/s/lightning_logs/version_1/checkpoints/epoch=2-step=644.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Loaded model weights from checkpoint at /__w/1/s/lightning_logs/version_1/checkpoints/epoch=2-step=644.ckpt


 rank_zero_warn(


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_acc': 0.9241999983787537, 'val_loss': 0.25223809480667114}
--------------------------------------------------------------------------------


[{'val_loss': 0.25223809480667114, 'val_acc': 0.9241999983787537}]

### Bonus Tip

You can keep calling `trainer.fit(model)` as many times as you'd like to continue training

In [8]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]



 | Name | Type | Params
----------------------------------------
0 | model | Sequential | 55.1 K
1 | accuracy | Accuracy | 0 
----------------------------------------
55.1 K Trainable params
0 Non-trainable params
55.1 K Total params
0.220 Total estimated model params size (MB)


 rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!

In [9]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

## Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning
movement, you can do so in the following ways!

### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub
The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool
tools we're building.

### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)!
The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself
and share your interests in `#general` channel


### Contributions !
The best way to contribute to our community is to become a code contributor! At any time you can go to
[Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts)
GitHub Issues page and filter for "good first issue".

* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
* You can also contribute your own notebooks with useful examples !

### Great thanks from the entire Pytorch Lightning Team for your interest !

{height="60px" width="240px"}