{"cells": [{"cell_type": "markdown", "id": "9214cb4b", "metadata": {"papermill": {"duration": 0.005, "end_time": "2023-03-15T10:40:36.106989", "exception": false, "start_time": "2023-03-15T10:40:36.101989", "status": "completed"}, "tags": []}, "source": ["\n", "# PyTorch Lightning DataModules\n", "\n", "* **Author:** PL team\n", "* **License:** CC BY-SA\n", "* **Generated:** 2023-03-15T10:38:58.977380\n", "\n", "This notebook will walk you through how to start using Datamodules. With the release of `pytorch-lightning` version 0.9.0, we have included a new class called `LightningDataModule` to help you decouple data related hooks from your `LightningModule`. The most up-to-date documentation on datamodules can be found [here](https://lightning.ai/docs/pytorch/stable/data/datamodule.html).\n", "\n", "---\n", "Open in [![Open In Colab](){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/datamodules.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "52d6b0f8", "metadata": {"papermill": {"duration": 0.003497, "end_time": "2023-03-15T10:40:36.114445", "exception": false, "start_time": "2023-03-15T10:40:36.110948", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "4cda7c76", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2023-03-15T10:40:36.122514Z", "iopub.status.busy": "2023-03-15T10:40:36.122049Z", "iopub.status.idle": "2023-03-15T10:40:39.373356Z", "shell.execute_reply": "2023-03-15T10:40:39.372021Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 3.257442, "end_time": "2023-03-15T10:40:39.375326", "exception": false, "start_time": "2023-03-15T10:40:36.117884", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}], "source": ["! pip install --quiet \"torchmetrics>=0.7, <0.12\" \"setuptools==67.4.0\" \"ipython[notebook]>=8.0.0, <8.12.0\" \"torch>=1.8.1, <1.14.0\" \"torchvision\" \"pytorch-lightning>=1.4, <2.0.0\""]}, {"cell_type": "markdown", "id": "6de813cb", "metadata": {"papermill": {"duration": 0.003626, "end_time": "2023-03-15T10:40:39.383174", "exception": false, "start_time": "2023-03-15T10:40:39.379548", "status": "completed"}, "tags": []}, "source": ["## Introduction\n", "\n", "First, we'll go over a regular `LightningModule` implementation without the use of a `LightningDataModule`"]}, {"cell_type": "code", "execution_count": 2, "id": "2179b1f1", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:40:39.392552Z", "iopub.status.busy": "2023-03-15T10:40:39.391720Z", "iopub.status.idle": "2023-03-15T10:40:41.953236Z", "shell.execute_reply": "2023-03-15T10:40:41.951759Z"}, "papermill": {"duration": 2.568674, "end_time": "2023-03-15T10:40:41.955375", "exception": false, "start_time": "2023-03-15T10:40:39.386701", "status": "completed"}, "tags": []}, "outputs": [], "source": ["import os\n", "\n", "import lightning as L\n", "import torch\n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.utils.data import DataLoader, random_split\n", "from torchmetrics.functional import accuracy\n", "from torchvision import transforms\n", "\n", "# Note - you must have torchvision installed for this example\n", "from torchvision.datasets import CIFAR10, MNIST\n", "\n", "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", "BATCH_SIZE = 256 if torch.cuda.is_available() else 64"]}, {"cell_type": "markdown", "id": "cac3d299", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.003759, "end_time": "2023-03-15T10:40:41.963128", "exception": false, "start_time": "2023-03-15T10:40:41.959369", "status": "completed"}, "tags": []}, "source": ["### Defining the LitMNISTModel\n", "\n", "Below, we reuse a `LightningModule` from our hello world tutorial that classifies MNIST Handwritten Digits.\n", "\n", "Unfortunately, we have hardcoded dataset-specific items within the model,\n", "forever limiting it to working with MNIST Data. \ud83d\ude22\n", "\n", "This is fine if you don't plan on training/evaluating your model on different datasets.\n", "However, in many cases, this can become bothersome when you want to try out your architecture with different datasets."]}, {"cell_type": "code", "execution_count": 3, "id": "709cf586", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:40:41.972713Z", "iopub.status.busy": "2023-03-15T10:40:41.971922Z", "iopub.status.idle": "2023-03-15T10:40:41.992949Z", "shell.execute_reply": "2023-03-15T10:40:41.992110Z"}, "papermill": {"duration": 0.027845, "end_time": "2023-03-15T10:40:41.994581", "exception": false, "start_time": "2023-03-15T10:40:41.966736", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class LitMNIST(L.LightningModule):\n", " def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):\n", " super().__init__()\n", "\n", " # We hardcode dataset specific stuff here.\n", " self.data_dir = data_dir\n", " self.num_classes = 10\n", " self.dims = (1, 28, 28)\n", " channels, width, height = self.dims\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", " ]\n", " )\n", "\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " # Build model\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, self.num_classes),\n", " )\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y, task=\"multiclass\", num_classes=10)\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log(\"val_acc\", acc, prog_bar=True)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer\n", "\n", " ####################\n", " # DATA RELATED HOOKS\n", " ####################\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=128)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=128)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=128)"]}, {"cell_type": "markdown", "id": "4a655491", "metadata": {"papermill": {"duration": 0.003778, "end_time": "2023-03-15T10:40:42.002076", "exception": false, "start_time": "2023-03-15T10:40:41.998298", "status": "completed"}, "tags": []}, "source": ["### Training the ListMNIST Model"]}, {"cell_type": "code", "execution_count": 4, "id": "29812281", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:40:42.010531Z", "iopub.status.busy": "2023-03-15T10:40:42.010333Z", "iopub.status.idle": "2023-03-15T10:41:10.455990Z", "shell.execute_reply": "2023-03-15T10:41:10.454725Z"}, "papermill": {"duration": 28.452752, "end_time": "2023-03-15T10:41:10.458522", "exception": false, "start_time": "2023-03-15T10:40:42.005770", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "bc45212939024bb4ba157142ddb0e9dc", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/9912422 [00:00