{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5e8f7886",
   "metadata": {},
   "source": [
    "\n",
    "# TPU training with PyTorch Lightning\n",
    "\n",
    "* **Author:** PL team\n",
    "* **License:** CC BY-SA\n",
    "* **Generated:** 2021-07-17T09:05:13.252067\n",
    "\n",
    "In this notebook, we'll train a model on TPUs. Updating one Trainer flag is all you need for that. The most up to documentation related to TPU training can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/tpu.html).\n",
    "\n",
    "---\n",
    "Open in [![Open In Colab](){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/mnist-tpu-training.ipynb)\n",
    "\n",
    "Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
    "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
    "| Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc9e852b",
   "metadata": {},
   "source": [
    "### Setup\n",
    "This notebook requires some packages besides pytorch-lightning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4059183",
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LfrJLKPFyhsK",
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "! pip install --quiet \"torchvision\" \"torchmetrics>=0.3\" \"pytorch-lightning>=1.3\" \"torch>=1.6, <1.9\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a0e3ea2",
   "metadata": {},
   "source": [
    "###  Install Colab TPU compatible PyTorch/TPU wheels and dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea42bbf9",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a98eb4da",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from pytorch_lightning import LightningDataModule, LightningModule, Trainer\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "from torchmetrics.functional import accuracy\n",
    "from torchvision import transforms\n",
    "# Note - you must have torchvision installed for this example\n",
    "from torchvision.datasets import MNIST\n",
    "\n",
    "BATCH_SIZE = 1024"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28bc18a6",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "### Defining The `MNISTDataModule`\n",
    "\n",
    "Below we define `MNISTDataModule`. You can learn more about datamodules\n",
    "in [docs](https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be77bbfc",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "class MNISTDataModule(LightningDataModule):\n",
    "\n",
    "    def __init__(self, data_dir: str = './'):\n",
    "        super().__init__()\n",
    "        self.data_dir = data_dir\n",
    "        self.transform = transforms.Compose([\n",
    "            transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))\n",
    "        ])\n",
    "\n",
    "        # self.dims is returned when you call dm.size()\n",
    "        # Setting default dims here because we know them.\n",
    "        # Could optionally be assigned dynamically in dm.setup()\n",
    "        self.dims = (1, 28, 28)\n",
    "        self.num_classes = 10\n",
    "\n",
    "    def prepare_data(self):\n",
    "        # download\n",
    "        MNIST(self.data_dir, train=True, download=True)\n",
    "        MNIST(self.data_dir, train=False, download=True)\n",
    "\n",
    "    def setup(self, stage=None):\n",
    "\n",
    "        # Assign train/val datasets for use in dataloaders\n",
    "        if stage == 'fit' or stage is None:\n",
    "            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
    "            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
    "\n",
    "        # Assign test dataset for use in dataloader(s)\n",
    "        if stage == 'test' or stage is None:\n",
    "            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77df25f1",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "### Defining the `LitModel`\n",
    "\n",
    "Below, we define the model `LitMNIST`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79e57e98",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LitModel(LightningModule):\n",
    "\n",
    "    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):\n",
    "\n",
    "        super().__init__()\n",
    "\n",
    "        self.save_hyperparameters()\n",
    "\n",
    "        self.model = nn.Sequential(\n",
    "            nn.Flatten(), nn.Linear(channels * width * height, hidden_size), nn.ReLU(), nn.Dropout(0.1),\n",
    "            nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.1),\n",
    "            nn.Linear(hidden_size, num_classes)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.model(x)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        logits = self(x)\n",
    "        loss = F.nll_loss(logits, y)\n",
    "        self.log('train_loss', loss)\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        logits = self(x)\n",
    "        loss = F.nll_loss(logits, y)\n",
    "        preds = torch.argmax(logits, dim=1)\n",
    "        acc = accuracy(preds, y)\n",
    "        self.log('val_loss', loss, prog_bar=True)\n",
    "        self.log('val_acc', acc, prog_bar=True)\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)\n",
    "        return optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c75ebe94",
   "metadata": {},
   "source": [
    "### TPU Training\n",
    "\n",
    "Lightning supports training on a single TPU core or 8 TPU cores.\n",
    "\n",
    "The Trainer parameters `tpu_cores` defines how many TPU cores to train on (1 or 8) / Single TPU core to train on [1].\n",
    "\n",
    "For Single TPU training, Just pass the TPU core ID [1-8] in a list.\n",
    "Setting `tpu_cores=[5]` will train on TPU core ID 5."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa173f57",
   "metadata": {},
   "source": [
    "Train on TPU core ID 5 with `tpu_cores=[5]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e23e2f58",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Init DataModule\n",
    "dm = MNISTDataModule()\n",
    "# Init model from datamodule's attributes\n",
    "model = LitModel(*dm.size(), dm.num_classes)\n",
    "# Init trainer\n",
    "trainer = Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5])\n",
    "# Train\n",
    "trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "589f8c4e",
   "metadata": {},
   "source": [
    "Train on single TPU core with `tpu_cores=1`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9d469e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Init DataModule\n",
    "dm = MNISTDataModule()\n",
    "# Init model from datamodule's attributes\n",
    "model = LitModel(*dm.size(), dm.num_classes)\n",
    "# Init trainer\n",
    "trainer = Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=1)\n",
    "# Train\n",
    "trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40552324",
   "metadata": {},
   "source": [
    "Train on 8 TPU cores with `tpu_cores=8`.\n",
    "You might have to restart the notebook to run it on 8 TPU cores after training on single TPU core."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "432fc169",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Init DataModule\n",
    "dm = MNISTDataModule()\n",
    "# Init model from datamodule's attributes\n",
    "model = LitModel(*dm.size(), dm.num_classes)\n",
    "# Init trainer\n",
    "trainer = Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8)\n",
    "# Train\n",
    "trainer.fit(model, dm)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be63a2ce",
   "metadata": {},
   "source": [
    "## Congratulations - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n",
    "movement, you can do so in the following ways!\n",
    "\n",
    "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
    "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n",
    "tools we're building.\n",
    "\n",
    "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)!\n",
    "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n",
    "and share your interests in `#general` channel\n",
    "\n",
    "\n",
    "### Contributions !\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to\n",
    "[Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n",
    "GitHub Issues page and filter for \"good first issue\".\n",
    "\n",
    "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "* You can also contribute your own notebooks with useful examples !\n",
    "\n",
    "### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
    "\n",
    "![Pytorch Lightning](){height=\"60px\" width=\"240px\"}"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "colab_type,id,colab,-all",
   "formats": "ipynb,py:percent",
   "main_language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}