{"cells": [{"cell_type": "markdown", "id": "b2a5db02", "metadata": {"papermill": {"duration": 0.00674, "end_time": "2023-03-15T10:52:35.390992", "exception": false, "start_time": "2023-03-15T10:52:35.384252", "status": "completed"}, "tags": []}, "source": ["\n", "# Introduction to PyTorch Lightning\n", "\n", "* **Author:** PL team\n", "* **License:** CC BY-SA\n", "* **Generated:** 2023-03-15T10:51:00.876251\n", "\n", "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n", "\n", "---\n", "Open in [![Open In Colab](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAHUAAAAUCAYAAACzrHJDAAAIuUlEQVRoQ+1ZaVRURxb+qhdolmbTUVSURpZgmLhHbQVFZIlGQBEXcMvJhKiTEzfigjQg7oNEJ9GMGidnjnNMBs2czIzajksEFRE1xklCTKJiQLRFsUGkoUWw+82pamn79etGYoKek1B/4NW99/tu3e/dquJBAGD27NkHALxKf39WY39gyrOi+i3xqGtUoePJrFmznrmgtModorbTu8YRNZk5cybXTvCtwh7o6NR2KzuZMWNGh6jtVt7nA0ymT5/eJlF9POrh7PAQl6s8bGYa3PUum//htmebVtLRqW0q01M5keTk5FZFzU0oRle3+zxwg5Hgtb+PZiL/ZVohxCI+hL5JgjmfjPxZ26+33BG3dA+ealHPM4gQAo5rU59gsI8bRvl54t3Ca62mvHyUAhtOlLd5WSQpKcluBjumnoCLs1EARkVd9E8l3p9y2i7RbQ1B6pFwu/YDgW8KbHJHMTQrwnjz2oZm9M4pavOCfo5jWrgCaaMVcMs6/pNhDr0+AMN93XlxV7R6DNpyzi7W/OE+yIrsjU6rTrbKV5cd/pNyItOmTbMp6sbBB+EqaYJY4cWE3VUciNt1TpgfcRFv71Fi54xT5kSoyLvOBEJMOMxWXkFlBeBSX4u6Zkcs+3KszYRtiapbNRqF31UgetVuc8z9vBXIv1qD+F1f83B6uDlCUyfsZGepGPpmg01OB7EITQbhS9ribKy+DmP1DUiClLz4bnIHVOqa7BY+Z1wg5g3zgUvyehiNpnJKxSLc/ts76LKm0BzX3c0RNy1yXjDcB5lWoro4iNHQxM+f1kWeWQARAWQS++trISJTp061Kep25X/MycwtjuctSC5rxo7ppi7VNUox5+PhPHtrsS2O1qJ6yx1QujQUzm9sh6hbkBlvvGcN8hYnwjUjH6kjfZEd5c/jitz5Jc5U3ENnFynKl4eB7nyEgP2UZ+Yz3/rVEbyYr27qELrtC4FIC0J7sc7xWnmccdHfRRTs0VB+cA4lt+oFcRR/wUeH8FG5w2Mbx8FQ8TXEvv1xYf4wBP3O2WyL3/UVjpXWgIqaFeUPr+wTmDvUB7njH6/bOv+HRg4SqioAg5GDe1aB3ZeMTJkyRSBqkLsWqSEm0fZVBEN94zEZnYvrdx1JL5cxe+a+AbhSJecRRHW/ikTFRTa38dtQlNZ5CRKwFvUtZU/kvBoEF9Uxni/XqIM+dwKbTw3rhcxIf7gmr2M+H6SMwx8iBzJbw5oxeG3Lv5FX9B3AGaHPS8e8z77H7v9VMpvPG5ug1enh7eGK8h0LBTwUb+GInqzInlRUK65DmTPQu4c3+uQKjwKK77zwUxBX4Tq7yR1RuiwUsqlrABCM6esHdXoy47fk4+prYKy8ZF574x4V5BnHQBuf4g9Z9ld8U36L2aktZNNplNfw7zotwWTy5MkCUft4aLEopJj5/OPHl1BQqeAVOnHgNSQOqmBzq9V9cfEm/yx5ubMGKS9cYPZ3vx2OS/c6PVHUuUO7Y1Pci3BO/1zgq18byebfGemLtNF+6JRtOvMk926ibussZqM+1mNz4TWkH7rCbM5phwGRGDAaoF8fY5OHFnlldAA8sgoEXKnDukA1NgSeNjqkJT9brbN4pC9WRweYXyLugR73c+MYvyWfu0yC6+mjzN1Isfw3FKJS98CU/zI1IHFkFPR52cHL2FJk0sB6kMTERIGo9GzcPkLNfA0cwdwi/hfEYO86ZMd9w+y1egfM2T2Eh/vesMNwljSzuZRT420SW3eqy8N6aHMmwmnFUZ7/PGVPbIoNZvNU1BURdHs0bT2+HjL8sDSM2e6vi4Lj5NW8WOLVA6RTT2azxLV+bglaFNqLieqemS/gWkw7NyoAHo+2dEsiivengjKsPFoqWOvbSh/kxPaxyW/JRzH2Fl3EzD9/xjAefJqB3usKUFn/0Gb+S/d/jy3FN2yLOmnSJJtn6oehByEiHPSeXnDxFGPRnoFoaBJjcdQlbDwcjL1zTNuQpoxD7R0OG0uUTMi0fkVwdzBdYIwcwZunxrVJVLplNm54BZp7jfDfYLoNyqQi1K6KxIdHzmN+QQ2WjFIwUT2zTGdlRXo4NFXVUO4sgX5dFC7f0aP/ZlNeUjFBuL8Xjl6uRuP6aMjSjpjzsH62FDU7JhBuGccEXIvDfJFFBc/gHw80dklfCVYnRaDfpiJcutPA4F7qJsfJeUPQI+1fqMlNhFx1FM0GDqkjFVg7NojlQ0Vt4aM5ReSqcbpaCg8nCW5lRsBvbT4T1TLfFptsfh7gItzuKTdJSEiwKSrt1vcmnEXXrsLbYnWDA1bu+z2WKy9Arq+1KRqdfKsoBo0GcdtEpS/B1bO4v0cFiUhkjskvKcMrWwtAPHuwQq8Z+4LZ1vTQANfXt4J0DwZX9gWa9qh4XDM/voC9JXfwYEMMHJcfNtusn82ihvliVUwg5KrPGVf6GH94ZJpEZBen6EC4qYTHA1dXhW0JIex8txzv//c8lhzXIi/BFxOH9jGbQhZsRalTIBZZ8KkGyZAxeRQvXkFF1TWz/Hm46jNYUnjPbt3JxIkT7f6dSj8qfJJyVvBxgaIlblOyjtysNHWN9fjjqWi7glJfW3/S0Hlj2XnA8PhKT9w6g3Qx3XiXhvuxQsuT1proxBKI/AaZqY1Xz5muvY8G8XkRRCaHsfQsRAFDH/tZPbcYuHotOG0FRIqB4HR3wNVoIPLtz8ycTguu+jpEigE218vd1YCr5m+HpHMvEI9u4LTXwNWaLjl0iPwGAmIpeHx1VeCqTJdPs1/vweweQPO3HC24NhOhnTphwoQnfv6QSY2ICbkNmdSA4h87oaLaiYfn5diIEd4att2erOwJXbPUHp953p6orQVSUVWRAXBT8c/dJ5L9xhzaJGp71GR/wFP8P5V2z10NSC9T93QM2xUg8fHxT+zU9ijeU4naHon8CjFJXFzc8/kn+dN06q9QgF98SYSo2Xen2NjYZy5sR6f+4nLSK5Iam2PH/x87a1YN/t5sBgAAAABJRU5ErkJggg==){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/mnist-hello-world.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "34fded74", "metadata": {"papermill": {"duration": 0.004191, "end_time": "2023-03-15T10:52:35.400930", "exception": false, "start_time": "2023-03-15T10:52:35.396739", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "1677b0d2", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2023-03-15T10:52:35.407494Z", "iopub.status.busy": "2023-03-15T10:52:35.407141Z", "iopub.status.idle": "2023-03-15T10:52:38.694278Z", "shell.execute_reply": "2023-03-15T10:52:38.693256Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 3.293351, "end_time": "2023-03-15T10:52:38.696867", "exception": false, "start_time": "2023-03-15T10:52:35.403516", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}], "source": ["! pip install --quiet \"seaborn\" \"pytorch-lightning>=1.4, <2.0.0\" \"torchvision\" \"setuptools==67.4.0\" \"lightning>=2.0.0rc0\" \"ipython[notebook]>=8.0.0, <8.12.0\" \"pandas\" \"torchmetrics >=0.11.0\" \"torch>=1.8.1, <1.14.0\" \"torchmetrics>=0.7, <0.12\""]}, {"cell_type": "code", "execution_count": 2, "id": "61040850", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:52:38.704796Z", "iopub.status.busy": "2023-03-15T10:52:38.704420Z", "iopub.status.idle": "2023-03-15T10:52:41.976645Z", "shell.execute_reply": "2023-03-15T10:52:41.975673Z"}, "papermill": {"duration": 3.279766, "end_time": "2023-03-15T10:52:41.979880", "exception": false, "start_time": "2023-03-15T10:52:38.700114", "status": "completed"}, "tags": []}, "outputs": [], "source": ["import os\n", "\n", "import lightning as L\n", "import pandas as pd\n", "import seaborn as sn\n", "import torch\n", "from IPython.display import display\n", "from lightning.pytorch.loggers import CSVLogger\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils.data import DataLoader, random_split\n", "from torchmetrics import Accuracy\n", "from torchvision import transforms\n", "from torchvision.datasets import MNIST\n", "\n", "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", "BATCH_SIZE = 256 if torch.cuda.is_available() else 64"]}, {"cell_type": "markdown", "id": "525db095", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.003391, "end_time": "2023-03-15T10:52:41.989818", "exception": false, "start_time": "2023-03-15T10:52:41.986427", "status": "completed"}, "tags": []}, "source": ["## Simplest example\n", "\n", "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n", "\n", "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."]}, {"cell_type": "code", "execution_count": 3, "id": "d7fd5aa3", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:52:41.998177Z", "iopub.status.busy": "2023-03-15T10:52:41.997250Z", "iopub.status.idle": "2023-03-15T10:52:42.006185Z", "shell.execute_reply": "2023-03-15T10:52:42.005593Z"}, "papermill": {"duration": 0.015597, "end_time": "2023-03-15T10:52:42.008612", "exception": false, "start_time": "2023-03-15T10:52:41.993015", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class MNISTModel(L.LightningModule):\n", " def __init__(self):\n", " super().__init__()\n", " self.l1 = torch.nn.Linear(28 * 28, 10)\n", "\n", " def forward(self, x):\n", " return torch.relu(self.l1(x.view(x.size(0), -1)))\n", "\n", " def training_step(self, batch, batch_nb):\n", " x, y = batch\n", " loss = F.cross_entropy(self(x), y)\n", " return loss\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=0.02)"]}, {"cell_type": "markdown", "id": "d889d2c8", "metadata": {"papermill": {"duration": 0.003515, "end_time": "2023-03-15T10:52:42.018006", "exception": false, "start_time": "2023-03-15T10:52:42.014491", "status": "completed"}, "tags": []}, "source": ["By using the `Trainer` you automatically get:\n", "1. Tensorboard logging\n", "2. Model checkpointing\n", "3. Training and validation loop\n", "4. early-stopping"]}, {"cell_type": "code", "execution_count": 4, "id": "bb6cb1f2", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:52:42.027654Z", "iopub.status.busy": "2023-03-15T10:52:42.027265Z", "iopub.status.idle": "2023-03-15T10:53:01.085934Z", "shell.execute_reply": "2023-03-15T10:53:01.085233Z"}, "papermill": {"duration": 19.065087, "end_time": "2023-03-15T10:53:01.087714", "exception": false, "start_time": "2023-03-15T10:52:42.022627", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "434a8b77741c4088b87555e446eb4dde", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/9912422 [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/6/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw\n"]}, {"name": "stdout", "output_type": "stream", "text": ["\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "894c26e73bf044d9b653997ea187c20d", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/28881 [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/6/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw\n", "\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "363fe6c5bc4b4befbbea695dfeaf1312", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/1648877 [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/6/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw\n"]}, {"name": "stdout", "output_type": "stream", "text": ["\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "d89b268ad7de40328e613bb4ce8e2c64", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/4542 [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/6/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/6/s/.datasets/MNIST/raw\n", "\n"]}, {"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Missing logger folder: /__w/6/s/lightning_logs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | l1 | Linear | 7.9 K \n", "--------------------------------\n", "7.9 K Trainable params\n", "0 Non-trainable params\n", "7.9 K Total params\n", "0.031 Total estimated model params size (MB)\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:208: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "e5560260e56a461ab403391e3a751cba", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["# Init our model\n", "mnist_model = MNISTModel()\n", "\n", "# Init DataLoader from MNIST Dataset\n", "train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())\n", "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)\n", "\n", "# Initialize a trainer\n", "trainer = L.Trainer(\n", " accelerator=\"auto\",\n", " devices=1,\n", " max_epochs=3,\n", ")\n", "\n", "# Train the model \u26a1\n", "trainer.fit(mnist_model, train_loader)"]}, {"cell_type": "markdown", "id": "3026f08d", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.004221, "end_time": "2023-03-15T10:53:01.099138", "exception": false, "start_time": "2023-03-15T10:53:01.094917", "status": "completed"}, "tags": []}, "source": ["## A more complete MNIST Lightning Module Example\n", "\n", "That wasn't so hard was it?\n", "\n", "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n", "\n", "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`.\n", "This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n", "\n", "---\n", "\n", "### Note what the following built-in functions are doing:\n", "\n", "1. [prepare_data()](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#prepare-data) \ud83d\udcbe\n", " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", "\n", "2. [setup(stage)](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#setup) \u2699\ufe0f\n", " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).\n", " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n", " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", "\n", "3. [x_dataloader()](https://lightning.ai/docs/pytorch/stable/api/pytorch_lightning.core.hooks.DataHooks.html#pytorch_lightning.core.hooks.DataHooks.train_dataloader) \u267b\ufe0f\n", " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"]}, {"cell_type": "code", "execution_count": 5, "id": "f6c60301", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:53:01.110802Z", "iopub.status.busy": "2023-03-15T10:53:01.109929Z", "iopub.status.idle": "2023-03-15T10:53:01.132035Z", "shell.execute_reply": "2023-03-15T10:53:01.131061Z"}, "papermill": {"duration": 0.031258, "end_time": "2023-03-15T10:53:01.134236", "exception": false, "start_time": "2023-03-15T10:53:01.102978", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class LitMNIST(L.LightningModule):\n", " def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):\n", " super().__init__()\n", "\n", " # Set our init args as class attributes\n", " self.data_dir = data_dir\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " # Hardcode some dataset specific attributes\n", " self.num_classes = 10\n", " self.dims = (1, 28, 28)\n", " channels, width, height = self.dims\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", " ]\n", " )\n", "\n", " # Define PyTorch model\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, self.num_classes),\n", " )\n", "\n", " self.val_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", " self.test_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " self.val_accuracy.update(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log(\"val_acc\", self.val_accuracy, prog_bar=True)\n", "\n", " def test_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " self.test_accuracy.update(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log(\"test_loss\", loss, prog_bar=True)\n", " self.log(\"test_acc\", self.test_accuracy, prog_bar=True)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer\n", "\n", " ####################\n", " # DATA RELATED HOOKS\n", " ####################\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)"]}, {"cell_type": "code", "execution_count": 6, "id": "2eae34ee", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:53:01.146924Z", "iopub.status.busy": "2023-03-15T10:53:01.146587Z", "iopub.status.idle": "2023-03-15T10:53:27.802216Z", "shell.execute_reply": "2023-03-15T10:53:27.801283Z"}, "papermill": {"duration": 26.663806, "end_time": "2023-03-15T10:53:27.804001", "exception": false, "start_time": "2023-03-15T10:53:01.140195", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Missing logger folder: logs/lightning_logs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "-----------------------------------------------------\n", "0 | model | Sequential | 55.1 K\n", "1 | val_accuracy | MulticlassAccuracy | 0 \n", "2 | test_accuracy | MulticlassAccuracy | 0 \n", "-----------------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "a33a0f29dd0748d382eed219a41f9ccf", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:208: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "f96fabb0dd564a378ad0d2776e1f676b", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "e4e66a4868694a79b0757ae75e5cdc4d", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "7adcbd5a28fe44d3acac9f25d66cb166", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "c82bd2cc5c15493cacc609a922450910", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["model = LitMNIST()\n", "trainer = L.Trainer(\n", " accelerator=\"auto\",\n", " devices=1,\n", " max_epochs=3,\n", " logger=CSVLogger(save_dir=\"logs/\"),\n", ")\n", "trainer.fit(model)"]}, {"cell_type": "markdown", "id": "19510931", "metadata": {"papermill": {"duration": 0.0062, "end_time": "2023-03-15T10:53:27.817925", "exception": false, "start_time": "2023-03-15T10:53:27.811725", "status": "completed"}, "tags": []}, "source": ["### Testing\n", "\n", "To test a model, call `trainer.test(model)`.\n", "\n", "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically\n", "test using the best saved checkpoint (conditioned on val_loss)."]}, {"cell_type": "code", "execution_count": 7, "id": "9a149c25", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:53:27.829422Z", "iopub.status.busy": "2023-03-15T10:53:27.829050Z", "iopub.status.idle": "2023-03-15T10:53:29.221980Z", "shell.execute_reply": "2023-03-15T10:53:29.221429Z"}, "papermill": {"duration": 1.400795, "end_time": "2023-03-15T10:53:29.223522", "exception": false, "start_time": "2023-03-15T10:53:27.822727", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:148: UserWarning: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `.test(ckpt_path='best')` to use the best model or `.test(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.\n", " rank_zero_warn(\n", "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Restoring states from the checkpoint path at logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Lightning automatically upgraded your loaded checkpoint from v2.0.0rc0 to v2.0.0rc0. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint --file logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt`\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Loaded model weights from the checkpoint at logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:208: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "e9eec3383c3e4def9e5660d71f568139", "version_major": 2, "version_minor": 0}, "text/plain": ["Testing: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"text/html": ["
\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", "\u2503 Test metric \u2503 DataLoader 0 \u2503\n", "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n", "\u2502 test_acc \u2502 0.9243000149726868 \u2502\n", "\u2502 test_loss \u2502 0.26116958260536194 \u2502\n", "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n", "\n"], "text/plain": ["\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", "\u2503\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m\u2503\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m\u2503\n", "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n", "\u2502\u001b[36m \u001b[0m\u001b[36m test_acc \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.9243000149726868 \u001b[0m\u001b[35m \u001b[0m\u2502\n", "\u2502\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.26116958260536194 \u001b[0m\u001b[35m \u001b[0m\u2502\n", "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"text/plain": ["[{'test_loss': 0.26116958260536194, 'test_acc': 0.9243000149726868}]"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["trainer.test()"]}, {"cell_type": "markdown", "id": "1a84b96f", "metadata": {"papermill": {"duration": 0.00662, "end_time": "2023-03-15T10:53:29.238296", "exception": false, "start_time": "2023-03-15T10:53:29.231676", "status": "completed"}, "tags": []}, "source": ["### Bonus Tip\n", "\n", "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"]}, {"cell_type": "code", "execution_count": 8, "id": "f07b46d0", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:53:29.250657Z", "iopub.status.busy": "2023-03-15T10:53:29.250287Z", "iopub.status.idle": "2023-03-15T10:53:29.463735Z", "shell.execute_reply": "2023-03-15T10:53:29.462980Z"}, "papermill": {"duration": 0.222294, "end_time": "2023-03-15T10:53:29.466001", "exception": false, "start_time": "2023-03-15T10:53:29.243707", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:612: UserWarning: Checkpoint directory logs/lightning_logs/version_0/checkpoints exists and is not empty.\n", " rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "-----------------------------------------------------\n", "0 | model | Sequential | 55.1 K\n", "1 | val_accuracy | MulticlassAccuracy | 0 \n", "2 | test_accuracy | MulticlassAccuracy | 0 \n", "-----------------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "8440d2a25b1340d0a9e801aa07ffe7fa", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["trainer.fit(model)"]}, {"cell_type": "markdown", "id": "0bce274e", "metadata": {"papermill": {"duration": 0.007134, "end_time": "2023-03-15T10:53:29.481972", "exception": false, "start_time": "2023-03-15T10:53:29.474838", "status": "completed"}, "tags": []}, "source": ["In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"]}, {"cell_type": "code", "execution_count": 9, "id": "156662bf", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:53:29.494949Z", "iopub.status.busy": "2023-03-15T10:53:29.494491Z", "iopub.status.idle": "2023-03-15T10:53:30.049467Z", "shell.execute_reply": "2023-03-15T10:53:30.048869Z"}, "papermill": {"duration": 0.564254, "end_time": "2023-03-15T10:53:30.051905", "exception": false, "start_time": "2023-03-15T10:53:29.487651", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/html": ["
\n", " | val_loss | \n", "val_acc | \n", "test_loss | \n", "test_acc | \n", "
---|---|---|---|---|
epoch | \n", "\n", " | \n", " | \n", " | \n", " |
0 | \n", "0.425393 | \n", "0.8854 | \n", "NaN | \n", "NaN | \n", "
1 | \n", "0.308394 | \n", "0.9058 | \n", "NaN | \n", "NaN | \n", "
2 | \n", "0.265774 | \n", "0.9200 | \n", "NaN | \n", "NaN | \n", "
3 | \n", "NaN | \n", "NaN | \n", "0.26117 | \n", "0.9243 | \n", "