{"cells": [{"cell_type": "markdown", "id": "9214cb4b", "metadata": {"papermill": {"duration": 0.005, "end_time": "2023-03-15T10:40:36.106989", "exception": false, "start_time": "2023-03-15T10:40:36.101989", "status": "completed"}, "tags": []}, "source": ["\n", "# PyTorch Lightning DataModules\n", "\n", "* **Author:** PL team\n", "* **License:** CC BY-SA\n", "* **Generated:** 2023-03-15T10:38:58.977380\n", "\n", "This notebook will walk you through how to start using Datamodules. With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`. The most up-to-date documentation on datamodules can be found [here](https://lightning.ai/docs/pytorch/stable/data/datamodule.html).\n", "\n", "---\n", "Open in [![Open In Colab](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAHUAAAAUCAYAAACzrHJDAAAIuUlEQVRoQ+1ZaVRURxb+qhdolmbTUVSURpZgmLhHbQVFZIlGQBEXcMvJhKiTEzfigjQg7oNEJ9GMGidnjnNMBs2czIzajksEFRE1xklCTKJiQLRFsUGkoUWw+82pamn79etGYoKek1B/4NW99/tu3e/dquJBAGD27NkHALxKf39WY39gyrOi+i3xqGtUoePJrFmznrmgtModorbTu8YRNZk5cybXTvCtwh7o6NR2KzuZMWNGh6jtVt7nA0ymT5/eJlF9POrh7PAQl6s8bGYa3PUum//htmebVtLRqW0q01M5keTk5FZFzU0oRle3+zxwg5Hgtb+PZiL/ZVohxCI+hL5JgjmfjPxZ26+33BG3dA+ealHPM4gQAo5rU59gsI8bRvl54t3Ca62mvHyUAhtOlLd5WSQpKcluBjumnoCLs1EARkVd9E8l3p9y2i7RbQ1B6pFwu/YDgW8KbHJHMTQrwnjz2oZm9M4pavOCfo5jWrgCaaMVcMs6/pNhDr0+AMN93XlxV7R6DNpyzi7W/OE+yIrsjU6rTrbKV5cd/pNyItOmTbMp6sbBB+EqaYJY4cWE3VUciNt1TpgfcRFv71Fi54xT5kSoyLvOBEJMOMxWXkFlBeBSX4u6Zkcs+3KszYRtiapbNRqF31UgetVuc8z9vBXIv1qD+F1f83B6uDlCUyfsZGepGPpmg01OB7EITQbhS9ribKy+DmP1DUiClLz4bnIHVOqa7BY+Z1wg5g3zgUvyehiNpnJKxSLc/ts76LKm0BzX3c0RNy1yXjDcB5lWoro4iNHQxM+f1kWeWQARAWQS++trISJTp061Kep25X/MycwtjuctSC5rxo7ppi7VNUox5+PhPHtrsS2O1qJ6yx1QujQUzm9sh6hbkBlvvGcN8hYnwjUjH6kjfZEd5c/jitz5Jc5U3ENnFynKl4eB7nyEgP2UZ+Yz3/rVEbyYr27qELrtC4FIC0J7sc7xWnmccdHfRRTs0VB+cA4lt+oFcRR/wUeH8FG5w2Mbx8FQ8TXEvv1xYf4wBP3O2WyL3/UVjpXWgIqaFeUPr+wTmDvUB7njH6/bOv+HRg4SqioAg5GDe1aB3ZeMTJkyRSBqkLsWqSEm0fZVBEN94zEZnYvrdx1JL5cxe+a+AbhSJecRRHW/ikTFRTa38dtQlNZ5CRKwFvUtZU/kvBoEF9Uxni/XqIM+dwKbTw3rhcxIf7gmr2M+H6SMwx8iBzJbw5oxeG3Lv5FX9B3AGaHPS8e8z77H7v9VMpvPG5ug1enh7eGK8h0LBTwUb+GInqzInlRUK65DmTPQu4c3+uQKjwKK77zwUxBX4Tq7yR1RuiwUsqlrABCM6esHdXoy47fk4+prYKy8ZF574x4V5BnHQBuf4g9Z9ld8U36L2aktZNNplNfw7zotwWTy5MkCUft4aLEopJj5/OPHl1BQqeAVOnHgNSQOqmBzq9V9cfEm/yx5ubMGKS9cYPZ3vx2OS/c6PVHUuUO7Y1Pci3BO/1zgq18byebfGemLtNF+6JRtOvMk926ibussZqM+1mNz4TWkH7rCbM5phwGRGDAaoF8fY5OHFnlldAA8sgoEXKnDukA1NgSeNjqkJT9brbN4pC9WRweYXyLugR73c+MYvyWfu0yC6+mjzN1Isfw3FKJS98CU/zI1IHFkFPR52cHL2FJk0sB6kMTERIGo9GzcPkLNfA0cwdwi/hfEYO86ZMd9w+y1egfM2T2Eh/vesMNwljSzuZRT420SW3eqy8N6aHMmwmnFUZ7/PGVPbIoNZvNU1BURdHs0bT2+HjL8sDSM2e6vi4Lj5NW8WOLVA6RTT2azxLV+bglaFNqLieqemS/gWkw7NyoAHo+2dEsiivengjKsPFoqWOvbSh/kxPaxyW/JRzH2Fl3EzD9/xjAefJqB3usKUFn/0Gb+S/d/jy3FN2yLOmnSJJtn6oehByEiHPSeXnDxFGPRnoFoaBJjcdQlbDwcjL1zTNuQpoxD7R0OG0uUTMi0fkVwdzBdYIwcwZunxrVJVLplNm54BZp7jfDfYLoNyqQi1K6KxIdHzmN+QQ2WjFIwUT2zTGdlRXo4NFXVUO4sgX5dFC7f0aP/ZlNeUjFBuL8Xjl6uRuP6aMjSjpjzsH62FDU7JhBuGccEXIvDfJFFBc/gHw80dklfCVYnRaDfpiJcutPA4F7qJsfJeUPQI+1fqMlNhFx1FM0GDqkjFVg7NojlQ0Vt4aM5ReSqcbpaCg8nCW5lRsBvbT4T1TLfFptsfh7gItzuKTdJSEiwKSrt1vcmnEXXrsLbYnWDA1bu+z2WKy9Arq+1KRqdfKsoBo0GcdtEpS/B1bO4v0cFiUhkjskvKcMrWwtAPHuwQq8Z+4LZ1vTQANfXt4J0DwZX9gWa9qh4XDM/voC9JXfwYEMMHJcfNtusn82ihvliVUwg5KrPGVf6GH94ZJpEZBen6EC4qYTHA1dXhW0JIex8txzv//c8lhzXIi/BFxOH9jGbQhZsRalTIBZZ8KkGyZAxeRQvXkFF1TWz/Hm46jNYUnjPbt3JxIkT7f6dSj8qfJJyVvBxgaIlblOyjtysNHWN9fjjqWi7glJfW3/S0Hlj2XnA8PhKT9w6g3Qx3XiXhvuxQsuT1proxBKI/AaZqY1Xz5muvY8G8XkRRCaHsfQsRAFDH/tZPbcYuHotOG0FRIqB4HR3wNVoIPLtz8ycTguu+jpEigE218vd1YCr5m+HpHMvEI9u4LTXwNWaLjl0iPwGAmIpeHx1VeCqTJdPs1/vweweQPO3HC24NhOhnTphwoQnfv6QSY2ICbkNmdSA4h87oaLaiYfn5diIEd4att2erOwJXbPUHp953p6orQVSUVWRAXBT8c/dJ5L9xhzaJGp71GR/wFP8P5V2z10NSC9T93QM2xUg8fHxT+zU9ijeU4naHon8CjFJXFzc8/kn+dN06q9QgF98SYSo2Xen2NjYZy5sR6f+4nLSK5Iam2PH/x87a1YN/t5sBgAAAABJRU5ErkJggg==){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/datamodules.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "52d6b0f8", "metadata": {"papermill": {"duration": 0.003497, "end_time": "2023-03-15T10:40:36.114445", "exception": false, "start_time": "2023-03-15T10:40:36.110948", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "4cda7c76", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2023-03-15T10:40:36.122514Z", "iopub.status.busy": "2023-03-15T10:40:36.122049Z", "iopub.status.idle": "2023-03-15T10:40:39.373356Z", "shell.execute_reply": "2023-03-15T10:40:39.372021Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 3.257442, "end_time": "2023-03-15T10:40:39.375326", "exception": false, "start_time": "2023-03-15T10:40:36.117884", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}], "source": ["! pip install --quiet \"torchmetrics>=0.7, <0.12\" \"setuptools==67.4.0\" \"ipython[notebook]>=8.0.0, <8.12.0\" \"torch>=1.8.1, <1.14.0\" \"torchvision\" \"pytorch-lightning>=1.4, <2.0.0\""]}, {"cell_type": "markdown", "id": "6de813cb", "metadata": {"papermill": {"duration": 0.003626, "end_time": "2023-03-15T10:40:39.383174", "exception": false, "start_time": "2023-03-15T10:40:39.379548", "status": "completed"}, "tags": []}, "source": ["## Introduction\n", "\n", "First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`"]}, {"cell_type": "code", "execution_count": 2, "id": "2179b1f1", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:40:39.392552Z", "iopub.status.busy": "2023-03-15T10:40:39.391720Z", "iopub.status.idle": "2023-03-15T10:40:41.953236Z", "shell.execute_reply": "2023-03-15T10:40:41.951759Z"}, "papermill": {"duration": 2.568674, "end_time": "2023-03-15T10:40:41.955375", "exception": false, "start_time": "2023-03-15T10:40:39.386701", "status": "completed"}, "tags": []}, "outputs": [], "source": ["import os\n", "\n", "import lightning as L\n", "import torch\n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.utils.data import DataLoader, random_split\n", "from torchmetrics.functional import accuracy\n", "from torchvision import transforms\n", "\n", "# Note - you must have torchvision installed for this example\n", "from torchvision.datasets import CIFAR10, MNIST\n", "\n", "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", "BATCH_SIZE = 256 if torch.cuda.is_available() else 64"]}, {"cell_type": "markdown", "id": "cac3d299", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.003759, "end_time": "2023-03-15T10:40:41.963128", "exception": false, "start_time": "2023-03-15T10:40:41.959369", "status": "completed"}, "tags": []}, "source": ["### Defining the LitMNISTModel\n", "\n", "Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n", "\n", "Unfortunately, we have hardcoded dataset-specific items within the model,\n", "forever limiting it to working with MNIST Data. \ud83d\ude22\n", "\n", "This is fine if you don't plan on training/evaluating your model on different datasets.\n", "However, in many cases, this can become bothersome when you want to try out your architecture with different datasets."]}, {"cell_type": "code", "execution_count": 3, "id": "709cf586", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:40:41.972713Z", "iopub.status.busy": "2023-03-15T10:40:41.971922Z", "iopub.status.idle": "2023-03-15T10:40:41.992949Z", "shell.execute_reply": "2023-03-15T10:40:41.992110Z"}, "papermill": {"duration": 0.027845, "end_time": "2023-03-15T10:40:41.994581", "exception": false, "start_time": "2023-03-15T10:40:41.966736", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class LitMNIST(L.LightningModule):\n", " def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):\n", " super().__init__()\n", "\n", " # We hardcode dataset specific stuff here.\n", " self.data_dir = data_dir\n", " self.num_classes = 10\n", " self.dims = (1, 28, 28)\n", " channels, width, height = self.dims\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", " ]\n", " )\n", "\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " # Build model\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, self.num_classes),\n", " )\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y, task=\"multiclass\", num_classes=10)\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log(\"val_acc\", acc, prog_bar=True)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer\n", "\n", " ####################\n", " # DATA RELATED HOOKS\n", " ####################\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=128)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=128)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=128)"]}, {"cell_type": "markdown", "id": "4a655491", "metadata": {"papermill": {"duration": 0.003778, "end_time": "2023-03-15T10:40:42.002076", "exception": false, "start_time": "2023-03-15T10:40:41.998298", "status": "completed"}, "tags": []}, "source": ["### Training the ListMNIST Model"]}, {"cell_type": "code", "execution_count": 4, "id": "29812281", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:40:42.010531Z", "iopub.status.busy": "2023-03-15T10:40:42.010333Z", "iopub.status.idle": "2023-03-15T10:41:10.455990Z", "shell.execute_reply": "2023-03-15T10:41:10.454725Z"}, "papermill": {"duration": 28.452752, "end_time": "2023-03-15T10:41:10.458522", "exception": false, "start_time": "2023-03-15T10:40:42.005770", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "bc45212939024bb4ba157142ddb0e9dc", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/9912422 [00:00