{"cells": [{"cell_type": "markdown", "id": "3a99e833", "metadata": {"papermill": {"duration": 0.005085, "end_time": "2022-08-15T07:29:47.607922", "exception": false, "start_time": "2022-08-15T07:29:47.602837", "status": "completed"}, "tags": []}, "source": ["\n", "# PyTorch Lightning DataModules\n", "\n", "* **Author:** PL team\n", "* **License:** CC BY-SA\n", "* **Generated:** 2022-08-15T09:28:45.750977\n", "\n", "This notebook will walk you through how to start using Datamodules. With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`. The most up-to-date documentation on datamodules can be found [here](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html).\n", "\n", "---\n", "Open in [![Open In Colab](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAHUAAAAUCAYAAACzrHJDAAAIuUlEQVRoQ+1ZaVRURxb+qhdolmbTUVSURpZgmLhHbQVFZIlGQBEXcMvJhKiTEzfigjQg7oNEJ9GMGidnjnNMBs2czIzajksEFRE1xklCTKJiQLRFsUGkoUWw+82pamn79etGYoKek1B/4NW99/tu3e/dquJBAGD27NkHALxKf39WY39gyrOi+i3xqGtUoePJrFmznrmgtModorbTu8YRNZk5cybXTvCtwh7o6NR2KzuZMWNGh6jtVt7nA0ymT5/eJlF9POrh7PAQl6s8bGYa3PUum//htmebVtLRqW0q01M5keTk5FZFzU0oRle3+zxwg5Hgtb+PZiL/ZVohxCI+hL5JgjmfjPxZ26+33BG3dA+ealHPM4gQAo5rU59gsI8bRvl54t3Ca62mvHyUAhtOlLd5WSQpKcluBjumnoCLs1EARkVd9E8l3p9y2i7RbQ1B6pFwu/YDgW8KbHJHMTQrwnjz2oZm9M4pavOCfo5jWrgCaaMVcMs6/pNhDr0+AMN93XlxV7R6DNpyzi7W/OE+yIrsjU6rTrbKV5cd/pNyItOmTbMp6sbBB+EqaYJY4cWE3VUciNt1TpgfcRFv71Fi54xT5kSoyLvOBEJMOMxWXkFlBeBSX4u6Zkcs+3KszYRtiapbNRqF31UgetVuc8z9vBXIv1qD+F1f83B6uDlCUyfsZGepGPpmg01OB7EITQbhS9ribKy+DmP1DUiClLz4bnIHVOqa7BY+Z1wg5g3zgUvyehiNpnJKxSLc/ts76LKm0BzX3c0RNy1yXjDcB5lWoro4iNHQxM+f1kWeWQARAWQS++trISJTp061Kep25X/MycwtjuctSC5rxo7ppi7VNUox5+PhPHtrsS2O1qJ6yx1QujQUzm9sh6hbkBlvvGcN8hYnwjUjH6kjfZEd5c/jitz5Jc5U3ENnFynKl4eB7nyEgP2UZ+Yz3/rVEbyYr27qELrtC4FIC0J7sc7xWnmccdHfRRTs0VB+cA4lt+oFcRR/wUeH8FG5w2Mbx8FQ8TXEvv1xYf4wBP3O2WyL3/UVjpXWgIqaFeUPr+wTmDvUB7njH6/bOv+HRg4SqioAg5GDe1aB3ZeMTJkyRSBqkLsWqSEm0fZVBEN94zEZnYvrdx1JL5cxe+a+AbhSJecRRHW/ikTFRTa38dtQlNZ5CRKwFvUtZU/kvBoEF9Uxni/XqIM+dwKbTw3rhcxIf7gmr2M+H6SMwx8iBzJbw5oxeG3Lv5FX9B3AGaHPS8e8z77H7v9VMpvPG5ug1enh7eGK8h0LBTwUb+GInqzInlRUK65DmTPQu4c3+uQKjwKK77zwUxBX4Tq7yR1RuiwUsqlrABCM6esHdXoy47fk4+prYKy8ZF574x4V5BnHQBuf4g9Z9ld8U36L2aktZNNplNfw7zotwWTy5MkCUft4aLEopJj5/OPHl1BQqeAVOnHgNSQOqmBzq9V9cfEm/yx5ubMGKS9cYPZ3vx2OS/c6PVHUuUO7Y1Pci3BO/1zgq18byebfGemLtNF+6JRtOvMk926ibussZqM+1mNz4TWkH7rCbM5phwGRGDAaoF8fY5OHFnlldAA8sgoEXKnDukA1NgSeNjqkJT9brbN4pC9WRweYXyLugR73c+MYvyWfu0yC6+mjzN1Isfw3FKJS98CU/zI1IHFkFPR52cHL2FJk0sB6kMTERIGo9GzcPkLNfA0cwdwi/hfEYO86ZMd9w+y1egfM2T2Eh/vesMNwljSzuZRT420SW3eqy8N6aHMmwmnFUZ7/PGVPbIoNZvNU1BURdHs0bT2+HjL8sDSM2e6vi4Lj5NW8WOLVA6RTT2azxLV+bglaFNqLieqemS/gWkw7NyoAHo+2dEsiivengjKsPFoqWOvbSh/kxPaxyW/JRzH2Fl3EzD9/xjAefJqB3usKUFn/0Gb+S/d/jy3FN2yLOmnSJJtn6oehByEiHPSeXnDxFGPRnoFoaBJjcdQlbDwcjL1zTNuQpoxD7R0OG0uUTMi0fkVwdzBdYIwcwZunxrVJVLplNm54BZp7jfDfYLoNyqQi1K6KxIdHzmN+QQ2WjFIwUT2zTGdlRXo4NFXVUO4sgX5dFC7f0aP/ZlNeUjFBuL8Xjl6uRuP6aMjSjpjzsH62FDU7JhBuGccEXIvDfJFFBc/gHw80dklfCVYnRaDfpiJcutPA4F7qJsfJeUPQI+1fqMlNhFx1FM0GDqkjFVg7NojlQ0Vt4aM5ReSqcbpaCg8nCW5lRsBvbT4T1TLfFptsfh7gItzuKTdJSEiwKSrt1vcmnEXXrsLbYnWDA1bu+z2WKy9Arq+1KRqdfKsoBo0GcdtEpS/B1bO4v0cFiUhkjskvKcMrWwtAPHuwQq8Z+4LZ1vTQANfXt4J0DwZX9gWa9qh4XDM/voC9JXfwYEMMHJcfNtusn82ihvliVUwg5KrPGVf6GH94ZJpEZBen6EC4qYTHA1dXhW0JIex8txzv//c8lhzXIi/BFxOH9jGbQhZsRalTIBZZ8KkGyZAxeRQvXkFF1TWz/Hm46jNYUnjPbt3JxIkT7f6dSj8qfJJyVvBxgaIlblOyjtysNHWN9fjjqWi7glJfW3/S0Hlj2XnA8PhKT9w6g3Qx3XiXhvuxQsuT1proxBKI/AaZqY1Xz5muvY8G8XkRRCaHsfQsRAFDH/tZPbcYuHotOG0FRIqB4HR3wNVoIPLtz8ycTguu+jpEigE218vd1YCr5m+HpHMvEI9u4LTXwNWaLjl0iPwGAmIpeHx1VeCqTJdPs1/vweweQPO3HC24NhOhnTphwoQnfv6QSY2ICbkNmdSA4h87oaLaiYfn5diIEd4att2erOwJXbPUHp953p6orQVSUVWRAXBT8c/dJ5L9xhzaJGp71GR/wFP8P5V2z10NSC9T93QM2xUg8fHxT+zU9ijeU4naHon8CjFJXFzc8/kn+dN06q9QgF98SYSo2Xen2NjYZy5sR6f+4nLSK5Iam2PH/x87a1YN/t5sBgAAAABJRU5ErkJggg==){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/datamodules.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "26a673e7", "metadata": {"papermill": {"duration": 0.002829, "end_time": "2022-08-15T07:29:47.614057", "exception": false, "start_time": "2022-08-15T07:29:47.611228", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "10b5261a", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2022-08-15T07:29:47.621509Z", "iopub.status.busy": "2022-08-15T07:29:47.620801Z", "iopub.status.idle": "2022-08-15T07:29:51.215440Z", "shell.execute_reply": "2022-08-15T07:29:51.214432Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 3.600823, "end_time": "2022-08-15T07:29:51.217684", "exception": false, "start_time": "2022-08-15T07:29:47.616861", "status": "completed"}, "tags": []}, "outputs": [], "source": ["! pip install --quiet \"torchmetrics>=0.7\" \"ipython[notebook]\" \"setuptools==59.5.0\" \"torch>=1.8\" \"pytorch-lightning>=1.4\" \"torchvision\""]}, {"cell_type": "markdown", "id": "e1f9ff6c", "metadata": {"papermill": {"duration": 0.002944, "end_time": "2022-08-15T07:29:51.224176", "exception": false, "start_time": "2022-08-15T07:29:51.221232", "status": "completed"}, "tags": []}, "source": ["## Introduction\n", "\n", "First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`"]}, {"cell_type": "code", "execution_count": 2, "id": "781d3fad", "metadata": {"execution": {"iopub.execute_input": "2022-08-15T07:29:51.231307Z", "iopub.status.busy": "2022-08-15T07:29:51.230775Z", "iopub.status.idle": "2022-08-15T07:29:55.072790Z", "shell.execute_reply": "2022-08-15T07:29:55.071853Z"}, "papermill": {"duration": 3.847887, "end_time": "2022-08-15T07:29:55.074824", "exception": false, "start_time": "2022-08-15T07:29:51.226937", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["WARNING:root:Bagua cannot detect bundled NCCL library, Bagua will try to use system NCCL instead. If you encounter any error, please run `import bagua_core; bagua_core.install_deps()` or the `bagua_install_deps.py` script to install bundled libraries.\n"]}], "source": ["import os\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from pytorch_lightning import LightningDataModule, LightningModule, Trainer\n", "from pytorch_lightning.callbacks.progress import TQDMProgressBar\n", "from torch import nn\n", "from torch.utils.data import DataLoader, random_split\n", "from torchmetrics.functional import accuracy\n", "from torchvision import transforms\n", "\n", "# Note - you must have torchvision installed for this example\n", "from torchvision.datasets import CIFAR10, MNIST\n", "\n", "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", "BATCH_SIZE = 256 if torch.cuda.is_available() else 64"]}, {"cell_type": "markdown", "id": "44efa955", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.003021, "end_time": "2022-08-15T07:29:55.081483", "exception": false, "start_time": "2022-08-15T07:29:55.078462", "status": "completed"}, "tags": []}, "source": ["### Defining the LitMNISTModel\n", "\n", "Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n", "\n", "Unfortunately, we have hardcoded dataset-specific items within the model,\n", "forever limiting it to working with MNIST Data. \ud83d\ude22\n", "\n", "This is fine if you don't plan on training/evaluating your model on different datasets.\n", "However, in many cases, this can become bothersome when you want to try out your architecture with different datasets."]}, {"cell_type": "code", "execution_count": 3, "id": "61831b14", "metadata": {"execution": {"iopub.execute_input": "2022-08-15T07:29:55.090065Z", "iopub.status.busy": "2022-08-15T07:29:55.089477Z", "iopub.status.idle": "2022-08-15T07:29:55.101643Z", "shell.execute_reply": "2022-08-15T07:29:55.101010Z"}, "papermill": {"duration": 0.018774, "end_time": "2022-08-15T07:29:55.103106", "exception": false, "start_time": "2022-08-15T07:29:55.084332", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class LitMNIST(LightningModule):\n", " def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):\n", "\n", " super().__init__()\n", "\n", " # We hardcode dataset specific stuff here.\n", " self.data_dir = data_dir\n", " self.num_classes = 10\n", " self.dims = (1, 28, 28)\n", " channels, width, height = self.dims\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", " ]\n", " )\n", "\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " # Build model\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, self.num_classes),\n", " )\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y)\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log(\"val_acc\", acc, prog_bar=True)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer\n", "\n", " ####################\n", " # DATA RELATED HOOKS\n", " ####################\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", "\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=128)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=128)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=128)"]}, {"cell_type": "markdown", "id": "9b8bba65", "metadata": {"papermill": {"duration": 0.002895, "end_time": "2022-08-15T07:29:55.109180", "exception": false, "start_time": "2022-08-15T07:29:55.106285", "status": "completed"}, "tags": []}, "source": ["### Training the ListMNIST Model"]}, {"cell_type": "code", "execution_count": 4, "id": "f1e68a81", "metadata": {"execution": {"iopub.execute_input": "2022-08-15T07:29:55.116110Z", "iopub.status.busy": "2022-08-15T07:29:55.115667Z", "iopub.status.idle": "2022-08-15T07:30:24.034392Z", "shell.execute_reply": "2022-08-15T07:30:24.033695Z"}, "papermill": {"duration": 28.923829, "end_time": "2022-08-15T07:30:24.035856", "exception": false, "start_time": "2022-08-15T07:29:55.112027", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "-------------------------------------\n", "0 | model | Sequential | 55.1 K\n", "-------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "ed4af29f643a43cbb0c9087a4795c714", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:219: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:219: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "2062fc4ac00b48b68a03073231da29db", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "82b1c1d9b6c44bff845e31464787cc65", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "115025c3b49c45fba4c3a479758b07e8", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=2` reached.\n"]}], "source": ["model = LitMNIST()\n", "trainer = Trainer(\n", " max_epochs=2,\n", " accelerator=\"auto\",\n", " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", ")\n", "trainer.fit(model)"]}, {"cell_type": "markdown", "id": "35ec750c", "metadata": {"papermill": {"duration": 0.003823, "end_time": "2022-08-15T07:30:24.043888", "exception": false, "start_time": "2022-08-15T07:30:24.040065", "status": "completed"}, "tags": []}, "source": ["## Using DataModules\n", "\n", "DataModules are a way of decoupling data-related hooks from the `LightningModule\n", "` so you can develop dataset agnostic models."]}, {"cell_type": "markdown", "id": "496631fd", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.003622, "end_time": "2022-08-15T07:30:24.051116", "exception": false, "start_time": "2022-08-15T07:30:24.047494", "status": "completed"}, "tags": []}, "source": ["### Defining The MNISTDataModule\n", "\n", "Let's go over each function in the class below and talk about what they're doing:\n", "\n", "1. ```__init__```\n", " - Takes in a `data_dir` arg that points to where you have downloaded/wish to download the MNIST dataset.\n", " - Defines a transform that will be applied across train, val, and test dataset splits.\n", " - Defines default `self.dims`.\n", "\n", "\n", "2. ```prepare_data```\n", " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", "\n", "3. ```setup```\n", " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).\n", " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage`.\n", " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", "\n", "\n", "4. ```x_dataloader```\n", " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"]}, {"cell_type": "code", "execution_count": 5, "id": "1a2eb03e", "metadata": {"execution": {"iopub.execute_input": "2022-08-15T07:30:24.065309Z", "iopub.status.busy": "2022-08-15T07:30:24.064699Z", "iopub.status.idle": "2022-08-15T07:30:24.077192Z", "shell.execute_reply": "2022-08-15T07:30:24.075890Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.023908, "end_time": "2022-08-15T07:30:24.078753", "exception": false, "start_time": "2022-08-15T07:30:24.054845", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class MNISTDataModule(LightningDataModule):\n", " def __init__(self, data_dir: str = PATH_DATASETS):\n", " super().__init__()\n", " self.data_dir = data_dir\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", " ]\n", " )\n", "\n", " self.dims = (1, 28, 28)\n", " self.num_classes = 10\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", "\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)"]}, {"cell_type": "markdown", "id": "08c04f24", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.004621, "end_time": "2022-08-15T07:30:24.087451", "exception": false, "start_time": "2022-08-15T07:30:24.082830", "status": "completed"}, "tags": []}, "source": ["### Defining the dataset agnostic `LitModel`\n", "\n", "Below, we define the same model as the `LitMNIST` model we made earlier.\n", "\n", "However, this time our model has the freedom to use any input data that we'd like \ud83d\udd25."]}, {"cell_type": "code", "execution_count": 6, "id": "d3374d8f", "metadata": {"execution": {"iopub.execute_input": "2022-08-15T07:30:24.096691Z", "iopub.status.busy": "2022-08-15T07:30:24.095996Z", "iopub.status.idle": "2022-08-15T07:30:24.104062Z", "shell.execute_reply": "2022-08-15T07:30:24.103362Z"}, "papermill": {"duration": 0.014248, "end_time": "2022-08-15T07:30:24.105574", "exception": false, "start_time": "2022-08-15T07:30:24.091326", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class LitModel(LightningModule):\n", " def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n", "\n", " super().__init__()\n", "\n", " # We take in input dimensions as parameters and use those to dynamically build model.\n", " self.channels = channels\n", " self.width = width\n", " self.height = height\n", " self.num_classes = num_classes\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, num_classes),\n", " )\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y)\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log(\"val_acc\", acc, prog_bar=True)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer"]}, {"cell_type": "markdown", "id": "6ddfd21b", "metadata": {"papermill": {"duration": 0.003685, "end_time": "2022-08-15T07:30:24.113034", "exception": false, "start_time": "2022-08-15T07:30:24.109349", "status": "completed"}, "tags": []}, "source": ["### Training the `LitModel` using the `MNISTDataModule`\n", "\n", "Now, we initialize and train the `LitModel` using the `MNISTDataModule`'s configuration settings and dataloaders."]}, {"cell_type": "code", "execution_count": 7, "id": "fe45fa8a", "metadata": {"execution": {"iopub.execute_input": "2022-08-15T07:30:24.122883Z", "iopub.status.busy": "2022-08-15T07:30:24.122354Z", "iopub.status.idle": "2022-08-15T07:30:58.818315Z", "shell.execute_reply": "2022-08-15T07:30:58.817588Z"}, "papermill": {"duration": 34.702048, "end_time": "2022-08-15T07:30:58.819831", "exception": false, "start_time": "2022-08-15T07:30:24.117783", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "-------------------------------------\n", "0 | model | Sequential | 55.1 K\n", "-------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "70a22add91db405298bc3c9d30fb2200", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "d0f022ec8e814bdfbaf4a6496121a8f1", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "10fd574e2b804b1cbc2665da2542a769", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "4e0556587a09488f8abf90d4dbd16689", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "9b8608794b90448088c36ddf9b8e1f4f", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["# Init DataModule\n", "dm = MNISTDataModule()\n", "# Init model from datamodule's attributes\n", "model = LitModel(*dm.dims, dm.num_classes)\n", "# Init trainer\n", "trainer = Trainer(\n", " max_epochs=3,\n", " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", " accelerator=\"auto\",\n", " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", ")\n", "# Pass the datamodule as arg to trainer.fit to override model hooks :)\n", "trainer.fit(model, dm)"]}, {"cell_type": "markdown", "id": "bc49d23e", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.004545, "end_time": "2022-08-15T07:30:58.830298", "exception": false, "start_time": "2022-08-15T07:30:58.825753", "status": "completed"}, "tags": []}, "source": ["### Defining the CIFAR10 DataModule\n", "\n", "Lets prove the `LitModel` we made earlier is dataset agnostic by defining a new datamodule for the CIFAR10 dataset."]}, {"cell_type": "code", "execution_count": 8, "id": "01cc9483", "metadata": {"execution": {"iopub.execute_input": "2022-08-15T07:30:58.841113Z", "iopub.status.busy": "2022-08-15T07:30:58.840398Z", "iopub.status.idle": "2022-08-15T07:30:58.850388Z", "shell.execute_reply": "2022-08-15T07:30:58.847924Z"}, "papermill": {"duration": 0.017475, "end_time": "2022-08-15T07:30:58.852231", "exception": false, "start_time": "2022-08-15T07:30:58.834756", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class CIFAR10DataModule(LightningDataModule):\n", " def __init__(self, data_dir: str = \"./\"):\n", " super().__init__()\n", " self.data_dir = data_dir\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", " ]\n", " )\n", "\n", " self.dims = (3, 32, 32)\n", " self.num_classes = 10\n", "\n", " def prepare_data(self):\n", " # download\n", " CIFAR10(self.data_dir, train=True, download=True)\n", " CIFAR10(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", "\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)\n", " self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage is None:\n", " self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.cifar_train, batch_size=BATCH_SIZE)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.cifar_val, batch_size=BATCH_SIZE)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.cifar_test, batch_size=BATCH_SIZE)"]}, {"cell_type": "markdown", "id": "c882f322", "metadata": {"papermill": {"duration": 0.006075, "end_time": "2022-08-15T07:30:58.862735", "exception": false, "start_time": "2022-08-15T07:30:58.856660", "status": "completed"}, "tags": []}, "source": ["### Training the `LitModel` using the `CIFAR10DataModule`\n", "\n", "Our model isn't very good, so it will perform pretty badly on the CIFAR10 dataset.\n", "\n", "The point here is that we can see that our `LitModel` has no problem using a different datamodule as its input data."]}, {"cell_type": "code", "execution_count": 9, "id": "8bf5ec16", "metadata": {"execution": {"iopub.execute_input": "2022-08-15T07:30:58.873771Z", "iopub.status.busy": "2022-08-15T07:30:58.873096Z", "iopub.status.idle": "2022-08-15T07:32:05.832705Z", "shell.execute_reply": "2022-08-15T07:32:05.831814Z"}, "papermill": {"duration": 66.967081, "end_time": "2022-08-15T07:32:05.834221", "exception": false, "start_time": "2022-08-15T07:30:58.867140", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "5b3bbee29fbb4ba1805820826297fc61", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/170498071 [00:00