{"cells": [{"cell_type": "markdown", "id": "da3bd453", "metadata": {"papermill": {"duration": 0.031743, "end_time": "2022-04-28T12:55:01.902365", "exception": false, "start_time": "2022-04-28T12:55:01.870622", "status": "completed"}, "tags": []}, "source": ["\n", "# Introduction to Pytorch Lightning\n", "\n", "* **Author:** PL team\n", "* **License:** CC BY-SA\n", "* **Generated:** 2022-04-28T08:05:32.100192\n", "\n", "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n", "\n", "---\n", "Open in [{height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/mnist-hello-world.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "8c5149be", "metadata": {"papermill": {"duration": 0.028475, "end_time": "2022-04-28T12:55:01.961385", "exception": false, "start_time": "2022-04-28T12:55:01.932910", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "dfb9657a", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2022-04-28T12:55:02.026114Z", "iopub.status.busy": "2022-04-28T12:55:02.025584Z", "iopub.status.idle": "2022-04-28T12:55:05.329954Z", "shell.execute_reply": "2022-04-28T12:55:05.329375Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 3.340052, "end_time": "2022-04-28T12:55:05.330108", "exception": false, "start_time": "2022-04-28T12:55:01.990056", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.0.4 is available.\r\n", "You should consider upgrading via the '/usr/bin/python3.8 -m pip install --upgrade pip' command.\u001b[0m\r\n"]}], "source": ["! pip install --quiet \"seaborn\" \"pytorch-lightning>=1.4\" \"ipython[notebook]\" \"torch>=1.6, <1.9\" \"pandas\" \"torchvision\" \"torchmetrics>=0.6\""]}, {"cell_type": "code", "execution_count": 2, "id": "b6076c42", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:55:05.401729Z", "iopub.status.busy": "2022-04-28T12:55:05.401163Z", "iopub.status.idle": "2022-04-28T12:55:07.701858Z", "shell.execute_reply": "2022-04-28T12:55:07.702297Z"}, "papermill": {"duration": 2.339285, "end_time": "2022-04-28T12:55:07.702484", "exception": false, "start_time": "2022-04-28T12:55:05.363199", "status": "completed"}, "tags": []}, "outputs": [], "source": ["import os\n", "\n", "import pandas as pd\n", "import seaborn as sn\n", "import torch\n", "from IPython.core.display import display\n", "from pytorch_lightning import LightningModule, Trainer\n", "from pytorch_lightning.callbacks.progress import TQDMProgressBar\n", "from pytorch_lightning.loggers import CSVLogger\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils.data import DataLoader, random_split\n", "from torchmetrics import Accuracy\n", "from torchvision import transforms\n", "from torchvision.datasets import MNIST\n", "\n", "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", "BATCH_SIZE = 256 if torch.cuda.is_available() else 64"]}, {"cell_type": "markdown", "id": "56bda26b", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.029813, "end_time": "2022-04-28T12:55:07.762240", "exception": false, "start_time": "2022-04-28T12:55:07.732427", "status": "completed"}, "tags": []}, "source": ["## Simplest example\n", "\n", "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n", "\n", "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."]}, {"cell_type": "code", "execution_count": 3, "id": "2b9004c0", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:55:07.828553Z", "iopub.status.busy": "2022-04-28T12:55:07.827990Z", "iopub.status.idle": "2022-04-28T12:55:07.830545Z", "shell.execute_reply": "2022-04-28T12:55:07.830954Z"}, "papermill": {"duration": 0.037885, "end_time": "2022-04-28T12:55:07.831100", "exception": false, "start_time": "2022-04-28T12:55:07.793215", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class MNISTModel(LightningModule):\n", " def __init__(self):\n", " super().__init__()\n", " self.l1 = torch.nn.Linear(28 * 28, 10)\n", "\n", " def forward(self, x):\n", " return torch.relu(self.l1(x.view(x.size(0), -1)))\n", "\n", " def training_step(self, batch, batch_nb):\n", " x, y = batch\n", " loss = F.cross_entropy(self(x), y)\n", " return loss\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=0.02)"]}, {"cell_type": "markdown", "id": "d1051b01", "metadata": {"papermill": {"duration": 0.029818, "end_time": "2022-04-28T12:55:07.891613", "exception": false, "start_time": "2022-04-28T12:55:07.861795", "status": "completed"}, "tags": []}, "source": ["By using the `Trainer` you automatically get:\n", "1. Tensorboard logging\n", "2. Model checkpointing\n", "3. Training and validation loop\n", "4. early-stopping"]}, {"cell_type": "code", "execution_count": 4, "id": "e5854cb1", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:55:07.955938Z", "iopub.status.busy": "2022-04-28T12:55:07.955411Z", "iopub.status.idle": "2022-04-28T12:55:27.894253Z", "shell.execute_reply": "2022-04-28T12:55:27.893785Z"}, "papermill": {"duration": 19.973059, "end_time": "2022-04-28T12:55:27.894398", "exception": false, "start_time": "2022-04-28T12:55:07.921339", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True, used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | l1 | Linear | 7.9 K \n", "--------------------------------\n", "7.9 K Trainable params\n", "0 Non-trainable params\n", "7.9 K Total params\n", "0.031 Total estimated model params size (MB)\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "7666e9e388ef4e8f8adcf93a08f2dafd", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}], "source": ["# Init our model\n", "mnist_model = MNISTModel()\n", "\n", "# Init DataLoader from MNIST Dataset\n", "train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())\n", "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)\n", "\n", "# Initialize a trainer\n", "trainer = Trainer(\n", " accelerator=\"auto\",\n", " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", " max_epochs=3,\n", " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", ")\n", "\n", "# Train the model \u26a1\n", "trainer.fit(mnist_model, train_loader)"]}, {"cell_type": "markdown", "id": "05c8b7d1", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.036078, "end_time": "2022-04-28T12:55:27.967422", "exception": false, "start_time": "2022-04-28T12:55:27.931344", "status": "completed"}, "tags": []}, "source": ["## A more complete MNIST Lightning Module Example\n", "\n", "That wasn't so hard was it?\n", "\n", "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n", "\n", "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`.\n", "This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n", "\n", "---\n", "\n", "### Note what the following built-in functions are doing:\n", "\n", "1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#prepare-data) \ud83d\udcbe\n", " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", "\n", "2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#setup) \u2699\ufe0f\n", " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).\n", " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n", " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", "\n", "3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/stable/api_references.html#core-api) \u267b\ufe0f\n", " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"]}, {"cell_type": "code", "execution_count": 5, "id": "f9fce56b", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:55:28.053056Z", "iopub.status.busy": "2022-04-28T12:55:28.042329Z", "iopub.status.idle": "2022-04-28T12:55:28.055187Z", "shell.execute_reply": "2022-04-28T12:55:28.055599Z"}, "papermill": {"duration": 0.052173, "end_time": "2022-04-28T12:55:28.055746", "exception": false, "start_time": "2022-04-28T12:55:28.003573", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class LitMNIST(LightningModule):\n", " def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):\n", "\n", " super().__init__()\n", "\n", " # Set our init args as class attributes\n", " self.data_dir = data_dir\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " # Hardcode some dataset specific attributes\n", " self.num_classes = 10\n", " self.dims = (1, 28, 28)\n", " channels, width, height = self.dims\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", " ]\n", " )\n", "\n", " # Define PyTorch model\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, self.num_classes),\n", " )\n", "\n", " self.val_accuracy = Accuracy()\n", " self.test_accuracy = Accuracy()\n", "\n", " def forward(self, x):\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " self.val_accuracy.update(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log(\"val_acc\", self.val_accuracy, prog_bar=True)\n", "\n", " def test_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " self.test_accuracy.update(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log(\"test_loss\", loss, prog_bar=True)\n", " self.log(\"test_acc\", self.test_accuracy, prog_bar=True)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " return optimizer\n", "\n", " ####################\n", " # DATA RELATED HOOKS\n", " ####################\n", "\n", " def prepare_data(self):\n", " # download\n", " MNIST(self.data_dir, train=True, download=True)\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", "\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)"]}, {"cell_type": "code", "execution_count": 6, "id": "66816adb", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:55:28.141512Z", "iopub.status.busy": "2022-04-28T12:55:28.140982Z", "iopub.status.idle": "2022-04-28T12:56:00.878261Z", "shell.execute_reply": "2022-04-28T12:56:00.878683Z"}, "papermill": {"duration": 32.779357, "end_time": "2022-04-28T12:56:00.878860", "exception": false, "start_time": "2022-04-28T12:55:28.099503", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True, used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "---------------------------------------------\n", "0 | model | Sequential | 55.1 K\n", "1 | val_accuracy | Accuracy | 0 \n", "2 | test_accuracy | Accuracy | 0 \n", "---------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "c913d5cdf1a84bf6aeb3ce66641ba038", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "d94d7beaed274e70b3f4ce83a035f40d", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "1b41411e37b44436adc2cdf8e9e56959", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "a851416ba0bd49f6bd4537685cfec517", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "b0f3fe4d77f945c7afb01e4f28a19fc8", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}], "source": ["model = LitMNIST()\n", "trainer = Trainer(\n", " accelerator=\"auto\",\n", " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", " max_epochs=3,\n", " callbacks=[TQDMProgressBar(refresh_rate=20)],\n", " logger=CSVLogger(save_dir=\"logs/\"),\n", ")\n", "trainer.fit(model)"]}, {"cell_type": "markdown", "id": "c021d8f5", "metadata": {"papermill": {"duration": 0.048069, "end_time": "2022-04-28T12:56:00.974852", "exception": false, "start_time": "2022-04-28T12:56:00.926783", "status": "completed"}, "tags": []}, "source": ["### Testing\n", "\n", "To test a model, call `trainer.test(model)`.\n", "\n", "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically\n", "test using the best saved checkpoint (conditioned on val_loss)."]}, {"cell_type": "code", "execution_count": 7, "id": "6f31bbf6", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:01.073612Z", "iopub.status.busy": "2022-04-28T12:56:01.072364Z", "iopub.status.idle": "2022-04-28T12:56:02.883287Z", "shell.execute_reply": "2022-04-28T12:56:02.883707Z"}, "papermill": {"duration": 1.861867, "end_time": "2022-04-28T12:56:02.883878", "exception": false, "start_time": "2022-04-28T12:56:01.022011", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1444: UserWarning: `.test(ckpt_path=None)` was called without a model. The best model of the previous `fit` call will be used. You can pass `test(ckpt_path='best')` to use and best model checkpoint and avoid this warning or `ckpt_path=trainer.checkpoint_callback.last_model_path` to use the last model.\n", " rank_zero_warn(\n", "Restoring states from the checkpoint path at logs/lightning_logs/version_3/checkpoints/epoch=2-step=645.ckpt\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Loaded model weights from checkpoint at logs/lightning_logs/version_3/checkpoints/epoch=2-step=645.ckpt\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "46147a5285ee4712a703d338c52be49e", "version_major": 2, "version_minor": 0}, "text/plain": ["Testing: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"text/html": ["
\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", "\u2503 Test metric \u2503 DataLoader 0 \u2503\n", "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n", "\u2502 test_acc \u2502 0.9236999750137329 \u2502\n", "\u2502 test_loss \u2502 0.25315943360328674 \u2502\n", "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n", "\n"], "text/plain": ["\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", "\u2503\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m\u2503\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m\u2503\n", "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n", "\u2502\u001b[36m \u001b[0m\u001b[36m test_acc \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.9236999750137329 \u001b[0m\u001b[35m \u001b[0m\u2502\n", "\u2502\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.25315943360328674 \u001b[0m\u001b[35m \u001b[0m\u2502\n", "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"text/plain": ["[{'test_loss': 0.25315943360328674, 'test_acc': 0.9236999750137329}]"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["trainer.test()"]}, {"cell_type": "markdown", "id": "7c7f9eda", "metadata": {"papermill": {"duration": 0.05379, "end_time": "2022-04-28T12:56:02.991785", "exception": false, "start_time": "2022-04-28T12:56:02.937995", "status": "completed"}, "tags": []}, "source": ["### Bonus Tip\n", "\n", "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"]}, {"cell_type": "code", "execution_count": 8, "id": "6a6fbcd5", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:03.104409Z", "iopub.status.busy": "2022-04-28T12:56:03.103870Z", "iopub.status.idle": "2022-04-28T12:56:03.285768Z", "shell.execute_reply": "2022-04-28T12:56:03.286185Z"}, "papermill": {"duration": 0.240786, "end_time": "2022-04-28T12:56:03.286358", "exception": false, "start_time": "2022-04-28T12:56:03.045572", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:611: UserWarning: Checkpoint directory logs/lightning_logs/version_3/checkpoints exists and is not empty.\n", " rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "---------------------------------------------\n", "0 | model | Sequential | 55.1 K\n", "1 | val_accuracy | Accuracy | 0 \n", "2 | test_accuracy | Accuracy | 0 \n", "---------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "460f5141753046f9b54925aca37129ce", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}], "source": ["trainer.fit(model)"]}, {"cell_type": "markdown", "id": "ca97777f", "metadata": {"papermill": {"duration": 0.056861, "end_time": "2022-04-28T12:56:03.401171", "exception": false, "start_time": "2022-04-28T12:56:03.344310", "status": "completed"}, "tags": []}, "source": ["In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"]}, {"cell_type": "code", "execution_count": 9, "id": "248d37f8", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:03.519798Z", "iopub.status.busy": "2022-04-28T12:56:03.519287Z", "iopub.status.idle": "2022-04-28T12:56:03.863561Z", "shell.execute_reply": "2022-04-28T12:56:03.863975Z"}, "papermill": {"duration": 0.405875, "end_time": "2022-04-28T12:56:03.864175", "exception": false, "start_time": "2022-04-28T12:56:03.458300", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/html": ["
\n", " | val_loss | \n", "val_acc | \n", "test_loss | \n", "test_acc | \n", "
---|---|---|---|---|
epoch | \n", "\n", " | \n", " | \n", " | \n", " |
0 | \n", "0.439232 | \n", "0.8828 | \n", "NaN | \n", "NaN | \n", "
1 | \n", "0.314095 | \n", "0.9080 | \n", "NaN | \n", "NaN | \n", "
2 | \n", "0.268804 | \n", "0.9198 | \n", "NaN | \n", "NaN | \n", "
2 | \n", "NaN | \n", "NaN | \n", "0.253159 | \n", "0.9237 | \n", "