{"cells": [{"cell_type": "markdown", "id": "bdec8c9b", "metadata": {"papermill": {"duration": 0.008108, "end_time": "2021-09-09T13:20:18.870349", "exception": false, "start_time": "2021-09-09T13:20:18.862241", "status": "completed"}, "tags": []}, "source": ["\n", "# PyTorch Lightning Basic GAN Tutorial\n", "\n", "* **Author:** PL team\n", "* **License:** CC BY-SA\n", "* **Generated:** 2021-09-09T15:08:28.322630\n", "\n", "How to train a GAN!\n", "\n", "Main takeaways:\n", "1. Generator and discriminator are arbitrary PyTorch modules.\n", "2. training_step does both the generator and discriminator training.\n", "\n", "\n", "---\n", "Open in [![Open In Colab](){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/basic-gan.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", "| Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)"]}, {"cell_type": "markdown", "id": "9f242a79", "metadata": {"papermill": {"duration": 0.006659, "end_time": "2021-09-09T13:20:18.884263", "exception": false, "start_time": "2021-09-09T13:20:18.877604", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "6deebb04", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2021-09-09T13:20:18.901922Z", "iopub.status.busy": "2021-09-09T13:20:18.901452Z", "iopub.status.idle": "2021-09-09T13:20:18.904360Z", "shell.execute_reply": "2021-09-09T13:20:18.903838Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 0.013341, "end_time": "2021-09-09T13:20:18.904472", "exception": false, "start_time": "2021-09-09T13:20:18.891131", "status": "completed"}, "tags": []}, "outputs": [], "source": ["# ! pip install --quiet \"pytorch-lightning>=1.3\" \"torch>=1.6, <1.9\" \"torchvision\" \"torchmetrics>=0.3\""]}, {"cell_type": "code", "execution_count": 2, "id": "e24f349d", "metadata": {"execution": {"iopub.execute_input": "2021-09-09T13:20:18.926320Z", "iopub.status.busy": "2021-09-09T13:20:18.925848Z", "iopub.status.idle": "2021-09-09T13:20:20.005775Z", "shell.execute_reply": "2021-09-09T13:20:20.005309Z"}, "papermill": {"duration": 1.094307, "end_time": "2021-09-09T13:20:20.005890", "exception": false, "start_time": "2021-09-09T13:20:18.911583", "status": "completed"}, "tags": []}, "outputs": [], "source": ["import os\n", "from collections import OrderedDict\n", "\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "from pytorch_lightning import LightningDataModule, LightningModule, Trainer\n", "from torch.utils.data import DataLoader, random_split\n", "from torchvision.datasets import MNIST\n", "\n", "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", "AVAIL_GPUS = min(1, torch.cuda.device_count())\n", "BATCH_SIZE = 256 if AVAIL_GPUS else 64\n", "NUM_WORKERS = int(os.cpu_count() / 2)"]}, {"cell_type": "markdown", "id": "a666589a", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.006778, "end_time": "2021-09-09T13:20:20.020013", "exception": false, "start_time": "2021-09-09T13:20:20.013235", "status": "completed"}, "tags": []}, "source": ["### MNIST DataModule\n", "\n", "Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial\n", "on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html)."]}, {"cell_type": "code", "execution_count": 3, "id": "90f2067a", "metadata": {"execution": {"iopub.execute_input": "2021-09-09T13:20:20.042212Z", "iopub.status.busy": "2021-09-09T13:20:20.041153Z", "iopub.status.idle": "2021-09-09T13:20:20.044348Z", "shell.execute_reply": "2021-09-09T13:20:20.043889Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.0172, "end_time": "2021-09-09T13:20:20.044446", "exception": false, "start_time": "2021-09-09T13:20:20.027246", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class MNISTDataModule(LightningDataModule):\n", "    def __init__(\n", "        self,\n", "        data_dir: str = PATH_DATASETS,\n", "        batch_size: int = BATCH_SIZE,\n", "        num_workers: int = NUM_WORKERS,\n", "    ):\n", "        super().__init__()\n", "        self.data_dir = data_dir\n", "        self.batch_size = batch_size\n", "        self.num_workers = num_workers\n", "\n", "        self.transform = transforms.Compose(\n", "            [\n", "                transforms.ToTensor(),\n", "                transforms.Normalize((0.1307,), (0.3081,)),\n", "            ]\n", "        )\n", "\n", "        # self.dims is returned when you call dm.size()\n", "        # Setting default dims here because we know them.\n", "        # Could optionally be assigned dynamically in dm.setup()\n", "        self.dims = (1, 28, 28)\n", "        self.num_classes = 10\n", "\n", "    def prepare_data(self):\n", "        # download\n", "        MNIST(self.data_dir, train=True, download=True)\n", "        MNIST(self.data_dir, train=False, download=True)\n", "\n", "    def setup(self, stage=None):\n", "        # Assign train/val datasets for use in dataloaders\n", "        if stage == \"fit\" or stage is None:\n", "            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", "            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", "        # Assign test dataset for use in dataloader(s)\n", "        if stage == \"test\" or stage is None:\n", "            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", "    def train_dataloader(self):\n", "        return DataLoader(\n", "            self.mnist_train,\n", "            batch_size=self.batch_size,\n", "            num_workers=self.num_workers,\n", "        )\n", "\n", "    def val_dataloader(self):\n", "        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)\n", "\n", "    def test_dataloader(self):\n", "        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)"]}, {"cell_type": "markdown", "id": "8a15e6b5", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.006938, "end_time": "2021-09-09T13:20:20.058387", "exception": false, "start_time": "2021-09-09T13:20:20.051449", "status": "completed"}, "tags": []}, "source": ["### A. Generator"]}, {"cell_type": "code", "execution_count": 4, "id": "e6cb2832", "metadata": {"execution": {"iopub.execute_input": "2021-09-09T13:20:20.078127Z", "iopub.status.busy": "2021-09-09T13:20:20.077652Z", "iopub.status.idle": "2021-09-09T13:20:20.079885Z", "shell.execute_reply": "2021-09-09T13:20:20.079427Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.014603, "end_time": "2021-09-09T13:20:20.079978", "exception": false, "start_time": "2021-09-09T13:20:20.065375", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class Generator(nn.Module):\n", "    def __init__(self, latent_dim, img_shape):\n", "        super().__init__()\n", "        self.img_shape = img_shape\n", "\n", "        def block(in_feat, out_feat, normalize=True):\n", "            layers = [nn.Linear(in_feat, out_feat)]\n", "            if normalize:\n", "                layers.append(nn.BatchNorm1d(out_feat, 0.8))\n", "            layers.append(nn.LeakyReLU(0.2, inplace=True))\n", "            return layers\n", "\n", "        self.model = nn.Sequential(\n", "            *block(latent_dim, 128, normalize=False),\n", "            *block(128, 256),\n", "            *block(256, 512),\n", "            *block(512, 1024),\n", "            nn.Linear(1024, int(np.prod(img_shape))),\n", "            nn.Tanh(),\n", "        )\n", "\n", "    def forward(self, z):\n", "        img = self.model(z)\n", "        img = img.view(img.size(0), *self.img_shape)\n", "        return img"]}, {"cell_type": "markdown", "id": "0802db9a", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.007295, "end_time": "2021-09-09T13:20:20.110849", "exception": false, "start_time": "2021-09-09T13:20:20.103554", "status": "completed"}, "tags": []}, "source": ["### B. Discriminator"]}, {"cell_type": "code", "execution_count": 5, "id": "af46f802", "metadata": {"execution": {"iopub.execute_input": "2021-09-09T13:20:20.131267Z", "iopub.status.busy": "2021-09-09T13:20:20.130793Z", "iopub.status.idle": "2021-09-09T13:20:20.132847Z", "shell.execute_reply": "2021-09-09T13:20:20.132379Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.014411, "end_time": "2021-09-09T13:20:20.132950", "exception": false, "start_time": "2021-09-09T13:20:20.118539", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class Discriminator(nn.Module):\n", "    def __init__(self, img_shape):\n", "        super().__init__()\n", "\n", "        self.model = nn.Sequential(\n", "            nn.Linear(int(np.prod(img_shape)), 512),\n", "            nn.LeakyReLU(0.2, inplace=True),\n", "            nn.Linear(512, 256),\n", "            nn.LeakyReLU(0.2, inplace=True),\n", "            nn.Linear(256, 1),\n", "            nn.Sigmoid(),\n", "        )\n", "\n", "    def forward(self, img):\n", "        img_flat = img.view(img.size(0), -1)\n", "        validity = self.model(img_flat)\n", "\n", "        return validity"]}, {"cell_type": "markdown", "id": "35cfa85f", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.010453, "end_time": "2021-09-09T13:20:20.151538", "exception": false, "start_time": "2021-09-09T13:20:20.141085", "status": "completed"}, "tags": []}, "source": ["### C. GAN\n", "\n", "#### A couple of cool features to check out in this example...\n", "\n", "  - We use `some_tensor.type_as(another_tensor)` to make sure we initialize new tensors on the right device (i.e. GPU, CPU).\n", "    - Lightning will put your dataloader data on the right device automatically\n", "    - In this example, we pull from latent dim on the fly, so we need to dynamically add tensors to the right device.\n", "    - `type_as` is the way we recommend to do this.\n", "  - This example shows how to use multiple dataloaders in your `LightningModule`."]}, {"cell_type": "code", "execution_count": 6, "id": "c03d4d75", "metadata": {"execution": {"iopub.execute_input": "2021-09-09T13:20:20.182014Z", "iopub.status.busy": "2021-09-09T13:20:20.181514Z", "iopub.status.idle": "2021-09-09T13:20:20.183781Z", "shell.execute_reply": "2021-09-09T13:20:20.183302Z"}, "papermill": {"duration": 0.023137, "end_time": "2021-09-09T13:20:20.183887", "exception": false, "start_time": "2021-09-09T13:20:20.160750", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class GAN(LightningModule):\n", "    def __init__(\n", "        self,\n", "        channels,\n", "        width,\n", "        height,\n", "        latent_dim: int = 100,\n", "        lr: float = 0.0002,\n", "        b1: float = 0.5,\n", "        b2: float = 0.999,\n", "        batch_size: int = BATCH_SIZE,\n", "        **kwargs\n", "    ):\n", "        super().__init__()\n", "        self.save_hyperparameters()\n", "\n", "        # networks\n", "        data_shape = (channels, width, height)\n", "        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)\n", "        self.discriminator = Discriminator(img_shape=data_shape)\n", "\n", "        self.validation_z = torch.randn(8, self.hparams.latent_dim)\n", "\n", "        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)\n", "\n", "    def forward(self, z):\n", "        return self.generator(z)\n", "\n", "    def adversarial_loss(self, y_hat, y):\n", "        return F.binary_cross_entropy(y_hat, y)\n", "\n", "    def training_step(self, batch, batch_idx, optimizer_idx):\n", "        imgs, _ = batch\n", "\n", "        # sample noise\n", "        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)\n", "        z = z.type_as(imgs)\n", "\n", "        # train generator\n", "        if optimizer_idx == 0:\n", "\n", "            # generate images\n", "            self.generated_imgs = self(z)\n", "\n", "            # log sampled images\n", "            sample_imgs = self.generated_imgs[:6]\n", "            grid = torchvision.utils.make_grid(sample_imgs)\n", "            self.logger.experiment.add_image(\"generated_images\", grid, 0)\n", "\n", "            # ground truth result (ie: all fake)\n", "            # put on GPU because we created this tensor inside training_loop\n", "            valid = torch.ones(imgs.size(0), 1)\n", "            valid = valid.type_as(imgs)\n", "\n", "            # adversarial loss is binary cross-entropy\n", "            g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)\n", "            tqdm_dict = {\"g_loss\": g_loss}\n", "            output = OrderedDict({\"loss\": g_loss, \"progress_bar\": tqdm_dict, \"log\": tqdm_dict})\n", "            return output\n", "\n", "        # train discriminator\n", "        if optimizer_idx == 1:\n", "            # Measure discriminator's ability to classify real from generated samples\n", "\n", "            # how well can it label as real?\n", "            valid = torch.ones(imgs.size(0), 1)\n", "            valid = valid.type_as(imgs)\n", "\n", "            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)\n", "\n", "            # how well can it label as fake?\n", "            fake = torch.zeros(imgs.size(0), 1)\n", "            fake = fake.type_as(imgs)\n", "\n", "            fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)\n", "\n", "            # discriminator loss is the average of these\n", "            d_loss = (real_loss + fake_loss) / 2\n", "            tqdm_dict = {\"d_loss\": d_loss}\n", "            output = OrderedDict({\"loss\": d_loss, \"progress_bar\": tqdm_dict, \"log\": tqdm_dict})\n", "            return output\n", "\n", "    def configure_optimizers(self):\n", "        lr = self.hparams.lr\n", "        b1 = self.hparams.b1\n", "        b2 = self.hparams.b2\n", "\n", "        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))\n", "        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))\n", "        return [opt_g, opt_d], []\n", "\n", "    def on_epoch_end(self):\n", "        z = self.validation_z.type_as(self.generator.model[0].weight)\n", "\n", "        # log sampled images\n", "        sample_imgs = self(z)\n", "        grid = torchvision.utils.make_grid(sample_imgs)\n", "        self.logger.experiment.add_image(\"generated_images\", grid, self.current_epoch)"]}, {"cell_type": "code", "execution_count": 7, "id": "239a8875", "metadata": {"execution": {"iopub.execute_input": "2021-09-09T13:20:20.202337Z", "iopub.status.busy": "2021-09-09T13:20:20.201868Z", "iopub.status.idle": "2021-09-09T13:20:42.180320Z", "shell.execute_reply": "2021-09-09T13:20:42.179750Z"}, "papermill": {"duration": 21.989268, "end_time": "2021-09-09T13:20:42.180442", "exception": false, "start_time": "2021-09-09T13:20:20.191174", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True, used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:99: UserWarning: you passed in a val_dataloader but have no validation_step. Skipping val loop\n", "  rank_zero_warn(f\"you passed in a {loader_name} but have no {step_name}. Skipping {stage} loop\")\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", "  | Name          | Type          | Params | In sizes | Out sizes     \n", "----------------------------------------------------------------------------\n", "0 | generator     | Generator     | 1.5 M  | [2, 100] | [2, 1, 28, 28]\n", "1 | discriminator | Discriminator | 533 K  | ?        | ?             \n", "----------------------------------------------------------------------------\n", "2.0 M     Trainable params\n", "0         Non-trainable params\n", "2.0 M     Total params\n", "8.174     Total estimated model params size (MB)\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "de9ba680bbcc4700acf95102f4837035", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: -1it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:405: LightningDeprecationWarning: One of the returned values {'progress_bar', 'log'} has a `grad_fn`. We will detach it automatically but this behaviour will change in v1.6. Please detach it manually: `return {'loss': ..., 'something': something.detach()}`\n", "  warning_cache.deprecation(\n"]}], "source": ["dm = MNISTDataModule()\n", "model = GAN(*dm.size())\n", "trainer = Trainer(gpus=AVAIL_GPUS, max_epochs=5, progress_bar_refresh_rate=20)\n", "trainer.fit(model, dm)"]}, {"cell_type": "code", "execution_count": 8, "id": "28e960d9", "metadata": {"execution": {"iopub.execute_input": "2021-09-09T13:20:42.202343Z", "iopub.status.busy": "2021-09-09T13:20:42.201872Z", "iopub.status.idle": "2021-09-09T13:20:42.203977Z", "shell.execute_reply": "2021-09-09T13:20:42.203515Z"}, "papermill": {"duration": 0.013867, "end_time": "2021-09-09T13:20:42.204076", "exception": false, "start_time": "2021-09-09T13:20:42.190209", "status": "completed"}, "tags": []}, "outputs": [], "source": ["# Start tensorboard.\n", "# %load_ext tensorboard\n", "# %tensorboard --logdir lightning_logs/"]}, {"cell_type": "markdown", "id": "bbe2c8c5", "metadata": {"papermill": {"duration": 0.008883, "end_time": "2021-09-09T13:20:42.221829", "exception": false, "start_time": "2021-09-09T13:20:42.212946", "status": "completed"}, "tags": []}, "source": ["## Congratulations - Time to Join the Community!\n", "\n", "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n", "movement, you can do so in the following ways!\n", "\n", "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n", "tools we're building.\n", "\n", "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)!\n", "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n", "and share your interests in `#general` channel\n", "\n", "\n", "### Contributions !\n", "The best way to contribute to our community is to become a code contributor! At any time you can go to\n", "[Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", "GitHub Issues page and filter for \"good first issue\".\n", "\n", "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* You can also contribute your own notebooks with useful examples !\n", "\n", "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", "\n", "![Pytorch Lightning](){height=\"60px\" width=\"240px\"}"]}, {"cell_type": "raw", "metadata": {"raw_mimetype": "text/restructuredtext"}, "source": [".. customcarditem::\n", "   :header: PyTorch Lightning Basic GAN Tutorial\n", "   :card_description: How to train a GAN!  Main takeaways: 1. Generator and discriminator are arbitrary PyTorch modules. 2. training_step does both the generator and discriminator training.\n", "   :tags: Image,GPU/TPU,Lightning-Examples"]}], "metadata": {"jupytext": {"cell_metadata_filter": "id,colab,colab_type,-all", "formats": "ipynb,py:percent", "main_language": "python"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6"}, "papermill": {"default_parameters": {}, "duration": 25.027239, "end_time": "2021-09-09T13:20:42.837710", "environment_variables": {}, "exception": null, "input_path": "lightning_examples/basic-gan/gan.ipynb", "output_path": ".notebooks/lightning_examples/basic-gan.ipynb", "parameters": {}, "start_time": "2021-09-09T13:20:17.810471", "version": "2.3.3"}, "widgets": {"application/vnd.jupyter.widget-state+json": {"state": {"0f79481079664977ab8b20e9a5e1a2df": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_574f810fd7df44e48a5a6446177d5496", "max": 215.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_c0ac37a29c364a18be33dbab437ddacb", "value": 215.0}}, "267507daf8b9404e924bf1dfdea8d03c": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": ""}}, "574f810fd7df44e48a5a6446177d5496": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": "2", "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "62f59cd6e2024994bf924a3234631810": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "7e103219c81a4fe2881b4b2ead09fdf6": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_62f59cd6e2024994bf924a3234631810", "placeholder": "\u200b", "style": "IPY_MODEL_fc1a7faf1b504a5791c0c496961f190f", "value": "Epoch 4: 100%"}}, "8ad238fc0f774fe19e253b46e5e2653f": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b195643141c643c19878a26d98156a78", "placeholder": "\u200b", "style": "IPY_MODEL_267507daf8b9404e924bf1dfdea8d03c", "value": " 215/215 [00:03&lt;00:00, 56.12it/s, loss=2.8, v_num=0]"}}, "b195643141c643c19878a26d98156a78": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "c0ac37a29c364a18be33dbab437ddacb": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "c3567fea3abc4ae7a076388428252fce": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": "inline-flex", "flex": null, "flex_flow": "row wrap", "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": "100%"}}, "de9ba680bbcc4700acf95102f4837035": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_7e103219c81a4fe2881b4b2ead09fdf6", "IPY_MODEL_0f79481079664977ab8b20e9a5e1a2df", "IPY_MODEL_8ad238fc0f774fe19e253b46e5e2653f"], "layout": "IPY_MODEL_c3567fea3abc4ae7a076388428252fce"}}, "fc1a7faf1b504a5791c0c496961f190f": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": ""}}}, "version_major": 2, "version_minor": 0}}}, "nbformat": 4, "nbformat_minor": 5}