{"cells": [{"cell_type": "markdown", "id": "b4257d74", "metadata": {"papermill": {"duration": 0.006769, "end_time": "2023-01-05T11:17:00.355472", "exception": false, "start_time": "2023-01-05T11:17:00.348703", "status": "completed"}, "tags": []}, "source": ["\n", "# Introduction to Pytorch Lightning\n", "\n", "* **Author:** PL team\n", "* **License:** CC BY-SA\n", "* **Generated:** 2023-01-05T12:09:29.379466\n", "\n", "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n", "\n", "---\n", "Open in [![Open In Colab](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAHUAAAAUCAYAAACzrHJDAAAIuUlEQVRoQ+1ZaVRURxb+qhdolmbTUVSURpZgmLhHbQVFZIlGQBEXcMvJhKiTEzfigjQg7oNEJ9GMGidnjnNMBs2czIzajksEFRE1xklCTKJiQLRFsUGkoUWw+82pamn79etGYoKek1B/4NW99/tu3e/dquJBAGD27NkHALxKf39WY39gyrOi+i3xqGtUoePJrFmznrmgtModorbTu8YRNZk5cybXTvCtwh7o6NR2KzuZMWNGh6jtVt7nA0ymT5/eJlF9POrh7PAQl6s8bGYa3PUum//htmebVtLRqW0q01M5keTk5FZFzU0oRle3+zxwg5Hgtb+PZiL/ZVohxCI+hL5JgjmfjPxZ26+33BG3dA+ealHPM4gQAo5rU59gsI8bRvl54t3Ca62mvHyUAhtOlLd5WSQpKcluBjumnoCLs1EARkVd9E8l3p9y2i7RbQ1B6pFwu/YDgW8KbHJHMTQrwnjz2oZm9M4pavOCfo5jWrgCaaMVcMs6/pNhDr0+AMN93XlxV7R6DNpyzi7W/OE+yIrsjU6rTrbKV5cd/pNyItOmTbMp6sbBB+EqaYJY4cWE3VUciNt1TpgfcRFv71Fi54xT5kSoyLvOBEJMOMxWXkFlBeBSX4u6Zkcs+3KszYRtiapbNRqF31UgetVuc8z9vBXIv1qD+F1f83B6uDlCUyfsZGepGPpmg01OB7EITQbhS9ribKy+DmP1DUiClLz4bnIHVOqa7BY+Z1wg5g3zgUvyehiNpnJKxSLc/ts76LKm0BzX3c0RNy1yXjDcB5lWoro4iNHQxM+f1kWeWQARAWQS++trISJTp061Kep25X/MycwtjuctSC5rxo7ppi7VNUox5+PhPHtrsS2O1qJ6yx1QujQUzm9sh6hbkBlvvGcN8hYnwjUjH6kjfZEd5c/jitz5Jc5U3ENnFynKl4eB7nyEgP2UZ+Yz3/rVEbyYr27qELrtC4FIC0J7sc7xWnmccdHfRRTs0VB+cA4lt+oFcRR/wUeH8FG5w2Mbx8FQ8TXEvv1xYf4wBP3O2WyL3/UVjpXWgIqaFeUPr+wTmDvUB7njH6/bOv+HRg4SqioAg5GDe1aB3ZeMTJkyRSBqkLsWqSEm0fZVBEN94zEZnYvrdx1JL5cxe+a+AbhSJecRRHW/ikTFRTa38dtQlNZ5CRKwFvUtZU/kvBoEF9Uxni/XqIM+dwKbTw3rhcxIf7gmr2M+H6SMwx8iBzJbw5oxeG3Lv5FX9B3AGaHPS8e8z77H7v9VMpvPG5ug1enh7eGK8h0LBTwUb+GInqzInlRUK65DmTPQu4c3+uQKjwKK77zwUxBX4Tq7yR1RuiwUsqlrABCM6esHdXoy47fk4+prYKy8ZF574x4V5BnHQBuf4g9Z9ld8U36L2aktZNNplNfw7zotwWTy5MkCUft4aLEopJj5/OPHl1BQqeAVOnHgNSQOqmBzq9V9cfEm/yx5ubMGKS9cYPZ3vx2OS/c6PVHUuUO7Y1Pci3BO/1zgq18byebfGemLtNF+6JRtOvMk926ibussZqM+1mNz4TWkH7rCbM5phwGRGDAaoF8fY5OHFnlldAA8sgoEXKnDukA1NgSeNjqkJT9brbN4pC9WRweYXyLugR73c+MYvyWfu0yC6+mjzN1Isfw3FKJS98CU/zI1IHFkFPR52cHL2FJk0sB6kMTERIGo9GzcPkLNfA0cwdwi/hfEYO86ZMd9w+y1egfM2T2Eh/vesMNwljSzuZRT420SW3eqy8N6aHMmwmnFUZ7/PGVPbIoNZvNU1BURdHs0bT2+HjL8sDSM2e6vi4Lj5NW8WOLVA6RTT2azxLV+bglaFNqLieqemS/gWkw7NyoAHo+2dEsiivengjKsPFoqWOvbSh/kxPaxyW/JRzH2Fl3EzD9/xjAefJqB3usKUFn/0Gb+S/d/jy3FN2yLOmnSJJtn6oehByEiHPSeXnDxFGPRnoFoaBJjcdQlbDwcjL1zTNuQpoxD7R0OG0uUTMi0fkVwdzBdYIwcwZunxrVJVLplNm54BZp7jfDfYLoNyqQi1K6KxIdHzmN+QQ2WjFIwUT2zTGdlRXo4NFXVUO4sgX5dFC7f0aP/ZlNeUjFBuL8Xjl6uRuP6aMjSjpjzsH62FDU7JhBuGccEXIvDfJFFBc/gHw80dklfCVYnRaDfpiJcutPA4F7qJsfJeUPQI+1fqMlNhFx1FM0GDqkjFVg7NojlQ0Vt4aM5ReSqcbpaCg8nCW5lRsBvbT4T1TLfFptsfh7gItzuKTdJSEiwKSrt1vcmnEXXrsLbYnWDA1bu+z2WKy9Arq+1KRqdfKsoBo0GcdtEpS/B1bO4v0cFiUhkjskvKcMrWwtAPHuwQq8Z+4LZ1vTQANfXt4J0DwZX9gWa9qh4XDM/voC9JXfwYEMMHJcfNtusn82ihvliVUwg5KrPGVf6GH94ZJpEZBen6EC4qYTHA1dXhW0JIex8txzv//c8lhzXIi/BFxOH9jGbQhZsRalTIBZZ8KkGyZAxeRQvXkFF1TWz/Hm46jNYUnjPbt3JxIkT7f6dSj8qfJJyVvBxgaIlblOyjtysNHWN9fjjqWi7glJfW3/S0Hlj2XnA8PhKT9w6g3Qx3XiXhvuxQsuT1proxBKI/AaZqY1Xz5muvY8G8XkRRCaHsfQsRAFDH/tZPbcYuHotOG0FRIqB4HR3wNVoIPLtz8ycTguu+jpEigE218vd1YCr5m+HpHMvEI9u4LTXwNWaLjl0iPwGAmIpeHx1VeCqTJdPs1/vweweQPO3HC24NhOhnTphwoQnfv6QSY2ICbkNmdSA4h87oaLaiYfn5diIEd4att2erOwJXbPUHp953p6orQVSUVWRAXBT8c/dJ5L9xhzaJGp71GR/wFP8P5V2z10NSC9T93QM2xUg8fHxT+zU9ijeU4naHon8CjFJXFzc8/kn+dN06q9QgF98SYSo2Xen2NjYZy5sR6f+4nLSK5Iam2PH/x87a1YN/t5sBgAAAABJRU5ErkJggg==){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/mnist-hello-world.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "2bf27b63", "metadata": {"papermill": {"duration": 0.002315, "end_time": "2023-01-05T11:17:00.362659", "exception": false, "start_time": "2023-01-05T11:17:00.360344", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "372258bd", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2023-01-05T11:17:00.369360Z", "iopub.status.busy": "2023-01-05T11:17:00.368896Z", "iopub.status.idle": "2023-01-05T11:17:03.134157Z", "shell.execute_reply": "2023-01-05T11:17:03.132833Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 2.772465, "end_time": "2023-01-05T11:17:03.137713", "exception": false, "start_time": "2023-01-05T11:17:00.365248", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}], "source": ["! pip install --quiet \"torchmetrics>=0.7, <0.12\" \"seaborn\" \"ipython[notebook]>=8.0.0, <8.9.0\" \"pytorch-lightning>=1.4, <1.9\" \"torchmetrics >=0.11.0\" \"setuptools==65.6.3\" \"pandas\" \"torchvision\" \"torch>=1.8.1, <1.14.0\""]}, {"cell_type": "code", "execution_count": 2, "id": "d495d12a", "metadata": {"execution": {"iopub.execute_input": "2023-01-05T11:17:03.148080Z", "iopub.status.busy": "2023-01-05T11:17:03.147841Z", "iopub.status.idle": "2023-01-05T11:17:06.489921Z", "shell.execute_reply": "2023-01-05T11:17:06.488590Z"}, "papermill": {"duration": 3.348533, "end_time": "2023-01-05T11:17:06.492608", "exception": false, "start_time": "2023-01-05T11:17:03.144075", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["/tmp/ipykernel_3064/1920170836.py:6: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n", " from IPython.core.display import display\n"]}], "source": ["import os\n", "\n", "import pandas as pd\n", "import seaborn as sn\n", "import torch\n", "from IPython.core.display import display\n", "from pytorch_lightning import LightningModule, Trainer\n", "from pytorch_lightning.callbacks.progress import TQDMProgressBar\n", "from pytorch_lightning.loggers import CSVLogger\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils.data import DataLoader, random_split\n", "from torchmetrics import Accuracy\n", "from torchvision import transforms\n", "from torchvision.datasets import MNIST\n", "\n", "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", "BATCH_SIZE = 256 if torch.cuda.is_available() else 64"]}, {"cell_type": "markdown", "id": "2444aa0a", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.002899, "end_time": "2023-01-05T11:17:06.501651", "exception": false, "start_time": "2023-01-05T11:17:06.498752", "status": "completed"}, "tags": []}, "source": ["## Simplest example\n", "\n", "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n", "\n", "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."]}, {"cell_type": "code", "execution_count": 3, "id": "0ea17820", "metadata": {"execution": {"iopub.execute_input": "2023-01-05T11:17:06.509112Z", "iopub.status.busy": "2023-01-05T11:17:06.508220Z", "iopub.status.idle": "2023-01-05T11:17:06.518338Z", "shell.execute_reply": "2023-01-05T11:17:06.517544Z"}, "papermill": {"duration": 0.016148, "end_time": "2023-01-05T11:17:06.520310", "exception": false, "start_time": "2023-01-05T11:17:06.504162", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class MNISTModel(LightningModule):\n", " def __init__(self):\n", " super().__init__()\n", " self.l1 = torch.nn.Linear(28 * 28, 10)\n", "\n", " def forward(self, x):\n", " return torch.relu(self.l1(x.view(x.size(0), -1)))\n", "\n", " def training_step(self, batch, batch_nb):\n", " x, y = batch\n", " loss = F.cross_entropy(self(x), y)\n", " return loss\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=0.02)"]}, {"cell_type": "markdown", "id": "33edeb48", "metadata": {"papermill": {"duration": 0.005641, "end_time": "2023-01-05T11:17:06.531624", "exception": false, "start_time": "2023-01-05T11:17:06.525983", "status": "completed"}, "tags": []}, "source": ["By using the `Trainer` you automatically get:\n", "1. Tensorboard logging\n", "2. Model checkpointing\n", "3. Training and validation loop\n", "4. early-stopping"]}, {"cell_type": "code", "execution_count": 4, "id": "6319bb60", "metadata": {"execution": {"iopub.execute_input": "2023-01-05T11:17:06.537593Z", "iopub.status.busy": "2023-01-05T11:17:06.537425Z", "iopub.status.idle": "2023-01-05T11:17:23.735631Z", "shell.execute_reply": "2023-01-05T11:17:23.734970Z"}, "papermill": {"duration": 17.203177, "end_time": "2023-01-05T11:17:23.737335", "exception": false, "start_time": "2023-01-05T11:17:06.534158", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "bf9ae660143b48ecb89e7080715b9a18", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/9912422 [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/6/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw\n"]}, {"name": "stdout", "output_type": "stream", "text": ["\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "02fd1a140eeb4c59a858d5e24c096de6", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/28881 [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/6/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw\n", "\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "9081419f28a44f2b94de23669c5b6bb2", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/1648877 [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/6/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw\n", "\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "ecfb0c477d7c451a93c6bab38ec300af", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/4542 [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/6/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw\n", "\n"]}, {"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Missing logger folder: /__w/6/s/lightning_logs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | l1 | Linear | 7.9 K \n", "--------------------------------\n", "7.9 K Trainable params\n", "0 Non-trainable params\n", "7.9 K Total params\n", "0.031 Total estimated model params size (MB)\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "ccc427227c1b45b98a5851a745e9527e", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["# Init our model\n", "mnist_model = MNISTModel()\n", "\n", "# Init DataLoader from MNIST Dataset\n", "train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())\n", "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)\n", "\n", "# Initialize a trainer\n", "trainer = Trainer(\n", " accelerator=\"auto\",\n", " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", " max_epochs=3,\n", " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", ")\n", "\n", "# Train the model \u26a1\n", "trainer.fit(mnist_model, train_loader)"]}, {"cell_type": "markdown", "id": "b974bd59", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.003859, "end_time": "2023-01-05T11:17:23.752001", "exception": false, "start_time": "2023-01-05T11:17:23.748142", "status": "completed"}, "tags": []}, "source": ["## A more complete MNIST Lightning Module Example\n", "\n", "That wasn't so hard was it?\n", "\n", "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n", "\n", "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`.\n", "This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n", "\n", "---\n", "\n", "### Note what the following built-in functions are doing:\n", "\n", "1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#prepare-data) \ud83d\udcbe\n", " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", "\n", "2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#setup) \u2699\ufe0f\n", " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).\n", " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n", " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", "\n", "3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.hooks.DataHooks.html#pytorch_lightning.core.hooks.DataHooks.train_dataloader) \u267b\ufe0f\n", " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"]}, {"cell_type": "code", "execution_count": 5, "id": "4ba977ca", "metadata": {"execution": {"iopub.execute_input": "2023-01-05T11:17:23.761973Z", "iopub.status.busy": "2023-01-05T11:17:23.761379Z", "iopub.status.idle": "2023-01-05T11:17:23.784018Z", "shell.execute_reply": "2023-01-05T11:17:23.783186Z"}, "papermill": {"duration": 0.030516, "end_time": "2023-01-05T11:17:23.786267", "exception": false, "start_time": "2023-01-05T11:17:23.755751", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class LitMNIST(LightningModule):\n", " def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):\n", " super().__init__()\n", "\n", " # Set our init args as class attributes\n", " self.data_dir = data_dir\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " # Hardcode some dataset specific attributes\n", " self.num_classes = 10\n", " self.dims = (1, 28, 28)\n", " channels, width, height = self.dims\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", " ]\n", " )\n", "\n", " # Define PyTorch model\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, self.num_classes),\n", " )\n", "\n", " self.val_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", " self.test_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " self.val_accuracy.update(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log(\"val_acc\", self.val_accuracy, prog_bar=True)\n", "\n", " def test_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " self.test_accuracy.update(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log(\"test_loss\", loss, prog_bar=True)\n", " self.log(\"test_acc\", self.test_accuracy, prog_bar=True)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer\n", "\n", " ####################\n", " # DATA RELATED HOOKS\n", " ####################\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)"]}, {"cell_type": "code", "execution_count": 6, "id": "f55efaa6", "metadata": {"execution": {"iopub.execute_input": "2023-01-05T11:17:23.799831Z", "iopub.status.busy": "2023-01-05T11:17:23.799660Z", "iopub.status.idle": "2023-01-05T11:17:48.541415Z", "shell.execute_reply": "2023-01-05T11:17:48.540535Z"}, "papermill": {"duration": 24.749407, "end_time": "2023-01-05T11:17:48.543854", "exception": false, "start_time": "2023-01-05T11:17:23.794447", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Missing logger folder: logs/lightning_logs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "-----------------------------------------------------\n", "0 | model | Sequential | 55.1 K\n", "1 | val_accuracy | MulticlassAccuracy | 0 \n", "2 | test_accuracy | MulticlassAccuracy | 0 \n", "-----------------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "39d287fa391b46d7b79dd570141b4e86", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "01baac8cb1944700be1cc8e22e1db1cb", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "e8fdc1872be740b0aeba4a9f602f91c4", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "3b6b4bc1dc0a4418ac5c7c3029577cbb", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "61e0d0d7dc72431ea2735586bf18ae7b", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["model = LitMNIST()\n", "trainer = Trainer(\n", " accelerator=\"auto\",\n", " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", " max_epochs=3,\n", " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", " logger=CSVLogger(save_dir=\"logs/\"),\n", ")\n", "trainer.fit(model)"]}, {"cell_type": "markdown", "id": "6563a3ae", "metadata": {"papermill": {"duration": 0.005077, "end_time": "2023-01-05T11:17:48.558415", "exception": false, "start_time": "2023-01-05T11:17:48.553338", "status": "completed"}, "tags": []}, "source": ["### Testing\n", "\n", "To test a model, call `trainer.test(model)`.\n", "\n", "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically\n", "test using the best saved checkpoint (conditioned on val_loss)."]}, {"cell_type": "code", "execution_count": 7, "id": "343253ba", "metadata": {"execution": {"iopub.execute_input": "2023-01-05T11:17:48.569703Z", "iopub.status.busy": "2023-01-05T11:17:48.569336Z", "iopub.status.idle": "2023-01-05T11:17:49.842116Z", "shell.execute_reply": "2023-01-05T11:17:49.841115Z"}, "papermill": {"duration": 1.280816, "end_time": "2023-01-05T11:17:49.843930", "exception": false, "start_time": "2023-01-05T11:17:48.563114", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py:134: UserWarning: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.\n", " rank_zero_warn(\n", "Restoring states from the checkpoint path at logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Loaded model weights from checkpoint at logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "43b93132772c4cea88ae0d507362ec05", "version_major": 2, "version_minor": 0}, "text/plain": ["Testing: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"text/html": ["
\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", "\u2503 Test metric \u2503 DataLoader 0 \u2503\n", "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n", "\u2502 test_acc \u2502 0.9228000044822693 \u2502\n", "\u2502 test_loss \u2502 0.2596317529678345 \u2502\n", "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n", "\n"], "text/plain": ["\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", "\u2503\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m\u2503\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m\u2503\n", "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n", "\u2502\u001b[36m \u001b[0m\u001b[36m test_acc \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.9228000044822693 \u001b[0m\u001b[35m \u001b[0m\u2502\n", "\u2502\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.2596317529678345 \u001b[0m\u001b[35m \u001b[0m\u2502\n", "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"text/plain": ["[{'test_loss': 0.2596317529678345, 'test_acc': 0.9228000044822693}]"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["trainer.test()"]}, {"cell_type": "markdown", "id": "145bc32e", "metadata": {"papermill": {"duration": 0.007028, "end_time": "2023-01-05T11:17:49.858372", "exception": false, "start_time": "2023-01-05T11:17:49.851344", "status": "completed"}, "tags": []}, "source": ["### Bonus Tip\n", "\n", "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"]}, {"cell_type": "code", "execution_count": 8, "id": "42796164", "metadata": {"execution": {"iopub.execute_input": "2023-01-05T11:17:49.870643Z", "iopub.status.busy": "2023-01-05T11:17:49.870288Z", "iopub.status.idle": "2023-01-05T11:17:50.057166Z", "shell.execute_reply": "2023-01-05T11:17:50.056281Z"}, "papermill": {"duration": 0.195226, "end_time": "2023-01-05T11:17:50.059108", "exception": false, "start_time": "2023-01-05T11:17:49.863882", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:604: UserWarning: Checkpoint directory logs/lightning_logs/version_0/checkpoints exists and is not empty.\n", " rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "-----------------------------------------------------\n", "0 | model | Sequential | 55.1 K\n", "1 | val_accuracy | MulticlassAccuracy | 0 \n", "2 | test_accuracy | MulticlassAccuracy | 0 \n", "-----------------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "e3ae3881c87e4d2d874b0ff385de8837", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["trainer.fit(model)"]}, {"cell_type": "markdown", "id": "5c6c52f2", "metadata": {"papermill": {"duration": 0.005622, "end_time": "2023-01-05T11:17:50.072391", "exception": false, "start_time": "2023-01-05T11:17:50.066769", "status": "completed"}, "tags": []}, "source": ["In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"]}, {"cell_type": "code", "execution_count": 9, "id": "5371c3d1", "metadata": {"execution": {"iopub.execute_input": "2023-01-05T11:17:50.085310Z", "iopub.status.busy": "2023-01-05T11:17:50.084474Z", "iopub.status.idle": "2023-01-05T11:17:50.416356Z", "shell.execute_reply": "2023-01-05T11:17:50.415351Z"}, "papermill": {"duration": 0.340885, "end_time": "2023-01-05T11:17:50.418703", "exception": false, "start_time": "2023-01-05T11:17:50.077818", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/html": ["
\n", " | val_loss | \n", "val_acc | \n", "test_loss | \n", "test_acc | \n", "
---|---|---|---|---|
epoch | \n", "\n", " | \n", " | \n", " | \n", " |
0 | \n", "0.432111 | \n", "0.8884 | \n", "NaN | \n", "NaN | \n", "
1 | \n", "0.310814 | \n", "0.9124 | \n", "NaN | \n", "NaN | \n", "
2 | \n", "0.264833 | \n", "0.9224 | \n", "NaN | \n", "NaN | \n", "
2 | \n", "NaN | \n", "NaN | \n", "0.259632 | \n", "0.9228 | \n", "