{"cells": [{"cell_type": "markdown", "id": "14e28b67", "metadata": {"papermill": {"duration": 0.004712, "end_time": "2023-10-11T19:17:43.890508", "exception": false, "start_time": "2023-10-11T19:17:43.885796", "status": "completed"}, "tags": []}, "source": ["\n", "# Introduction to PyTorch Lightning\n", "\n", "* **Author:** Lightning.ai\n", "* **License:** CC BY-SA\n", "* **Generated:** 2023-10-11T19:15:44.096249\n", "\n", "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n", "\n", "---\n", "Open in [{height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/mnist-hello-world.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "76eb4d71", "metadata": {"papermill": {"duration": 0.003791, "end_time": "2023-10-11T19:17:43.898299", "exception": false, "start_time": "2023-10-11T19:17:43.894508", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "9f4f2ba7", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2023-10-11T19:17:43.906826Z", "iopub.status.busy": "2023-10-11T19:17:43.906598Z", "iopub.status.idle": "2023-10-11T19:19:16.223109Z", "shell.execute_reply": "2023-10-11T19:19:16.221646Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 92.323974, "end_time": "2023-10-11T19:19:16.225919", "exception": false, "start_time": "2023-10-11T19:17:43.901945", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}], "source": ["! pip install --quiet \"pandas\" \"torchmetrics>=0.7, <1.3\" \"urllib3\" \"lightning>=2.0.0\" \"setuptools>=68.0.0, <68.3.0\" \"torchvision\" \"pytorch-lightning>=1.4, <2.1.0\" \"ipython[notebook]>=8.0.0, <8.17.0\" \"torchmetrics >=0.11.0\" \"seaborn\" \"torch>=1.8.1, <2.1.0\" \"matplotlib>=3.0.0, <3.9.0\""]}, {"cell_type": "code", "execution_count": 2, "id": "fa6ca114", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T19:19:16.241909Z", "iopub.status.busy": "2023-10-11T19:19:16.241535Z", "iopub.status.idle": "2023-10-11T19:19:20.298920Z", "shell.execute_reply": "2023-10-11T19:19:20.297076Z"}, "papermill": {"duration": 4.069363, "end_time": "2023-10-11T19:19:20.301611", "exception": false, "start_time": "2023-10-11T19:19:16.232248", "status": "completed"}, "tags": []}, "outputs": [], "source": ["\n", "# ------------------- Preliminaries ------------------- #\n", "import os\n", "from dataclasses import dataclass\n", "from typing import Tuple\n", "\n", "import lightning as L\n", "import pandas as pd\n", "import seaborn as sn\n", "import torch\n", "from IPython.display import display\n", "from lightning.pytorch.loggers import CSVLogger\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils.data import DataLoader, random_split\n", "from torchmetrics import Accuracy\n", "from torchvision import transforms\n", "from torchvision.datasets import MNIST\n", "\n", "# ------------------- Configuration ------------------- #\n", "\n", "\n", "@dataclass\n", "class Config:\n", " \"\"\"Configuration options for the Lightning MNIST example.\n", "\n", " Args:\n", " data_dir : The path to the directory where the MNIST dataset is stored. Defaults to the value of\n", " the 'PATH_DATASETS' environment variable or '.' if not set.\n", "\n", " save_dir : The path to the directory where the training logs will be saved. Defaults to 'logs/'.\n", "\n", " batch_size : The batch size to use during training. Defaults to 256 if a GPU is available,\n", " or 64 otherwise.\n", "\n", " max_epochs : The maximum number of epochs to train the model for. Defaults to 3.\n", "\n", " accelerator : The accelerator to use for training. Can be one of \"cpu\", \"gpu\", \"tpu\", \"ipu\", \"auto\".\n", "\n", " devices : The number of devices to use for training. Defaults to 1.\n", "\n", " Examples:\n", " This dataclass can be used to specify the configuration options for training a PyTorch Lightning model on the\n", " MNIST dataset. A new instance of this dataclass can be created as follows:\n", "\n", " >>> config = Config()\n", "\n", " The default values for each argument are shown in the documentation above. If desired, any of these values can be\n", " overridden when creating a new instance of the dataclass:\n", "\n", " >>> config = Config(batch_size=128, max_epochs=5)\n", " \"\"\"\n", "\n", " data_dir: str = os.environ.get(\"PATH_DATASETS\", \".\")\n", " save_dir: str = \"logs/\"\n", " batch_size: int = 256 if torch.cuda.is_available() else 64\n", " max_epochs: int = 3\n", " accelerator: str = \"auto\"\n", " devices: int = 1\n", "\n", "\n", "config = Config()"]}, {"cell_type": "markdown", "id": "507ccd62", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.006133, "end_time": "2023-10-11T19:19:20.314347", "exception": false, "start_time": "2023-10-11T19:19:20.308214", "status": "completed"}, "tags": []}, "source": ["## Simplest example\n", "\n", "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n", "\n", "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."]}, {"cell_type": "code", "execution_count": 3, "id": "4a2a4091", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T19:19:20.329307Z", "iopub.status.busy": "2023-10-11T19:19:20.328270Z", "iopub.status.idle": "2023-10-11T19:19:20.341979Z", "shell.execute_reply": "2023-10-11T19:19:20.341006Z"}, "papermill": {"duration": 0.023944, "end_time": "2023-10-11T19:19:20.344384", "exception": false, "start_time": "2023-10-11T19:19:20.320440", "status": "completed"}, "tags": []}, "outputs": [], "source": ["\n", "\n", "class MNISTModel(L.LightningModule):\n", " \"\"\"A PyTorch Lightning module for classifying images in the MNIST dataset.\n", "\n", " Attributes:\n", " l1 : A linear layer that maps input features to output features.\n", "\n", " Methods:\n", " forward(x):\n", " Performs a forward pass through the model.\n", "\n", " training_step(batch, batch_nb):\n", " Defines a single training step for the model.\n", "\n", " configure_optimizers():\n", " Configures the optimizer to use during training.\n", "\n", " Examples:\n", " The MNISTModel class can be used to create and train a PyTorch Lightning model for classifying images in the MNIST\n", " dataset. To create a new instance of the model, simply instantiate the class:\n", "\n", " >>> model = MNISTModel()\n", "\n", " The model can then be trained using a PyTorch Lightning trainer object:\n", "\n", " >>> trainer = pl.Trainer()\n", " >>> trainer.fit(model)\n", " \"\"\"\n", "\n", " def __init__(self):\n", " \"\"\"Initializes a new instance of the MNISTModel class.\"\"\"\n", " super().__init__()\n", " self.l1 = torch.nn.Linear(28 * 28, 10)\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"Performs a forward pass through the model.\n", "\n", " Args:\n", " x : The input tensor to pass through the model.\n", "\n", " Returns:\n", " activated : The output tensor produced by the model.\n", "\n", " Examples:\n", " >>> model = MNISTModel()\n", " >>> x = torch.randn(1, 1, 28, 28)\n", " >>> output = model(x)\n", " \"\"\"\n", " flattened = x.view(x.size(0), -1)\n", " hidden = self.l1(flattened)\n", " activated = torch.relu(hidden)\n", "\n", " return activated\n", "\n", " def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_nb: int) -> torch.Tensor:\n", " \"\"\"Defines a single training step for the model.\n", "\n", " Args:\n", " batch: A tuple containing the input and target tensors for the batch.\n", " batch_nb: The batch number.\n", "\n", " Returns:\n", " torch.Tensor: The loss value for the current batch.\n", "\n", " Examples:\n", " >>> model = MNISTModel()\n", " >>> x = torch.randn(1, 1, 28, 28)\n", " >>> y = torch.tensor([1])\n", " >>> loss = model.training_step((x, y), 0)\n", " \"\"\"\n", " x, y = batch\n", " loss = F.cross_entropy(self(x), y)\n", " return loss\n", "\n", " def configure_optimizers(self) -> torch.optim.Optimizer:\n", " \"\"\"Configures the optimizer to use during training.\n", "\n", " Returns:\n", " torch.optim.Optimizer: The optimizer to use during training.\n", "\n", " Examples:\n", " >>> model = MNISTModel()\n", " >>> optimizer = model.configure_optimizers()\n", " \"\"\"\n", " return torch.optim.Adam(self.parameters(), lr=0.02)"]}, {"cell_type": "markdown", "id": "e189cea8", "metadata": {"papermill": {"duration": 0.006164, "end_time": "2023-10-11T19:19:20.356982", "exception": false, "start_time": "2023-10-11T19:19:20.350818", "status": "completed"}, "tags": []}, "source": ["By using the `Trainer` you automatically get:\n", "1. Tensorboard logging\n", "2. Model checkpointing\n", "3. Training and validation loop\n", "4. early-stopping"]}, {"cell_type": "code", "execution_count": 4, "id": "8ea1b3bb", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T19:19:20.371087Z", "iopub.status.busy": "2023-10-11T19:19:20.370720Z", "iopub.status.idle": "2023-10-11T19:19:38.567377Z", "shell.execute_reply": "2023-10-11T19:19:38.566580Z"}, "papermill": {"duration": 18.206536, "end_time": "2023-10-11T19:19:38.569794", "exception": false, "start_time": "2023-10-11T19:19:20.363258", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/14/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 0%| | 0/9912422 [00:00, ?it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 45%|\u2588\u2588\u2588\u2588\u258d | 4423680/9912422 [00:00<00:00, 44140128.13it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 95%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258c| 9437184/9912422 [00:00<00:00, 47607761.85it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 9912422/9912422 [00:00<00:00, 47157242.46it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/14/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz to /__w/14/s/.datasets/MNIST/raw\n"]}, {"name": "stdout", "output_type": "stream", "text": ["\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /__w/14/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 0%| | 0/28881 [00:00, ?it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 28881/28881 [00:00<00:00, 47152858.63it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/14/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /__w/14/s/.datasets/MNIST/raw\n", "\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /__w/14/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 0%| | 0/1648877 [00:00, ?it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1648877/1648877 [00:00<00:00, 37055704.43it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/14/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /__w/14/s/.datasets/MNIST/raw\n", "\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /__w/14/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 0%| | 0/4542 [00:00, ?it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 4542/4542 [00:00<00:00, 9379876.30it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/14/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/14/s/.datasets/MNIST/raw\n", "\n"]}, {"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", " warning_cache.warn(\n"]}, {"name": "stderr", "output_type": "stream", "text": ["You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2,3]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | l1 | Linear | 7.9 K \n", "--------------------------------\n", "7.9 K Trainable params\n", "0 Non-trainable params\n", "7.9 K Total params\n", "0.031 Total estimated model params size (MB)\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:442: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "37b0c1b7cbe74ab383f5bdefe7aed12b", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["# Init our model\n", "mnist_model = MNISTModel()\n", "\n", "# Init DataLoader from MNIST Dataset\n", "train_ds = MNIST(config.data_dir, train=True, download=True, transform=transforms.ToTensor())\n", "\n", "# Create a dataloader\n", "train_loader = DataLoader(train_ds, batch_size=config.batch_size)\n", "\n", "# Initialize a trainer\n", "trainer = L.Trainer(\n", " accelerator=config.accelerator,\n", " devices=config.devices,\n", " max_epochs=config.max_epochs,\n", ")\n", "\n", "# Train the model \u26a1\n", "trainer.fit(mnist_model, train_loader)"]}, {"cell_type": "markdown", "id": "c642887e", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.01001, "end_time": "2023-10-11T19:19:38.591383", "exception": false, "start_time": "2023-10-11T19:19:38.581373", "status": "completed"}, "tags": []}, "source": ["## A more complete MNIST Lightning Module Example\n", "\n", "That wasn't so hard was it?\n", "\n", "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n", "\n", "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`.\n", "This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n", "\n", "---\n", "\n", "### Note what the following built-in functions are doing:\n", "\n", "1. [prepare_data()](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#prepare-data) \ud83d\udcbe\n", " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", "\n", "2. [setup(stage)](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#setup) \u2699\ufe0f\n", " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).\n", " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n", " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", "\n", "3. [x_dataloader()](https://lightning.ai/docs/pytorch/stable/api/pytorch_lightning.core.hooks.DataHooks.html#pytorch_lightning.core.hooks.DataHooks.train_dataloader) \u267b\ufe0f\n", " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"]}, {"cell_type": "code", "execution_count": 5, "id": "b590d798", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T19:19:38.612552Z", "iopub.status.busy": "2023-10-11T19:19:38.612287Z", "iopub.status.idle": "2023-10-11T19:19:38.632716Z", "shell.execute_reply": "2023-10-11T19:19:38.632022Z"}, "papermill": {"duration": 0.032842, "end_time": "2023-10-11T19:19:38.634253", "exception": false, "start_time": "2023-10-11T19:19:38.601411", "status": "completed"}, "tags": []}, "outputs": [], "source": ["\n", "\n", "class LitMNIST(L.LightningModule):\n", " \"\"\"PyTorch Lightning module for training a multi-layer perceptron (MLP) on the MNIST dataset.\n", "\n", " Attributes:\n", " data_dir : The path to the directory where the MNIST data will be downloaded.\n", "\n", " hidden_size : The number of units in the hidden layer of the MLP.\n", "\n", " learning_rate : The learning rate to use for training the MLP.\n", "\n", " Methods:\n", " forward(x):\n", " Performs a forward pass through the MLP.\n", "\n", " training_step(batch, batch_idx):\n", " Defines a single training step for the MLP.\n", "\n", " validation_step(batch, batch_idx):\n", " Defines a single validation step for the MLP.\n", "\n", " test_step(batch, batch_idx):\n", " Defines a single testing step for the MLP.\n", "\n", " configure_optimizers():\n", " Configures the optimizer to use for training the MLP.\n", "\n", " prepare_data():\n", " Downloads the MNIST dataset.\n", "\n", " setup(stage=None):\n", " Splits the MNIST dataset into train, validation, and test sets.\n", "\n", " train_dataloader():\n", " Returns a DataLoader for the training set.\n", "\n", " val_dataloader():\n", " Returns a DataLoader for the validation set.\n", "\n", " test_dataloader():\n", " Returns a DataLoader for the test set.\n", " \"\"\"\n", "\n", " def __init__(self, data_dir: str = config.data_dir, hidden_size: int = 64, learning_rate: float = 2e-4):\n", " \"\"\"Initializes a new instance of the LitMNIST class.\n", "\n", " Args:\n", " data_dir : The path to the directory where the MNIST data will be downloaded. Defaults to config.data_dir.\n", "\n", " hidden_size : The number of units in the hidden layer of the MLP (default is 64).\n", "\n", " learning_rate : The learning rate to use for training the MLP (default is 2e-4).\n", " \"\"\"\n", " super().__init__()\n", "\n", " # Set our init args as class attributes\n", " self.data_dir = data_dir\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " # Hardcode some dataset specific attributes\n", " self.num_classes = 10\n", " self.dims = (1, 28, 28)\n", " channels, width, height = self.dims\n", "\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", " ]\n", " )\n", "\n", " # Define PyTorch model\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, self.num_classes),\n", " )\n", "\n", " self.val_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", " self.test_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"Performs a forward pass through the MLP.\n", "\n", " Args:\n", " x : The input data.\n", "\n", " Returns:\n", " torch.Tensor: The output of the MLP.\n", " \"\"\"\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_nb: int) -> torch.Tensor:\n", " \"\"\"Defines a single training step for the MLP.\n", "\n", " Args:\n", " batch: A tuple containing the input data and target labels.\n", "\n", " batch_idx: The index of the current batch.\n", "\n", " Returns:\n", " (torch.Tensor): The training loss.\n", " \"\"\"\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_nb: int) -> None:\n", " \"\"\"Defines a single validation step for the MLP.\n", "\n", " Args:\n", " batch : A tuple containing the input data and target labels.\n", " batch_idx : The index of the current batch.\n", " \"\"\"\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " self.val_accuracy.update(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log(\"val_acc\", self.val_accuracy, prog_bar=True)\n", "\n", " def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_nb: int) -> None:\n", " \"\"\"Defines a single testing step for the MLP.\n", "\n", " Args:\n", " batch : A tuple containing the input data and target labels.\n", " batch_idx : The index of the current batch.\n", " \"\"\"\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " self.test_accuracy.update(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log(\"test_loss\", loss, prog_bar=True)\n", " self.log(\"test_acc\", self.test_accuracy, prog_bar=True)\n", "\n", " def configure_optimizers(self) -> torch.optim.Optimizer:\n", " \"\"\"Configures the optimizer to use for training the MLP.\n", "\n", " Returns:\n", " torch.optim.Optimizer: The optimizer.\n", " \"\"\"\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", "\n", " return optimizer\n", "\n", " # ------------------------------------- #\n", " # DATA RELATED HOOKS\n", " # ------------------------------------- #\n", "\n", " def prepare_data(self) -> None:\n", " \"\"\"Downloads the MNIST dataset.\"\"\"\n", " MNIST(self.data_dir, train=True, download=True)\n", "\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage: str = None) -> None:\n", " \"\"\"Splits the MNIST dataset into train, validation, and test sets.\n", "\n", " Args:\n", " stage : The current stage (either \"fit\" or \"test\"). Defaults to None.\n", " \"\"\"\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", "\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self) -> DataLoader:\n", " \"\"\"Returns a DataLoader for the training set.\n", "\n", " Returns:\n", " DataLoader: The training DataLoader.\n", " \"\"\"\n", " return DataLoader(self.mnist_train, batch_size=config.batch_size)\n", "\n", " def val_dataloader(self) -> DataLoader:\n", " \"\"\"Returns a DataLoader for the validation set.\n", "\n", " Returns:\n", " DataLoader: The validation DataLoader.\n", " \"\"\"\n", " return DataLoader(self.mnist_val, batch_size=config.batch_size)\n", "\n", " def test_dataloader(self) -> DataLoader:\n", " \"\"\"Returns a DataLoader for the test set.\n", "\n", " Returns:\n", " DataLoader: The test DataLoader.\n", " \"\"\"\n", " return DataLoader(self.mnist_test, batch_size=config.batch_size)"]}, {"cell_type": "code", "execution_count": 6, "id": "3e4b6829", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T19:19:38.647554Z", "iopub.status.busy": "2023-10-11T19:19:38.647330Z", "iopub.status.idle": "2023-10-11T19:20:06.384372Z", "shell.execute_reply": "2023-10-11T19:20:06.383605Z"}, "papermill": {"duration": 27.746092, "end_time": "2023-10-11T19:20:06.386494", "exception": false, "start_time": "2023-10-11T19:19:38.640402", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2,3]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "-----------------------------------------------------\n", "0 | model | Sequential | 55.1 K\n", "1 | val_accuracy | MulticlassAccuracy | 0 \n", "2 | test_accuracy | MulticlassAccuracy | 0 \n", "-----------------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "2c07ea9039cb4476a1a8dc08bc9808c7", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:442: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "a496173fbae644c6913bc1165e6e15a3", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "935105a8a90b4507a57417d2dd17b32a", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "c05d49bc81f64628ba3f1c0773394344", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "5ef6a7e7e41a44cd9b227c830b1474c5", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["# Instantiate the LitMNIST model\n", "model = LitMNIST()\n", "\n", "# Instantiate a PyTorch Lightning trainer with the specified configuration\n", "trainer = L.Trainer(\n", " accelerator=config.accelerator,\n", " devices=config.devices,\n", " max_epochs=config.max_epochs,\n", " logger=CSVLogger(save_dir=config.save_dir),\n", ")\n", "\n", "# Train the model using the trainer\n", "trainer.fit(model)"]}, {"cell_type": "markdown", "id": "37774f48", "metadata": {"papermill": {"duration": 0.011502, "end_time": "2023-10-11T19:20:06.404153", "exception": false, "start_time": "2023-10-11T19:20:06.392651", "status": "completed"}, "tags": []}, "source": ["### Testing\n", "\n", "To test a model, call `trainer.test(model)`.\n", "\n", "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically\n", "test using the best saved checkpoint (conditioned on val_loss)."]}, {"cell_type": "code", "execution_count": 7, "id": "da572e76", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T19:20:06.427157Z", "iopub.status.busy": "2023-10-11T19:20:06.426978Z", "iopub.status.idle": "2023-10-11T19:20:07.850798Z", "shell.execute_reply": "2023-10-11T19:20:07.850259Z"}, "papermill": {"duration": 1.435929, "end_time": "2023-10-11T19:20:07.852072", "exception": false, "start_time": "2023-10-11T19:20:06.416143", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Restoring states from the checkpoint path at logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2,3]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Loaded model weights from the checkpoint at logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:442: PossibleUserWarning: The dataloader, test_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "84a8558070594e57aa7d892c85785833", "version_major": 2, "version_minor": 0}, "text/plain": ["Testing: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"text/html": ["
\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", "\u2503 Test metric \u2503 DataLoader 0 \u2503\n", "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n", "\u2502 test_acc \u2502 0.9228000044822693 \u2502\n", "\u2502 test_loss \u2502 0.25202980637550354 \u2502\n", "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n", "\n"], "text/plain": ["\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", "\u2503\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m\u2503\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m\u2503\n", "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n", "\u2502\u001b[36m \u001b[0m\u001b[36m test_acc \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.9228000044822693 \u001b[0m\u001b[35m \u001b[0m\u2502\n", "\u2502\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.25202980637550354 \u001b[0m\u001b[35m \u001b[0m\u2502\n", "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"text/plain": ["[{'test_loss': 0.25202980637550354, 'test_acc': 0.9228000044822693}]"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["trainer.test(ckpt_path=\"best\")"]}, {"cell_type": "markdown", "id": "1e64ccd6", "metadata": {"papermill": {"duration": 0.012018, "end_time": "2023-10-11T19:20:07.870339", "exception": false, "start_time": "2023-10-11T19:20:07.858321", "status": "completed"}, "tags": []}, "source": ["### Bonus Tip\n", "\n", "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"]}, {"cell_type": "code", "execution_count": 8, "id": "fe9227f0", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T19:20:07.893915Z", "iopub.status.busy": "2023-10-11T19:20:07.893744Z", "iopub.status.idle": "2023-10-11T19:20:08.114959Z", "shell.execute_reply": "2023-10-11T19:20:08.114466Z"}, "papermill": {"duration": 0.233489, "end_time": "2023-10-11T19:20:08.117072", "exception": false, "start_time": "2023-10-11T19:20:07.883583", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:617: UserWarning: Checkpoint directory logs/lightning_logs/version_0/checkpoints exists and is not empty.\n", " rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2,3]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "-----------------------------------------------------\n", "0 | model | Sequential | 55.1 K\n", "1 | val_accuracy | MulticlassAccuracy | 0 \n", "2 | test_accuracy | MulticlassAccuracy | 0 \n", "-----------------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "f41a57de3dc441cf83c3917a96d37f22", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["trainer.fit(model)"]}, {"cell_type": "markdown", "id": "0878bd72", "metadata": {"papermill": {"duration": 0.01378, "end_time": "2023-10-11T19:20:08.145011", "exception": false, "start_time": "2023-10-11T19:20:08.131231", "status": "completed"}, "tags": []}, "source": ["In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"]}, {"cell_type": "code", "execution_count": 9, "id": "c54f8924", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T19:20:08.168663Z", "iopub.status.busy": "2023-10-11T19:20:08.168418Z", "iopub.status.idle": "2023-10-11T19:20:08.482524Z", "shell.execute_reply": "2023-10-11T19:20:08.481877Z"}, "papermill": {"duration": 0.326153, "end_time": "2023-10-11T19:20:08.484824", "exception": false, "start_time": "2023-10-11T19:20:08.158671", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/html": ["
\n", " | val_loss | \n", "val_acc | \n", "test_loss | \n", "test_acc | \n", "
---|---|---|---|---|
epoch | \n", "\n", " | \n", " | \n", " | \n", " |
0 | \n", "0.427113 | \n", "0.8896 | \n", "NaN | \n", "NaN | \n", "
1 | \n", "0.317400 | \n", "0.9096 | \n", "NaN | \n", "NaN | \n", "
2 | \n", "0.272898 | \n", "0.9174 | \n", "NaN | \n", "NaN | \n", "
3 | \n", "NaN | \n", "NaN | \n", "0.25203 | \n", "0.9228 | \n", "