{"cells": [{"cell_type": "markdown", "id": "4a709358", "metadata": {"papermill": {"duration": 0.012367, "end_time": "2023-10-11T16:28:35.825360", "exception": false, "start_time": "2023-10-11T16:28:35.812993", "status": "completed"}, "tags": []}, "source": ["\n", "# Tutorial 10: Autoregressive Image Modeling\n", "\n", "* **Author:** Phillip Lippe\n", "* **License:** CC BY-SA\n", "* **Generated:** 2023-10-11T16:26:07.497672\n", "\n", "In this tutorial, we implement an autoregressive likelihood model for the task of image modeling.\n", "Autoregressive models are naturally strong generative models that constitute one of the current\n", "state-of-the-art architectures on likelihood-based image modeling,\n", "and are also the basis for large language generation models such as GPT3.\n", "We will focus on the PixelCNN architecture in this tutorial, and apply it to MNIST modeling.\n", "This notebook is part of a lecture series on Deep Learning at the University of Amsterdam.\n", "The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.\n", "\n", "\n", "---\n", "Open in [![Open In Colab](){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/course_UvA-DL/10-autoregressive-image-modeling.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "158b4700", "metadata": {"papermill": {"duration": 0.015621, "end_time": "2023-10-11T16:28:35.853167", "exception": false, "start_time": "2023-10-11T16:28:35.837546", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "36e0ec76", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2023-10-11T16:28:35.877471Z", "iopub.status.busy": "2023-10-11T16:28:35.876916Z", "iopub.status.idle": "2023-10-11T16:32:50.702473Z", "shell.execute_reply": "2023-10-11T16:32:50.701484Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 254.842354, "end_time": "2023-10-11T16:32:50.706380", "exception": false, "start_time": "2023-10-11T16:28:35.864026", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}], "source": ["! pip install --quiet \"lightning>=2.0.0\" \"torch>=1.8.1, <2.1.0\" \"setuptools>=68.0.0, <68.3.0\" \"matplotlib\" \"torchmetrics>=0.7, <1.3\" \"ipython[notebook]>=8.0.0, <8.17.0\" \"matplotlib>=3.0.0, <3.9.0\" \"pytorch-lightning>=1.4, <2.1.0\" \"torchvision\" \"urllib3\" \"seaborn\""]}, {"cell_type": "markdown", "id": "2b2df267", "metadata": {"papermill": {"duration": 0.017062, "end_time": "2023-10-11T16:32:50.742553", "exception": false, "start_time": "2023-10-11T16:32:50.725491", "status": "completed"}, "tags": []}, "source": ["
\n", "\n", "Similar to the language generation you have seen in assignment 2, autoregressive models work on images by modeling the likelihood of a pixel given all previous ones.\n", "For instance, in the picture below, we model the pixel $x_i$ as a conditional probability distribution\n", "based on all previous (here blue) pixels (figure credit - [Aaron van den Oord et al. ](https://arxiv.org/abs/1601.06759)):\n", "\n", "
\n", "\n", "Generally, autoregressive model over high-dimensional data $\\mathbf{x}$ factor the joint distribution as the following product of conditionals:\n", "\n", "$$p(\\mathbf{x})=p(x_1, ..., x_n)=\\prod_{i=1}^{n} p(x_i|x_1,...,x_{i-1})$$\n", "\n", "Learning these conditionals is often much simpler than learning the joint distribution $p(\\mathbf{x})$ all together.\n", "However, disadvantages of autoregressive models include slow sampling, especially for large images,\n", "as we need height-times-width forward passes through the model.\n", "In addition, for some applications, we require a latent space as modeled in VAEs and Normalizing Flows.\n", "For instance, in autoregressive models, we cannot interpolate between two images because of the lack of a latent representation.\n", "We will explore and discuss these benefits and drawbacks alongside with our implementation.\n", "\n", "Our implementation will focus on the [PixelCNN](https://arxiv.org/pdf/1606.05328.pdf) [2] model which has been discussed in detail in the lecture.\n", "Most current SOTA models use PixelCNN as their fundamental architecture,\n", "and various additions have been proposed to improve the performance\n", "(e.g. [PixelCNN++](https://arxiv.org/pdf/1701.05517.pdf) and [PixelSNAIL](http://proceedings.mlr.press/v80/chen18h/chen18h.pdf)).\n", "Hence, implementing PixelCNN is a good starting point for our short tutorial.\n", "\n", "First of all, we need to import our standard libraries. Similarly as in\n", "the last couple of tutorials, we will use [PyTorch\n", "Lightning](https://lightning.ai/docs/pytorch/stable/) here as\n", "well."]}, {"cell_type": "code", "execution_count": 2, "id": "7fd6bdb9", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:32:50.768763Z", "iopub.status.busy": "2023-10-11T16:32:50.768292Z", "iopub.status.idle": "2023-10-11T16:32:56.292995Z", "shell.execute_reply": "2023-10-11T16:32:56.292200Z"}, "papermill": {"duration": 5.539522, "end_time": "2023-10-11T16:32:56.294556", "exception": false, "start_time": "2023-10-11T16:32:50.755034", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 42\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Using device cuda:0\n"]}, {"data": {"text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["\n", "import math\n", "import os\n", "import urllib.request\n", "from urllib.error import HTTPError\n", "\n", "import lightning as L\n", "\n", "# Imports for plotting\n", "import matplotlib.pyplot as plt\n", "import matplotlib_inline.backend_inline\n", "import numpy as np\n", "import seaborn as sns\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import torch.utils.data as data\n", "import torchvision\n", "from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint\n", "from matplotlib.colors import to_rgb\n", "from torch import Tensor\n", "from torchvision import transforms\n", "from torchvision.datasets import MNIST\n", "from tqdm.notebook import tqdm\n", "\n", "plt.set_cmap(\"cividis\")\n", "%matplotlib inline\n", "matplotlib_inline.backend_inline.set_matplotlib_formats(\"svg\", \"pdf\") # For export\n", "\n", "# Path to the folder where the datasets are/should be downloaded (e.g. MNIST)\n", "DATASET_PATH = os.environ.get(\"PATH_DATASETS\", \"data\")\n", "# Path to the folder where the pretrained models are saved\n", "CHECKPOINT_PATH = os.environ.get(\"PATH_CHECKPOINT\", \"saved_models/tutorial12\")\n", "\n", "# Setting the seed\n", "L.seed_everything(42)\n", "\n", "# Ensure that all operations are deterministic on GPU (if used) for reproducibility\n", "torch.backends.cudnn.deterministic = True\n", "torch.backends.cudnn.benchmark = False\n", "\n", "# Fetching the device that will be used throughout this notebook\n", "device = torch.device(\"cpu\") if not torch.cuda.is_available() else torch.device(\"cuda:0\")\n", "print(\"Using device\", device)"]}, {"cell_type": "markdown", "id": "815134e9", "metadata": {"papermill": {"duration": 0.012459, "end_time": "2023-10-11T16:32:56.317820", "exception": false, "start_time": "2023-10-11T16:32:56.305361", "status": "completed"}, "tags": []}, "source": ["We again provide a pretrained model, which is downloaded below:"]}, {"cell_type": "code", "execution_count": 3, "id": "251ddfda", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:32:56.401586Z", "iopub.status.busy": "2023-10-11T16:32:56.400791Z", "iopub.status.idle": "2023-10-11T16:32:56.828302Z", "shell.execute_reply": "2023-10-11T16:32:56.827305Z"}, "papermill": {"duration": 0.443193, "end_time": "2023-10-11T16:32:56.830959", "exception": false, "start_time": "2023-10-11T16:32:56.387766", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial12/PixelCNN.ckpt...\n"]}], "source": ["# Github URL where saved models are stored for this tutorial\n", "base_url = \"https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial12/\"\n", "# Files to download\n", "pretrained_files = [\"PixelCNN.ckpt\"]\n", "# Create checkpoint path if it doesn't exist yet\n", "os.makedirs(CHECKPOINT_PATH, exist_ok=True)\n", "\n", "# For each file, check whether it already exists. If not, try downloading it.\n", "for file_name in pretrained_files:\n", " file_path = os.path.join(CHECKPOINT_PATH, file_name)\n", " if not os.path.isfile(file_path):\n", " file_url = base_url + file_name\n", " print(\"Downloading %s...\" % file_url)\n", " try:\n", " urllib.request.urlretrieve(file_url, file_path)\n", " except HTTPError as e:\n", " print(\n", " \"Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\\n\",\n", " e,\n", " )"]}, {"cell_type": "markdown", "id": "5a5a3eee", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.011057, "end_time": "2023-10-11T16:32:56.854980", "exception": false, "start_time": "2023-10-11T16:32:56.843923", "status": "completed"}, "tags": []}, "source": ["Similar to the Normalizing Flows in Tutorial 11, we will work on the\n", "MNIST dataset and use 8-bits per pixel (values between 0 and 255). The\n", "dataset is loaded below:"]}, {"cell_type": "code", "execution_count": 4, "id": "38deeb8b", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:32:56.881735Z", "iopub.status.busy": "2023-10-11T16:32:56.881173Z", "iopub.status.idle": "2023-10-11T16:32:58.173909Z", "shell.execute_reply": "2023-10-11T16:32:58.173260Z"}, "papermill": {"duration": 1.315515, "end_time": "2023-10-11T16:32:58.182836", "exception": false, "start_time": "2023-10-11T16:32:56.867321", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/13/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 0%| | 0/9912422 [00:00 only make them a tensor\n", "transform = transforms.Compose([transforms.ToTensor(), discretize])\n", "\n", "# Loading the training dataset. We need to split it into a training and validation part\n", "train_dataset = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)\n", "L.seed_everything(42)\n", "train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000])\n", "\n", "# Loading the test set\n", "test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)\n", "\n", "# We define a set of data loaders that we can use for various purposes later.\n", "train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)\n", "val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)\n", "test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)"]}, {"cell_type": "markdown", "id": "60fbf039", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.015206, "end_time": "2023-10-11T16:32:58.213957", "exception": false, "start_time": "2023-10-11T16:32:58.198751", "status": "completed"}, "tags": []}, "source": ["A good practice is to always visualize some data examples to get an intuition of the data:"]}, {"cell_type": "code", "execution_count": 5, "id": "bbf284d2", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:32:58.254674Z", "iopub.status.busy": "2023-10-11T16:32:58.254203Z", "iopub.status.idle": "2023-10-11T16:32:58.430171Z", "shell.execute_reply": "2023-10-11T16:32:58.429503Z"}, "papermill": {"duration": 0.210926, "end_time": "2023-10-11T16:32:58.438607", "exception": false, "start_time": "2023-10-11T16:32:58.227681", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1R5cGUgL0NhdGFsb2cgL1BhZ2VzIDIgMCBSID4+CmVuZG9iago4IDAgb2JqCjw8IC9Gb250IDMgMCBSIC9YT2JqZWN0IDcgMCBSIC9FeHRHU3RhdGUgNCAwIFIgL1BhdHRlcm4gNSAwIFIKL1NoYWRpbmcgNiAwIFIgL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gPj4KZW5kb2JqCjExIDAgb2JqCjw8IC9UeXBlIC9QYWdlIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovTWVkaWFCb3ggWyAwIDAgMzQxLjY3NDgzODcwOTcgMTgwLjcyIF0gL0NvbnRlbnRzIDkgMCBSIC9Bbm5vdHMgMTAgMCBSID4+CmVuZG9iago5IDAgb2JqCjw8IC9MZW5ndGggMTIgMCBSIC9GaWx0ZXIgL0ZsYXRlRGVjb2RlID4+CnN0cmVhbQp4nFWOSw7CMAxE9z7FnCDfKkmXQKWIZWHBAaJQiCioVKLXx61AhcWzPJbHHtnk1zXlQ9xidyS5qjSSRmE6KBRmgkZkOlKserKVFs5XwdYsb79SByW84Zla2wvRmQZ4YRas4TpvB69qD+2csAbPjBPukBv+MvKrwkx8PeI/2LD4HeYgH+v3cOoh9xrNAy219AYPKzF0CmVuZHN0cmVhbQplbmRvYmoKMTIgMCBvYmoKMTQ4CmVuZG9iagoxMCAwIG9iagpbIF0KZW5kb2JqCjMgMCBvYmoKPDwgPj4KZW5kb2JqCjQgMCBvYmoKPDwgL0ExIDw8IC9UeXBlIC9FeHRHU3RhdGUgL0NBIDEgL2NhIDEgPj4gPj4KZW5kb2JqCjUgMCBvYmoKPDwgPj4KZW5kb2JqCjYgMCBvYmoKPDwgPj4KZW5kb2JqCjcgMCBvYmoKPDwgL0kxIDEzIDAgUiA+PgplbmRvYmoKMTMgMCBvYmoKPDwgL1R5cGUgL1hPYmplY3QgL1N1YnR5cGUgL0ltYWdlIC9XaWR0aCA0NTUgL0hlaWdodCAyMzEKL0NvbG9yU3BhY2UgWyAvSW5kZXhlZCAvRGV2aWNlUkdCIDIyMgoo/////f39+/v7+fn59/f39fX18/Pz8fHx7+/v7e3t6+vr6enp5+fn5eXl4+Pj4eHh39/f3d3d29vb2dnZ19fX1dXV09PT0dHRz8/Pzc3Ny8vLycnJx8fHxcXFw8PDwcHBv7+/vb29u7u7ubm5t7e3tbW1s7Ozr6+vra2tqampp6enpaWloaGhmZmZl5eXlZWVk5OTkZGRj4+PjY2Ni4uLiYmJh4eHhYWFg4ODfX19e3t7eXl5d3d3dXV1c3NzcXFxb29vZ2dnY2NjX19fXV1dW1tbWVlZV1dXVVVVU1NTUVFRT09PS0tLSUlJR0dHRUVFQ0NDQUFBPz8/Ozs7OTk5Nzc3NTU1MzMzMTExLy8vKysrJycnJSUlIyMjHx8fHR0dGxsbGRkZFxcXFRUVExMTERERDw8PXHJcclxyCwsLCQkJBwcHBQUFAwMDAQEB/v7+/Pz8+vr6+Pj49PT08vLy8PDw7u7u7Ozs6urq6Ojo5ubm5OTk4uLi4ODg3t7e3Nzc2tra2NjY1tbW1NTU0NDQzs7OzMzMxsbGxMTEwsLCwMDAvr6+vLy8urq6uLi4tra2tLS0srKysLCwrq6urKysqqqqqKiopKSkoqKioKCgnp6enJycmpqamJiYlpaWlJSUioqKiIiIhoaGhISEgoKCgICAfn5+fHx8eHh4dnZ2dHR0cnJycHBwbm5uampqZmZmZGRkYmJiYGBgXl5eXFxcXFxcWlpaWFhYVlZWVFRUUFBQTExMSkpKSEhIRkZGREREQkJCQEBAPj4+PDw8Ojo6ODg4NDQ0MjIyMDAwLi4uLCwsKioqXChcKFwoJiYmJCQkIiIiICAgHh4eHBwcGhoaGBgYFhYWFBQUEhISEBAQDg4ODAwMXG5cblxuCAgIBgYGBAQEAgICAAAAKQpdCi9CaXRzUGVyQ29tcG9uZW50IDggL0ZpbHRlciAvRmxhdGVEZWNvZGUKL0RlY29kZVBhcm1zIDw8IC9QcmVkaWN0b3IgMTAgL0NvbG9ycyAxIC9Db2x1bW5zIDQ1NSAvQml0c1BlckNvbXBvbmVudCA4ID4+Ci9MZW5ndGggMTQgMCBSID4+CnN0cmVhbQp4nO2d/58WRQHHOUIoSkwNAoqyiAL6AgRFREREJwlpGET0BSlM0JCzODwBM0PlLoOKhKDSS6wsoKgEosvEwzs7EdPokoMj4W/p8+F2X+ztM7M7++U5juHz/gXcZ3c/O/t+8Jmd2ZkZ8FvhAwMu9gWIUpBHP5BHP5BHP5BHP5BHPwg9nusrlFeVPHn0I08e/ciTRz/y5NGPPHn0I08e/ciTRz/y5NFCZ2fnO8GzoE/y8iKPycijU26/z5NHp9x+n3eZejwBFi5cWA/qwAAwZswYe27RPGfy5HWDrVu3siDbQdXziiCPduRRHi9hj58A40BNb5qamuy5RfIykTXvVrAF1PUwB1Q1ryjyaEYe5TFPXlHk0czl5PF/gHWBlpaWRjAI1FRy9913/wWYc8srSApZ87aBwOE9gEWtal5R5NGMPMpjnryiyKOZy8Xjf4ChUmOgQD3nKcCmhMmTJ9t34i1+DSSfKmv53gICj4ZvYTpZ8z4D7ge4Ycx8BLD0p06dypQnjzHkUR7z5MmjPEbysnmcNWvWe0G6Q1JbW5unnCjBqQcBT3H48GHDHv8Eu3fvfgMYCVDbakktp1v5zl3wWF9fvwA4H5c5b+XKlTsBK4l39cASB3+9a8SIEQuBc548xpBHecyU1+ce16xZ800wA6xduzZB3Lp16xaBFwLOnj2bp5zfAOEZ586dy01Uy0tYsmTJh8F6EIl9HKSWM/2GBIQet2zZ4nxMrjz8CFLYG8GfwZ49e6IeAW722qPAKU8eY8ijPGbKk0d5rMxz8Njd3Z0gD/wU/Ay0t7eXUE7UbQ6HZx46dOgk8GZgjn4HYEtvgbwos2fPHgjoccqUKW7H5Mn7OsAzI221B3dt2bJlfwvYCgKX94Gurq70PHmMIo/ymCVPHuXRnFfIY1NTE99fuR2UV05qYSsjX4WN5V1xxRUs42QQbEmu4zjlhfDFnM2bNwdtq3WdnZ3OZcqctx/g6j8CDC3E3IRvFL7EQ1nG5ubm9Dx5DJFHecya16cezwB+1NOpEqUBjAYrVqyIHcP/8Xcl/g/d+b7yUXTfvn3fD+Dj65EjR/jJz0HN+cfVdexFSz6Nc948UHeBwON8cEMPnwPpp0nNewnwBm7YsIF3y7IX/mlcA3glqHmk58ljiDxGkEd5jCGPJXlk/1dl3eZq0ApYuZlbASsnPwb8+9NPP52nnMncBNjxiCtx6x50zuMbnBGP48aNGw/4EBdsYU8hN01JfLRMzWMLcdDDmHA1uMNhl+T111+fsKM8xpBHecySJ49EHs152TwOGTKEPWJ8MW8HqLQcZf369aweZC1nMm8HQVVrOUg/IK9HO1eCjo6OvHkTABXt37/fssdZMHHiRO70EFi9enV6+eQxRB7lMUuePBJ5NOcleAzbaSMMHDjwXmBowjbDR/U2kKWcyYQep02b5nZAXo/4Fm6oBF/kIRywgMpep7klPTFvJbgKUJH9SjhOP+h/TH43N5InjyHyKI+ueRfFI0vjZCuZm4FrOVOYP3/+68EPwOnTp92OyePxOrBq1SrDTnxy5puK2OkVkDXvvyB8rbHysfAkmApeB4Kd0i9bHmPIozy65snjeeTRmnfJeMQTWwcc8oyZ5mDM4/GTwLwTRydsAlBp6TxMzePzI0sR6Vbkk/C8efNYgYrcuMmJo7FjefIYIo/ymCVPHuXRnpfbI3vjRo0axbdaONKNtlA7iO3ER+clIEs5LXwaBGf9O3A+Lo/Hd4PKeg62rOJMkKznLFq0KG/edMAqzODBg38YwIGA2MI6HMdXvA3gP+XRUM70HeVRHrPkyaM82vMSPCb3MPIlmX09PAHMO/0GZC2nAT4gc4YVnJFroyQPlMub9wEQaSfHd2VF8H4u+wPb2tp+BPjJcFAgj9NwPQYaGhrCJgH+m8BXJxxEHmg9AJzLJ48h8iiPWfL63KP5vUd3HnjgAb4Dn7WcBvL+NmbK46QgF+bQJe8CHA7Nux7ZPAuUkxcOQP4CCLayBZb/RGD5FuBcPnkMkUd5zJknj+7I4/OAb+Nk9TcM8MrMM7Tn8fgTwDNv3LhxGch0bKa8M2fOXAvqjLA2wgHhd4KS8ir5LGDVx3m6EHk0II/ymCevEnlMLqfbzl56JJ8HWVRyHpR/gNRc54K9Ct4KePaxY8c6H5c3jxUpfm8iAvm+zKFDhzjEvPy8GNcFjdSo/bgdII9m5FEe8+TFqJJH0tbWlizvu+DmHlYnDw46l72cnFSeIfzVtY9yKi+vKAXy+CYlnx35+yiPBfOKIo9m5NEpTx5Lpv96LJesee8H9MgKQF/kFaVA3ldA2CUpjwXziiKPZuTRKU8eS0YeDbz44ovfBvTI+d8nTZqUMEViCXklII8G5NE1Tx5LRh4NyKNrXv/2eOzYsbCxgWuxvPzyy9XNK4ECed8CvwRcoGzp0qWZ8uSxZOTRgDy65vVvj8pzzZNHP/Lk0Y88efQjTx79yJNHP/Lk0Y88efQjL/QoLm3k0Q/k0Q/k0Q/k0Q/k0Q/k0Q/0/OhHnjz6kSePfuTJox958uhHnjz6kSePfuTJox958uhHnjz6kSePfuTJox958uhHnjz6kSePfuT1X4//Ak/1cA/g2LnI+oi1tbUl51XCdVgeBJ8CzgdV5r0ARoG6urqrwfvA78HevXtZNs7My5Wstm3bFk4zyXUgx48f/yGwEjjlyaMdeUzPTd9RHuUxT14l8piea9/hCOCiiFyGeGBvgmk7SENDA3f6GDgDCuTZ+RPgfeUcjM4HxfJuvPHGEeAP4NChQ8mXugZwkUnegZkzZ3IZyMFgwoQJXJY5OU8e7chjAvJoRR7lMbPHxSC26TBwPj45r7W1FYYaAmk1do8XttRMBXnzEuEyyfT4HeB8UCwP4m4HeeI5tzSXSa07v6iobXU2eUxFHhOQR1eq6LGxsZEPN7GtXASLgU6nSM7r6uq6ClARV1/c0ptgy/dAxCNvs/0J7yJ7LAAXUv4gkMf0vETkMQF5dEUe3fMS8dfj8uXLDR7p0LA5Mde+ww3gALDv0QGmT58erQw9+eSTefMsdHd3NwF6vBI4H1eSx+PHj48BjN+5cye+3l3JefJoQR6TkUcn5DFrnoXL0CPbAcrz6EZra2vQI3m+S7J8j88880y4jucfgfNxBcrH+ZBfAh8HqFpx+VC2H992223pefJoQR6TkUcrfecRP4R8WDR8chE8Vvf3sS893gQeffTRNwHm8c/m5uavAuc8ebQgj065bnl25LF3njxa8NpjY2Ojof+R0CNqQc65bnl2qu5x8+bN1fUId5DVHK7fhVvL9x65yPSdIGuePFqQR6dc5+u0II+98+TRgtced+zYYfmEbaxOr+lcKh5HjhxZXY+jR48OAyaAtra2zNcYyZNHC/LolOt8nRbksXeePFrw1iNfE7W/hMNP7ZYrczNdq4Fdu3Zd4h7vuOOOL4GPgg0AhbgPcBjdSZA1Tx4tyKNTrvN1WpDH3nnuHu0tqOx/vLcHqNzBHfn3AReIGC7o8TR4BATDknnyIWDBggWWA/LkLQT19fV0yKHQ/wbOx+bJuxVA6ROAA6yGAXx1OC7AOU8eDcijPMqjE/JoznPzSFONILYJdZsBvWF1h58sB4d7UaScEZ4HkXFzdPgwSC1nppBagMLQ426Q6diC5WOn4+ieR0u3aULk0Y48yqM8OiGP5jw3j6y2UM3ixYunRuSh+sJ21eCTxdaKkCHXrVQxOjo6OCNUxOMvQLl5rwG2d+JGsr7h9pJMgTwDnJ0DKlHTqufUHU558hhDHuUxT54BeSwhz1+PEEePjT2DyvmHZaKV8jweB78Dc+bMaQW7AP/EM3nEIeeJSpjoKUteBNYugvbxh4DzcXnzLJw8eZKt5xw755QnjzHkUR7z5Fko2ePyYKRc0GRqfucxxDxiwJxr+bSlpYU/edeC5Pms4HAZKJpXybNgLYBHlvhrwPlYe14nyHQa8GvAm/Dcc8+l58ljDHmUR3uePMpjGR4DgW4zOuJx0jJqoDLX/mlNpGuxpjeRLcOHD+cs7ekXlafesRHA40TA/s5Mx1bmTQPO031F4LQkrG4dPXo0PU8eDcijPMpj8Kk8BpTpMdNUx0FTa/JO9jzWBSKtp8n1nAAuaJI3z8xfAWc8wQ1Mb9x0ynsP2AqWLl2a6UxfBPQ4Y8aM9Dx5jCGP8mjO6w8ey8ee9wroLS3dY/n9VlzQKmhcfRVkK5wxj2//XwcGDRrE5lIuZjULHDx4kO87Wk6DZ+P7AcvY3t6eniePMeTRDXmUR3m05vUjj3yjYurUqTGP14BfgVtAbW0t57YNdqhO/2NQx6kbO3ZstoIl5vExFM+BbDLl6DgGbNq0iYMOHg5ARWhOAGfYHzZsGMs4EzjlyWMMeXRDHuVRHq15/cgj6ejoWA8eA/SIYn8ZRHbgDTnQQ3X6Hzkor3yPETgKj+uqnDhx4hh4HNRVsn379oTFVyrz5DGGPLohj/J4Th5tef3Mo/Jy5smjH3ny6EeePPqRJ49+5MmjH3ny6EeePPqRJ49+5MmjH3ny6EeePPqRJ49+5MmjH3ny6Ede6FFc2sijH8ijH8ijH8ijH8ijH8ijH/wfjEm7MQplbmRzdHJlYW0KZW5kb2JqCjE0IDAgb2JqCjM0NDMKZW5kb2JqCjIgMCBvYmoKPDwgL1R5cGUgL1BhZ2VzIC9LaWRzIFsgMTEgMCBSIF0gL0NvdW50IDEgPj4KZW5kb2JqCjE1IDAgb2JqCjw8IC9DcmVhdG9yIChNYXRwbG90bGliIHYzLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZykKL1Byb2R1Y2VyIChNYXRwbG90bGliIHBkZiBiYWNrZW5kIHYzLjguMCkgL0NyZWF0aW9uRGF0ZSAoRDoyMDIzMTAxMTE2MzI1OFopCj4+CmVuZG9iagp4cmVmCjAgMTYKMDAwMDAwMDAwMCA2NTUzNSBmIAowMDAwMDAwMDE2IDAwMDAwIG4gCjAwMDAwMDUxNzYgMDAwMDAgbiAKMDAwMDAwMDYwNyAwMDAwMCBuIAowMDAwMDAwNjI4IDAwMDAwIG4gCjAwMDAwMDA2ODggMDAwMDAgbiAKMDAwMDAwMDcwOSAwMDAwMCBuIAowMDAwMDAwNzMwIDAwMDAwIG4gCjAwMDAwMDAwNjUgMDAwMDAgbiAKMDAwMDAwMDM0NCAwMDAwMCBuIAowMDAwMDAwNTg3IDAwMDAwIG4gCjAwMDAwMDAyMDggMDAwMDAgbiAKMDAwMDAwMDU2NyAwMDAwMCBuIAowMDAwMDAwNzYyIDAwMDAwIG4gCjAwMDAwMDUxNTUgMDAwMDAgbiAKMDAwMDAwNTIzNiAwMDAwMCBuIAp0cmFpbGVyCjw8IC9TaXplIDE2IC9Sb290IDEgMCBSIC9JbmZvIDE1IDAgUiA+PgpzdGFydHhyZWYKNTM4NwolJUVPRgo=", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:32:58.298110\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["def show_imgs(imgs):\n", " num_imgs = imgs.shape[0] if isinstance(imgs, Tensor) else len(imgs)\n", " nrow = min(num_imgs, 4)\n", " ncol = int(math.ceil(num_imgs / nrow))\n", " imgs = torchvision.utils.make_grid(imgs, nrow=nrow, pad_value=128)\n", " imgs = imgs.clamp(min=0, max=255)\n", " np_imgs = imgs.cpu().numpy()\n", " plt.figure(figsize=(1.5 * nrow, 1.5 * ncol))\n", " plt.imshow(np.transpose(np_imgs, (1, 2, 0)), interpolation=\"nearest\")\n", " plt.axis(\"off\")\n", " plt.show()\n", " plt.close()\n", "\n", "\n", "show_imgs([train_set[i][0] for i in range(8)])"]}, {"cell_type": "markdown", "id": "02183103", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.012911, "end_time": "2023-10-11T16:32:58.470168", "exception": false, "start_time": "2023-10-11T16:32:58.457257", "status": "completed"}, "tags": []}, "source": ["## Masked autoregressive convolutions\n", "\n", "The core module of PixelCNN is its masked convolutions.\n", "In contrast to language models, we don't apply an LSTM on each pixel one-by-one.\n", "This would be inefficient because images are grids instead of sequences.\n", "Thus, it is better to rely on convolutions that have shown great success in deep CNN classification models.\n", "\n", "Nevertheless, we cannot just apply standard convolutions without any changes.\n", "Remember that during training of autoregressive models, we want to use teacher forcing which both helps the model training, and significantly reduces the time needed for training.\n", "For image modeling, teacher forcing is implemented by using a training image as input to the model, and we want to obtain as output the prediction for each pixel based on *only* its predecessors.\n", "Thus, we need to ensure that the prediction for a specific pixel can only be influenced by its predecessors and not by its own value or any \"future\" pixels.\n", "For this, we apply convolutions with a mask.\n", "\n", "Which mask we use depends on the ordering of pixels we decide on, i.e. which is the first pixel we predict,\n", "which is the second one, etc.\n", "The most commonly used ordering is to denote the upper left pixel as the start pixel,\n", "and sort the pixels row by row, as shown in the visualization at the top of the tutorial.\n", "Thus, the second pixel is on the right of the first one (first row, second column),\n", "and once we reach the end of the row, we start in the second row, first column.\n", "If we now want to apply this to our convolutions, we need to ensure that the prediction of pixel 1\n", "is not influenced by its own \"true\" input, and all pixels on its right and in any lower row.\n", "In convolutions, this means that we want to set those entries of the weight matrix to zero that take pixels on the right and below into account.\n", "As an example for a 5x5 kernel, see a mask below (figure credit - [Aaron van den Oord](https://arxiv.org/pdf/1606.05328.pdf)):\n", "\n", "
\n", "\n", "Before looking into the application of masked convolutions in PixelCNN\n", "in detail, let's first implement a module that allows us to apply an\n", "arbitrary mask to a convolution:"]}, {"cell_type": "code", "execution_count": 6, "id": "b20c6c1e", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:32:58.510742Z", "iopub.status.busy": "2023-10-11T16:32:58.510255Z", "iopub.status.idle": "2023-10-11T16:32:58.516902Z", "shell.execute_reply": "2023-10-11T16:32:58.516237Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.039385, "end_time": "2023-10-11T16:32:58.522867", "exception": false, "start_time": "2023-10-11T16:32:58.483482", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class MaskedConvolution(nn.Module):\n", " def __init__(self, c_in, c_out, mask, **kwargs):\n", " \"\"\"Implements a convolution with mask applied on its weights.\n", "\n", " Args:\n", " c_in: Number of input channels\n", " c_out: Number of output channels\n", " mask: Tensor of shape [kernel_size_H, kernel_size_W] with 0s where\n", " the convolution should be masked, and 1s otherwise.\n", " kwargs: Additional arguments for the convolution\n", " \"\"\"\n", " super().__init__()\n", " # For simplicity: calculate padding automatically\n", " kernel_size = (mask.shape[0], mask.shape[1])\n", " dilation = 1 if \"dilation\" not in kwargs else kwargs[\"dilation\"]\n", " padding = tuple(dilation * (kernel_size[i] - 1) // 2 for i in range(2))\n", " # Actual convolution\n", " self.conv = nn.Conv2d(c_in, c_out, kernel_size, padding=padding, **kwargs)\n", "\n", " # Mask as buffer => it is no parameter but still a tensor of the module\n", " # (must be moved with the devices)\n", " self.register_buffer(\"mask\", mask[None, None])\n", "\n", " def forward(self, x):\n", " self.conv.weight.data *= self.mask # Ensures zero's at masked positions\n", " return self.conv(x)"]}, {"cell_type": "markdown", "id": "e61f2bd5", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.013881, "end_time": "2023-10-11T16:32:58.550654", "exception": false, "start_time": "2023-10-11T16:32:58.536773", "status": "completed"}, "tags": []}, "source": ["### Vertical and horizontal convolution stacks\n", "\n", "To build our own autoregressive image model, we could simply stack a few masked convolutions on top of each other.\n", "This was actually the case for the original PixelCNN model, discussed in the paper\n", "[Pixel Recurrent Neural Networks](https://arxiv.org/pdf/1601.06759.pdf), but this leads to a considerable issue.\n", "When sequentially applying a couple of masked convolutions, the receptive field of a pixel\n", "show to have a \"blind spot\" on the right upper side, as shown in the figure below\n", "(figure credit - [Aaron van den Oord et al. ](https://arxiv.org/pdf/1606.05328.pdf)):\n", "\n", "
\n", "\n", "Although a pixel should be able to take into account all other pixels above and left of it,\n", "a stack of masked convolutions does not allow us to look to the upper pixels on the right.\n", "This is because the features of the pixels above, which we use for convolution,\n", "do not contain any information of the pixels on the right of the same row.\n", "If they would, we would be \"cheating\" and actually looking into the future.\n", "To overcome this issue, van den Oord et.\n", "al [2] proposed to split the convolutions into a vertical and a horizontal stack.\n", "The vertical stack looks at all pixels above the current one, while the horizontal takes into account all on the left.\n", "While keeping both of them separate, we can actually look at the pixels on the right with the vertical stack without breaking any of our assumptions.\n", "The two convolutions are also shown in the figure above.\n", "\n", "Let us implement them here as follows:"]}, {"cell_type": "code", "execution_count": 7, "id": "bc16dcf7", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:32:58.579953Z", "iopub.status.busy": "2023-10-11T16:32:58.579133Z", "iopub.status.idle": "2023-10-11T16:32:58.586061Z", "shell.execute_reply": "2023-10-11T16:32:58.585382Z"}, "papermill": {"duration": 0.024054, "end_time": "2023-10-11T16:32:58.587949", "exception": false, "start_time": "2023-10-11T16:32:58.563895", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class VerticalStackConvolution(MaskedConvolution):\n", " def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs):\n", " # Mask out all pixels below. For efficiency, we could also reduce the kernel\n", " # size in height, but for simplicity, we stick with masking here.\n", " mask = torch.ones(kernel_size, kernel_size)\n", " mask[kernel_size // 2 + 1 :, :] = 0\n", "\n", " # For the very first convolution, we will also mask the center row\n", " if mask_center:\n", " mask[kernel_size // 2, :] = 0\n", "\n", " super().__init__(c_in, c_out, mask, **kwargs)\n", "\n", "\n", "class HorizontalStackConvolution(MaskedConvolution):\n", " def __init__(self, c_in, c_out, kernel_size=3, mask_center=False, **kwargs):\n", " # Mask out all pixels on the left. Note that our kernel has a size of 1\n", " # in height because we only look at the pixel in the same row.\n", " mask = torch.ones(1, kernel_size)\n", " mask[0, kernel_size // 2 + 1 :] = 0\n", "\n", " # For the very first convolution, we will also mask the center pixel\n", " if mask_center:\n", " mask[0, kernel_size // 2] = 0\n", "\n", " super().__init__(c_in, c_out, mask, **kwargs)"]}, {"cell_type": "markdown", "id": "d6ad6a83", "metadata": {"papermill": {"duration": 0.013791, "end_time": "2023-10-11T16:32:58.615712", "exception": false, "start_time": "2023-10-11T16:32:58.601921", "status": "completed"}, "tags": []}, "source": ["Note that we have an input argument called `mask_center`. Remember that\n", "the input to the model is the actual input image. Hence, the very first\n", "convolution we apply cannot use the center pixel as input, but must be\n", "masked. All consecutive convolutions, however, should use the center\n", "pixel as we otherwise lose the features of the previous layer. Hence,\n", "the input argument `mask_center` is True for the very first\n", "convolutions, and False for all others."]}, {"cell_type": "markdown", "id": "e2cb5d6f", "metadata": {"papermill": {"duration": 0.013643, "end_time": "2023-10-11T16:32:58.642976", "exception": false, "start_time": "2023-10-11T16:32:58.629333", "status": "completed"}, "tags": []}, "source": ["### Visualizing the receptive field\n", "\n", "To validate our implementation of masked convolutions, we can visualize the receptive field we obtain with such convolutions.\n", "We should see that with increasing number of convolutional layers, the receptive field grows in both vertical and horizontal direction, without the issue of a blind spot.\n", "The receptive field can be empirically measured by backpropagating an arbitrary loss for the output features of a speicifc pixel with respect to the input.\n", "We implement this idea below, and visualize the receptive field below."]}, {"cell_type": "code", "execution_count": 8, "id": "ed9b059e", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:32:58.671501Z", "iopub.status.busy": "2023-10-11T16:32:58.670577Z", "iopub.status.idle": "2023-10-11T16:32:59.242548Z", "shell.execute_reply": "2023-10-11T16:32:59.241837Z"}, "papermill": {"duration": 0.592577, "end_time": "2023-10-11T16:32:59.248506", "exception": false, "start_time": "2023-10-11T16:32:58.655929", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:32:59.086723\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["inp_img = torch.zeros(1, 1, 11, 11)\n", "inp_img.requires_grad_()\n", "\n", "\n", "def show_center_recep_field(img, out):\n", " \"\"\"Calculates the gradients of the input with respect to the output center pixel, and visualizes the overall\n", " receptive field.\n", "\n", " Args:\n", " img: Input image for which we want to calculate the receptive field on.\n", " out: Output features/loss which is used for backpropagation, and should be\n", " the output of the network/computation graph.\n", " \"\"\"\n", " # Determine gradients\n", " loss = out[0, :, img.shape[2] // 2, img.shape[3] // 2].sum() # L1 loss for simplicity\n", " # Retain graph as we want to stack multiple layers and show the receptive field of all of them\n", " loss.backward(retain_graph=True)\n", " img_grads = img.grad.abs()\n", " img.grad.fill_(0) # Reset grads\n", "\n", " # Plot receptive field\n", " img = img_grads.squeeze().cpu().numpy()\n", " fig, ax = plt.subplots(1, 2)\n", " _ = ax[0].imshow(img)\n", " ax[1].imshow(img > 0)\n", " # Mark the center pixel in red if it doesn't have any gradients (should be\n", " # the case for standard autoregressive models)\n", " show_center = img[img.shape[0] // 2, img.shape[1] // 2] == 0\n", " if show_center:\n", " center_pixel = np.zeros(img.shape + (4,))\n", " center_pixel[center_pixel.shape[0] // 2, center_pixel.shape[1] // 2, :] = np.array([1.0, 0.0, 0.0, 1.0])\n", " for i in range(2):\n", " ax[i].axis(\"off\")\n", " if show_center:\n", " ax[i].imshow(center_pixel)\n", " ax[0].set_title(\"Weighted receptive field\")\n", " ax[1].set_title(\"Binary receptive field\")\n", " plt.show()\n", " plt.close()\n", "\n", "\n", "show_center_recep_field(inp_img, inp_img)"]}, {"cell_type": "markdown", "id": "4b2cc8e1", "metadata": {"papermill": {"duration": 0.023645, "end_time": "2023-10-11T16:32:59.294318", "exception": false, "start_time": "2023-10-11T16:32:59.270673", "status": "completed"}, "tags": []}, "source": ["Let's first visualize the receptive field of a horizontal convolution\n", "without the center pixel. We use a small, arbitrary input image\n", "($11\\times 11$ pixels), and calculate the loss for the center pixel. For\n", "simplicity, we initialize all weights with 1 and the bias with 0, and\n", "use a single channel. This is sufficient for our visualization purposes."]}, {"cell_type": "code", "execution_count": 9, "id": "693bf136", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:32:59.325572Z", "iopub.status.busy": "2023-10-11T16:32:59.324600Z", "iopub.status.idle": "2023-10-11T16:32:59.664773Z", "shell.execute_reply": "2023-10-11T16:32:59.664021Z"}, "papermill": {"duration": 0.358925, "end_time": "2023-10-11T16:32:59.667560", "exception": false, "start_time": "2023-10-11T16:32:59.308635", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:32:59.497963\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["horiz_conv = HorizontalStackConvolution(c_in=1, c_out=1, kernel_size=3, mask_center=True)\n", "horiz_conv.conv.weight.data.fill_(1)\n", "horiz_conv.conv.bias.data.fill_(0)\n", "horiz_img = horiz_conv(inp_img)\n", "show_center_recep_field(inp_img, horiz_img)"]}, {"cell_type": "markdown", "id": "42097881", "metadata": {"papermill": {"duration": 0.016191, "end_time": "2023-10-11T16:32:59.700384", "exception": false, "start_time": "2023-10-11T16:32:59.684193", "status": "completed"}, "tags": []}, "source": ["The receptive field is shown in yellow, the center pixel in red, and all other pixels outside of the receptive field are dark blue.\n", "As expected, the receptive field of a single horizontal convolution with the center pixel masked and a $3\\times3$ kernel is only the pixel on the left.\n", "If we use a larger kernel size, more pixels would be taken into account on the left.\n", "\n", "Next, let's take a look at the vertical convolution:"]}, {"cell_type": "code", "execution_count": 10, "id": "5bdf4270", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:32:59.735907Z", "iopub.status.busy": "2023-10-11T16:32:59.735254Z", "iopub.status.idle": "2023-10-11T16:33:00.098283Z", "shell.execute_reply": "2023-10-11T16:33:00.097301Z"}, "papermill": {"duration": 0.386416, "end_time": "2023-10-11T16:33:00.103087", "exception": false, "start_time": "2023-10-11T16:32:59.716671", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:32:59.935650\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["vert_conv = VerticalStackConvolution(c_in=1, c_out=1, kernel_size=3, mask_center=True)\n", "vert_conv.conv.weight.data.fill_(1)\n", "vert_conv.conv.bias.data.fill_(0)\n", "vert_img = vert_conv(inp_img)\n", "show_center_recep_field(inp_img, vert_img)"]}, {"cell_type": "markdown", "id": "777aa66b", "metadata": {"papermill": {"duration": 0.017694, "end_time": "2023-10-11T16:33:00.148304", "exception": false, "start_time": "2023-10-11T16:33:00.130610", "status": "completed"}, "tags": []}, "source": ["The vertical convolution takes all pixels above into account. Combining\n", "these two, we get the L-shaped receptive field of the original masked\n", "convolution:"]}, {"cell_type": "code", "execution_count": 11, "id": "9c0e0e05", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:00.187044Z", "iopub.status.busy": "2023-10-11T16:33:00.186447Z", "iopub.status.idle": "2023-10-11T16:33:00.518898Z", "shell.execute_reply": "2023-10-11T16:33:00.518237Z"}, "papermill": {"duration": 0.358833, "end_time": "2023-10-11T16:33:00.524310", "exception": false, "start_time": "2023-10-11T16:33:00.165477", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:33:00.357656\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["horiz_img = vert_img + horiz_img\n", "show_center_recep_field(inp_img, horiz_img)"]}, {"cell_type": "markdown", "id": "b1ffbf74", "metadata": {"papermill": {"duration": 0.016053, "end_time": "2023-10-11T16:33:00.555854", "exception": false, "start_time": "2023-10-11T16:33:00.539801", "status": "completed"}, "tags": []}, "source": ["If we stack multiple horizontal and vertical convolutions, we need to take two aspects into account:\n", "\n", "1.\n", "The center should not be masked anymore for the following convolutions as the features at the pixel's position are already independent of its actual value.\n", "If it is hard to imagine why we can do this, just change the value below to `mask_center=True` and see what happens.\n", "2.\n", "The vertical convolution is not allowed to work on features from the horizontal convolution.\n", "In the feature map of the horizontal convolutions, a pixel contains information about all of the \"true\" pixels on the left.\n", "If we apply a vertical convolution which also uses features from the right, we effectively expand our receptive field to the true input which we want to prevent.\n", "Thus, the feature maps can only be merged for the horizontal convolution.\n", "\n", "Using this, we can stack the convolutions in the following way. We have\n", "two feature streams: one for the vertical stack, and one for the\n", "horizontal stack. The horizontal convolutions can operate on the joint\n", "features of the previous horizontals and vertical convolutions, while\n", "the vertical stack only takes its own previous features as input. For a\n", "quick implementation, we can therefore sum the horizontal and vertical\n", "output features at each layer, and use those as final output features to\n", "calculate the loss on. An implementation of 4 consecutive layers is\n", "shown below. Note that we reuse the features from the other convolutions\n", "with `mask_center=True` from above."]}, {"cell_type": "code", "execution_count": 12, "id": "b8d27ef2", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:00.588553Z", "iopub.status.busy": "2023-10-11T16:33:00.587962Z", "iopub.status.idle": "2023-10-11T16:33:02.580244Z", "shell.execute_reply": "2023-10-11T16:33:02.579479Z"}, "papermill": {"duration": 2.014108, "end_time": "2023-10-11T16:33:02.586272", "exception": false, "start_time": "2023-10-11T16:33:00.572164", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Layer 2\n"]}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:33:00.852363\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Layer 3\n"]}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:33:01.342255\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Layer 4\n"]}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:33:01.860988\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Layer 5\n"]}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:33:02.421336\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["# Initialize convolutions with equal weight to all input pixels\n", "horiz_conv = HorizontalStackConvolution(c_in=1, c_out=1, kernel_size=3, mask_center=False)\n", "horiz_conv.conv.weight.data.fill_(1)\n", "horiz_conv.conv.bias.data.fill_(0)\n", "vert_conv = VerticalStackConvolution(c_in=1, c_out=1, kernel_size=3, mask_center=False)\n", "vert_conv.conv.weight.data.fill_(1)\n", "vert_conv.conv.bias.data.fill_(0)\n", "\n", "# We reuse our convolutions for the 4 layers here. Note that in a standard network,\n", "# we don't do that, and instead learn 4 separate convolution. As this cell is only for\n", "# visualization purposes, we reuse the convolutions for all layers.\n", "for l_idx in range(4):\n", " vert_img = vert_conv(vert_img)\n", " horiz_img = horiz_conv(horiz_img) + vert_img\n", " print(\"Layer %i\" % (l_idx + 2))\n", " show_center_recep_field(inp_img, horiz_img)"]}, {"cell_type": "markdown", "id": "0c34bd39", "metadata": {"papermill": {"duration": 0.018461, "end_time": "2023-10-11T16:33:02.628825", "exception": false, "start_time": "2023-10-11T16:33:02.610364", "status": "completed"}, "tags": []}, "source": ["The receptive field above it visualized for the horizontal stack, which includes the features of the vertical convolutions.\n", "It grows over layers without any blind spot as we had before.\n", "The difference between \"weighted\" and \"binary\" receptive field is that for the latter, we check whether there are any gradients flowing back to this pixel.\n", "This indicates that the center pixel indeed can use information from this pixel.\n", "Nevertheless, due to the convolution weights, some pixels have a stronger effect on the prediction than others.\n", "This is visualized in the weighted receptive field by plotting the gradient magnitude for each pixel instead of a binary yes/no.\n", "\n", "\n", "Another receptive field we can check is the one for the vertical stack\n", "as the one above is for the horizontal stack. Let's visualize it below:"]}, {"cell_type": "code", "execution_count": 13, "id": "6cb0b47d", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:02.679685Z", "iopub.status.busy": "2023-10-11T16:33:02.678503Z", "iopub.status.idle": "2023-10-11T16:33:03.078740Z", "shell.execute_reply": "2023-10-11T16:33:03.078030Z"}, "papermill": {"duration": 0.430988, "end_time": "2023-10-11T16:33:03.080865", "exception": false, "start_time": "2023-10-11T16:33:02.649877", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:33:02.906042\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["show_center_recep_field(inp_img, vert_img)"]}, {"cell_type": "markdown", "id": "4f11b075", "metadata": {"papermill": {"duration": 0.025983, "end_time": "2023-10-11T16:33:03.129794", "exception": false, "start_time": "2023-10-11T16:33:03.103811", "status": "completed"}, "tags": []}, "source": ["As we have discussed before, the vertical stack only looks at pixels above the one we want to predict.\n", "Hence, we can validate that our implementation works as we initially expected it to.\n", "As a final step, let's clean up the computation graph we still had kept\n", "in memory for the visualization of the receptive field:"]}, {"cell_type": "code", "execution_count": 14, "id": "6a3e2981", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:03.177517Z", "iopub.status.busy": "2023-10-11T16:33:03.173856Z", "iopub.status.idle": "2023-10-11T16:33:03.187847Z", "shell.execute_reply": "2023-10-11T16:33:03.186786Z"}, "papermill": {"duration": 0.035686, "end_time": "2023-10-11T16:33:03.189444", "exception": false, "start_time": "2023-10-11T16:33:03.153758", "status": "completed"}, "tags": []}, "outputs": [], "source": ["del inp_img, horiz_conv, vert_conv"]}, {"cell_type": "markdown", "id": "ce9f1ee6", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.019031, "end_time": "2023-10-11T16:33:03.227796", "exception": false, "start_time": "2023-10-11T16:33:03.208765", "status": "completed"}, "tags": []}, "source": ["## Gated PixelCNN\n", "\n", "
\n", "\n", "In the next step, we will use the masked convolutions to build a full autoregressive model, called Gated PixelCNN.\n", "The difference between the original PixelCNN and Gated PixelCNN is the use of separate horizontal and vertical stacks.\n", "However, in literature, you often see that people refer to the Gated PixelCNN simply as \"PixelCNN\".\n", "Hence, in the following, if we say \"PixelCNN\", we usually mean the gated version.\n", "What \"Gated\" refers to in the model name is explained next.\n", "\n", "### Gated Convolutions\n", "\n", "For visualizing the receptive field, we assumed a very simplified stack of vertical and horizontal convolutions.\n", "Obviously, there are more sophisticated ways of doing it, and PixelCNN uses gated convolutions for this.\n", "Specifically, the Gated Convolution block in PixelCNN looks as follows\n", "(figure credit - [Aaron van den Oord et al. ](https://arxiv.org/pdf/1606.05328.pdf)):\n", "\n", "
\n", "\n", "The left path is the vertical stack (the $N\\times N$ convolution is masked correspondingly),\n", "and the right path is the horizontal stack.\n", "Gated convolutions are implemented by having a twice as large output channel size,\n", "and combine them by a element-wise multiplication of $\\tanh$ and a sigmoid.\n", "For a linear layer, we can express a gated activation unit as follows:\n", "\n", "$$\\mathbf{y} = \\tanh\\left(\\mathbf{W}_{f}\\mathbf{x}\\right)\\odot\\sigma\\left(\\mathbf{W}_{g}\\mathbf{x}\\right)$$\n", "\n", "For simplicity, biases have been neglected and the linear layer split into two part, $\\mathbf{W}_{f}$ and $\\mathbf{W}_{g}$.\n", "This concept resembles the input and modulation gate in an LSTM, and has been used in many other architectures as well.\n", "The main motivation behind this gated activation is that it might allow to model more complex interactions and simplifies learning.\n", "But as in any other architecture, this is mostly a design choice and can be considered a hyperparameters.\n", "\n", "Besides the gated convolutions, we also see that the horizontal stack uses a residual connection while the vertical stack does not.\n", "This is because we use the output of the horizontal stack for prediction.\n", "Each convolution in the vertical stack also receives a strong gradient signal\n", "as it is only two $1\\times 1$ convolutions away from the residual connection,\n", "and does not require another residual connection to all its earleri layers.\n", "\n", "The implementation in PyTorch is fairly straight forward for this block,\n", "because the visualization above gives us a computation graph to follow:"]}, {"cell_type": "code", "execution_count": 15, "id": "47bf5a5e", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:03.290181Z", "iopub.status.busy": "2023-10-11T16:33:03.289672Z", "iopub.status.idle": "2023-10-11T16:33:03.298299Z", "shell.execute_reply": "2023-10-11T16:33:03.297335Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.041182, "end_time": "2023-10-11T16:33:03.300732", "exception": false, "start_time": "2023-10-11T16:33:03.259550", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class GatedMaskedConv(nn.Module):\n", " def __init__(self, c_in, **kwargs):\n", " \"\"\"Gated Convolution block implemented the computation graph shown above.\"\"\"\n", " super().__init__()\n", " self.conv_vert = VerticalStackConvolution(c_in, c_out=2 * c_in, **kwargs)\n", " self.conv_horiz = HorizontalStackConvolution(c_in, c_out=2 * c_in, **kwargs)\n", " self.conv_vert_to_horiz = nn.Conv2d(2 * c_in, 2 * c_in, kernel_size=1, padding=0)\n", " self.conv_horiz_1x1 = nn.Conv2d(c_in, c_in, kernel_size=1, padding=0)\n", "\n", " def forward(self, v_stack, h_stack):\n", " # Vertical stack (left)\n", " v_stack_feat = self.conv_vert(v_stack)\n", " v_val, v_gate = v_stack_feat.chunk(2, dim=1)\n", " v_stack_out = torch.tanh(v_val) * torch.sigmoid(v_gate)\n", "\n", " # Horizontal stack (right)\n", " h_stack_feat = self.conv_horiz(h_stack)\n", " h_stack_feat = h_stack_feat + self.conv_vert_to_horiz(v_stack_feat)\n", " h_val, h_gate = h_stack_feat.chunk(2, dim=1)\n", " h_stack_feat = torch.tanh(h_val) * torch.sigmoid(h_gate)\n", " h_stack_out = self.conv_horiz_1x1(h_stack_feat)\n", " h_stack_out = h_stack_out + h_stack\n", "\n", " return v_stack_out, h_stack_out"]}, {"cell_type": "markdown", "id": "70fb33a4", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.023851, "end_time": "2023-10-11T16:33:03.349989", "exception": false, "start_time": "2023-10-11T16:33:03.326138", "status": "completed"}, "tags": []}, "source": ["### Building the model\n", "\n", "Using the gated convolutions, we can now build our PixelCNN model.\n", "The architecture consists of multiple stacked GatedMaskedConv blocks, where we add an additional dilation factor to a few convolutions.\n", "This is used to increase the receptive field of the model and allows to take a larger context into account during generation.\n", "As a reminder, dilation on a convolution works looks as follows\n", "(figure credit - [Vincent Dumoulin and Francesco Visin](https://arxiv.org/pdf/1603.07285.pdf)):\n", "\n", "
\n", "\n", "Note that the smaller output size is only because the animation assumes no padding.\n", "In our implementation, we will pad the input image correspondingly.\n", "Alternatively to dilated convolutions, we could downsample the input and use a encoder-decoder architecture as in PixelCNN++ [3].\n", "This is especially beneficial if we want to build a very deep autoregressive model.\n", "Nonetheless, as we seek to train a reasonably small model, dilated convolutions are the more efficient option to use here.\n", "\n", "Below, we implement the PixelCNN model as a PyTorch Lightning module.\n", "Besides the stack of gated convolutions, we also have the initial\n", "horizontal and vertical convolutions which mask the center pixel, and a\n", "final $1\\times 1$ convolution which maps the output features to class\n", "predictions. To determine the likelihood of a batch of images, we first\n", "create our initial features using the masked horizontal and vertical\n", "input convolution. Next, we forward the features through the stack of\n", "gated convolutions. Finally, we take the output features of the\n", "horizontal stack, and apply the $1\\times 1$ convolution for\n", "classification. We use the bits per dimension metric for the likelihood,\n", "similarly to Tutorial 11 and assignment 3."]}, {"cell_type": "code", "execution_count": 16, "id": "227ece04", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:03.399604Z", "iopub.status.busy": "2023-10-11T16:33:03.399067Z", "iopub.status.idle": "2023-10-11T16:33:03.417286Z", "shell.execute_reply": "2023-10-11T16:33:03.416282Z"}, "papermill": {"duration": 0.044676, "end_time": "2023-10-11T16:33:03.419123", "exception": false, "start_time": "2023-10-11T16:33:03.374447", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class PixelCNN(L.LightningModule):\n", " def __init__(self, c_in, c_hidden):\n", " super().__init__()\n", " self.save_hyperparameters()\n", "\n", " # Initial convolutions skipping the center pixel\n", " self.conv_vstack = VerticalStackConvolution(c_in, c_hidden, mask_center=True)\n", " self.conv_hstack = HorizontalStackConvolution(c_in, c_hidden, mask_center=True)\n", " # Convolution block of PixelCNN. We use dilation instead of downscaling\n", " self.conv_layers = nn.ModuleList(\n", " [\n", " GatedMaskedConv(c_hidden),\n", " GatedMaskedConv(c_hidden, dilation=2),\n", " GatedMaskedConv(c_hidden),\n", " GatedMaskedConv(c_hidden, dilation=4),\n", " GatedMaskedConv(c_hidden),\n", " GatedMaskedConv(c_hidden, dilation=2),\n", " GatedMaskedConv(c_hidden),\n", " ]\n", " )\n", " # Output classification convolution (1x1)\n", " self.conv_out = nn.Conv2d(c_hidden, c_in * 256, kernel_size=1, padding=0)\n", "\n", " self.example_input_array = train_set[0][0][None]\n", "\n", " def forward(self, x):\n", " \"\"\"Forward image through model and return logits for each pixel.\n", "\n", " Args:\n", " x: Image tensor with integer values between 0 and 255.\n", " \"\"\"\n", " # Scale input from 0 to 255 back to -1 to 1\n", " x = (x.float() / 255.0) * 2 - 1\n", "\n", " # Initial convolutions\n", " v_stack = self.conv_vstack(x)\n", " h_stack = self.conv_hstack(x)\n", " # Gated Convolutions\n", " for layer in self.conv_layers:\n", " v_stack, h_stack = layer(v_stack, h_stack)\n", " # 1x1 classification convolution\n", " # Apply ELU before 1x1 convolution for non-linearity on residual connection\n", " out = self.conv_out(F.elu(h_stack))\n", "\n", " # Output dimensions: [Batch, Classes, Channels, Height, Width]\n", " out = out.reshape(out.shape[0], 256, out.shape[1] // 256, out.shape[2], out.shape[3])\n", " return out\n", "\n", " def calc_likelihood(self, x):\n", " # Forward pass with bpd likelihood calculation\n", " pred = self.forward(x)\n", " nll = F.cross_entropy(pred, x, reduction=\"none\")\n", " bpd = nll.mean(dim=[1, 2, 3]) * np.log2(np.exp(1))\n", " return bpd.mean()\n", "\n", " @torch.no_grad()\n", " def sample(self, img_shape, img=None):\n", " \"\"\"Sampling function for the autoregressive model.\n", "\n", " Args:\n", " img_shape: Shape of the image to generate (B,C,H,W)\n", " img (optional): If given, this tensor will be used as\n", " a starting image. The pixels to fill\n", " should be -1 in the input tensor.\n", " \"\"\"\n", " # Create empty image\n", " if img is None:\n", " img = torch.zeros(img_shape, dtype=torch.long).to(device) - 1\n", " # Generation loop\n", " for h in tqdm(range(img_shape[2]), leave=False):\n", " for w in range(img_shape[3]):\n", " for c in range(img_shape[1]):\n", " # Skip if not to be filled (-1)\n", " if (img[:, c, h, w] != -1).all().item():\n", " continue\n", " # For efficiency, we only have to input the upper part of the image\n", " # as all other parts will be skipped by the masked convolutions anyways\n", " pred = self.forward(img[:, :, : h + 1, :])\n", " probs = F.softmax(pred[:, :, c, h, w], dim=-1)\n", " img[:, c, h, w] = torch.multinomial(probs, num_samples=1).squeeze(dim=-1)\n", " return img\n", "\n", " def configure_optimizers(self):\n", " optimizer = optim.Adam(self.parameters(), lr=1e-3)\n", " scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.99)\n", " return [optimizer], [scheduler]\n", "\n", " def training_step(self, batch, batch_idx):\n", " loss = self.calc_likelihood(batch[0])\n", " self.log(\"train_bpd\", loss)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " loss = self.calc_likelihood(batch[0])\n", " self.log(\"val_bpd\", loss)\n", "\n", " def test_step(self, batch, batch_idx):\n", " loss = self.calc_likelihood(batch[0])\n", " self.log(\"test_bpd\", loss)"]}, {"cell_type": "markdown", "id": "ed569c97", "metadata": {"papermill": {"duration": 0.023337, "end_time": "2023-10-11T16:33:03.476301", "exception": false, "start_time": "2023-10-11T16:33:03.452964", "status": "completed"}, "tags": []}, "source": ["To sample from the autoregressive model, we need to iterate over all dimensions of the input.\n", "We start with an empty image, and fill the pixels one by one, starting from the upper left corner.\n", "Note that as for predicting $x_i$, all pixels below it have no influence on the prediction.\n", "Hence, we can cut the image in height without changing the prediction while increasing efficiency.\n", "Nevertheless, all the loops in the sampling function already show that it will take us quite some time to sample.\n", "A lot of computation could be reused across loop iterations as those the features on the already predicted pixels will not change over iterations.\n", "Nevertheless, this takes quite some effort to implement, and is often not done in implementations because in the end,\n", "autoregressive sampling remains sequential and slow.\n", "Hence, we settle with the default implementation here.\n", "\n", "Before training the model, we can check the full receptive field of the model on an MNIST image of size $28\\times 28$:"]}, {"cell_type": "code", "execution_count": 17, "id": "37ef0017", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:03.526873Z", "iopub.status.busy": "2023-10-11T16:33:03.526391Z", "iopub.status.idle": "2023-10-11T16:33:13.921338Z", "shell.execute_reply": "2023-10-11T16:33:13.911749Z"}, "papermill": {"duration": 10.42969, "end_time": "2023-10-11T16:33:13.930120", "exception": false, "start_time": "2023-10-11T16:33:03.500430", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:33:13.736191\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["test_model = PixelCNN(c_in=1, c_hidden=64)\n", "inp = torch.zeros(1, 1, 28, 28)\n", "inp.requires_grad_()\n", "out = test_model(inp)\n", "show_center_recep_field(inp, out.squeeze(dim=2))\n", "del inp, out, test_model"]}, {"cell_type": "markdown", "id": "1f00c4b2", "metadata": {"papermill": {"duration": 0.023927, "end_time": "2023-10-11T16:33:13.977670", "exception": false, "start_time": "2023-10-11T16:33:13.953743", "status": "completed"}, "tags": []}, "source": ["The visualization shows that for predicting any pixel, we can take almost half of the image into account.\n", "However, keep in mind that this is the \"theoretical\" receptive field and not necessarily\n", "the [effective receptive field](https://arxiv.org/pdf/1701.04128.pdf), which is usually much smaller.\n", "For a stronger model, we should therefore try to increase the receptive\n", "field even further. Especially, for the pixel on the bottom right, the\n", "very last pixel, we would be allowed to take into account the whole\n", "image. However, our current receptive field only spans across 1/4 of the\n", "image. An encoder-decoder architecture can help with this, but it also\n", "shows that we require a much deeper, more complex network in\n", "autoregressive models than in VAEs or energy-based models."]}, {"cell_type": "markdown", "id": "d26b4d9e", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.02414, "end_time": "2023-10-11T16:33:14.026163", "exception": false, "start_time": "2023-10-11T16:33:14.002023", "status": "completed"}, "tags": []}, "source": ["### Training loop\n", "\n", "To train the model, we again can rely on PyTorch Lightning and write a\n", "function below for loading the pretrained model if it exists. To reduce\n", "the computational cost, we have saved the validation and test score in\n", "the checkpoint already:"]}, {"cell_type": "code", "execution_count": 18, "id": "5efc54ef", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:14.076835Z", "iopub.status.busy": "2023-10-11T16:33:14.076302Z", "iopub.status.idle": "2023-10-11T16:33:14.085316Z", "shell.execute_reply": "2023-10-11T16:33:14.084465Z"}, "papermill": {"duration": 0.03622, "end_time": "2023-10-11T16:33:14.086824", "exception": false, "start_time": "2023-10-11T16:33:14.050604", "status": "completed"}, "tags": []}, "outputs": [], "source": ["def train_model(**kwargs):\n", " # Create a PyTorch Lightning trainer with the generation callback\n", " trainer = L.Trainer(\n", " default_root_dir=os.path.join(CHECKPOINT_PATH, \"PixelCNN\"),\n", " accelerator=\"auto\",\n", " devices=1,\n", " max_epochs=150,\n", " callbacks=[\n", " ModelCheckpoint(save_weights_only=True, mode=\"min\", monitor=\"val_bpd\"),\n", " LearningRateMonitor(\"epoch\"),\n", " ],\n", " )\n", " result = None\n", " # Check whether pretrained model exists. If yes, load it and skip training\n", " pretrained_filename = os.path.join(CHECKPOINT_PATH, \"PixelCNN.ckpt\")\n", " if os.path.isfile(pretrained_filename):\n", " print(\"Found pretrained model, loading...\")\n", " model = PixelCNN.load_from_checkpoint(pretrained_filename)\n", " ckpt = torch.load(pretrained_filename, map_location=device)\n", " result = ckpt.get(\"result\", None)\n", " else:\n", " model = PixelCNN(**kwargs)\n", " trainer.fit(model, train_loader, val_loader)\n", " model = model.to(device)\n", "\n", " if result is None:\n", " # Test best model on validation and test set\n", " val_result = trainer.test(model, dataloaders=val_loader, verbose=False)\n", " test_result = trainer.test(model, dataloaders=test_loader, verbose=False)\n", " result = {\"test\": test_result, \"val\": val_result}\n", " return model, result"]}, {"cell_type": "markdown", "id": "458169ad", "metadata": {"papermill": {"duration": 0.024165, "end_time": "2023-10-11T16:33:14.135326", "exception": false, "start_time": "2023-10-11T16:33:14.111161", "status": "completed"}, "tags": []}, "source": ["Training the model is time consuming and we recommend using the provided pre-trained model for going through this notebook.\n", "However, feel free to play around with the hyperparameter like number of layers etc.\n", "if you want to get a feeling for those.\n", "\n", "When calling the training function with a pre-trained model, we automatically load it and print its test performance:"]}, {"cell_type": "code", "execution_count": 19, "id": "e4c53b11", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:14.186116Z", "iopub.status.busy": "2023-10-11T16:33:14.185398Z", "iopub.status.idle": "2023-10-11T16:33:15.110991Z", "shell.execute_reply": "2023-10-11T16:33:15.110044Z"}, "papermill": {"duration": 0.95267, "end_time": "2023-10-11T16:33:15.112725", "exception": false, "start_time": "2023-10-11T16:33:14.160055", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", " warning_cache.warn(\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Found pretrained model, loading...\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Lightning automatically upgraded your loaded checkpoint from v0.9.0 to v2.0.9.post0. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint --file saved_models/tutorial12/PixelCNN.ckpt`\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Test bits per dimension: 0.808bpd\n"]}], "source": ["model, result = train_model(c_in=1, c_hidden=64)\n", "test_res = result[\"test\"][0]\n", "print(\n", " \"Test bits per dimension: %4.3fbpd\" % (test_res[\"test_loss\"] if \"test_loss\" in test_res else test_res[\"test_bpd\"])\n", ")"]}, {"cell_type": "markdown", "id": "b190872e", "metadata": {"papermill": {"duration": 0.025107, "end_time": "2023-10-11T16:33:15.163281", "exception": false, "start_time": "2023-10-11T16:33:15.138174", "status": "completed"}, "tags": []}, "source": ["With a test performance of 0.809bpd, the PixelCNN significantly outperforms the normalizing flows we have seen in Tutorial 11.\n", "Considering image modeling as an autoregressive problem simplifies the learning process as predicting\n", "one pixel given the ground truth of all others is much easier than predicting all pixels at once.\n", "In addition, PixelCNN can explicitly predict the pixel values by a discrete softmax while\n", "Normalizing Flows have to learn transformations in continuous latent space.\n", "These two aspects allow the PixelCNN to achieve a notably better performance.\n", "\n", "To fully compare the models, let's also measure the number of parameters of the PixelCNN:"]}, {"cell_type": "code", "execution_count": 20, "id": "06692cba", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:15.217137Z", "iopub.status.busy": "2023-10-11T16:33:15.216804Z", "iopub.status.idle": "2023-10-11T16:33:15.223414Z", "shell.execute_reply": "2023-10-11T16:33:15.222521Z"}, "papermill": {"duration": 0.036785, "end_time": "2023-10-11T16:33:15.224982", "exception": false, "start_time": "2023-10-11T16:33:15.188197", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Number of parameters: 852,160\n"]}], "source": ["num_params = sum(np.prod(param.shape) for param in model.parameters())\n", "print(f\"Number of parameters: {num_params:,}\")"]}, {"cell_type": "markdown", "id": "6c29b2dc", "metadata": {"papermill": {"duration": 0.025038, "end_time": "2023-10-11T16:33:15.275381", "exception": false, "start_time": "2023-10-11T16:33:15.250343", "status": "completed"}, "tags": []}, "source": ["Compared to the multi-scale normalizing flows, the PixelCNN has considerably less parameters.\n", "Of course, the number of parameters depend on our hyperparameter choices.\n", "Nevertheless, in general, it can be said that autoregressive models\n", "require considerably less parameters than normalizing flows to reach\n", "good performance, based on the reasons stated above. Still,\n", "autoregressive models are much slower in sampling than normalizing\n", "flows, which limits their possible applications."]}, {"cell_type": "markdown", "id": "650e8e62", "metadata": {"papermill": {"duration": 0.024652, "end_time": "2023-10-11T16:33:15.324881", "exception": false, "start_time": "2023-10-11T16:33:15.300229", "status": "completed"}, "tags": []}, "source": ["## Sampling\n", "\n", "One way of qualitatively analysing generative models is by looking at the actual samples.\n", "Let's therefore use our sampling function to generate a few digits:"]}, {"cell_type": "code", "execution_count": 21, "id": "18b6732a", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:33:15.376913Z", "iopub.status.busy": "2023-10-11T16:33:15.376547Z", "iopub.status.idle": "2023-10-11T16:36:13.221467Z", "shell.execute_reply": "2023-10-11T16:36:13.220786Z"}, "papermill": {"duration": 177.87393, "end_time": "2023-10-11T16:36:13.223466", "exception": false, "start_time": "2023-10-11T16:33:15.349536", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 1\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "41ccb9d6e78546d0ac13c9a76568eb34", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/28 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:36:13.161773\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["L.seed_everything(1)\n", "samples = model.sample(img_shape=(16, 1, 28, 28))\n", "show_imgs(samples.cpu())"]}, {"cell_type": "markdown", "id": "f4b907f6", "metadata": {"papermill": {"duration": 0.020867, "end_time": "2023-10-11T16:36:13.267347", "exception": false, "start_time": "2023-10-11T16:36:13.246480", "status": "completed"}, "tags": []}, "source": ["Most of the samples can be identified as digits, and overall we achieve a better quality than we had in normalizing flows.\n", "This goes along with the lower likelihood we achieved with autoregressive models.\n", "Nevertheless, we also see that there is still place for improvement\n", "as a considerable amount of samples cannot be identified (for example the first row).\n", "Deeper autoregressive models are expected to achieve better quality,\n", "as they can take more context into account for generating the pixels.\n", "\n", "Note that on Google Colab, you might see different results, specifically with a white line at the top.\n", "After some debugging, it seemed that the difference occurs inside the dilated convolution,\n", "as it gives different results for different batch sizes.\n", "However, it is hard to debug this further as it might be a bug of the installed PyTorch version on Google Colab.\n", "\n", "The trained model itself is not restricted to any specific image size.\n", "However, what happens if we actually sample a larger image than we had\n", "seen in our training dataset? Let's try below to sample images of size\n", "$64\\times64$ instead of $28\\times28$:"]}, {"cell_type": "code", "execution_count": 22, "id": "4b949b08", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:36:13.313820Z", "iopub.status.busy": "2023-10-11T16:36:13.313477Z", "iopub.status.idle": "2023-10-11T16:43:24.891241Z", "shell.execute_reply": "2023-10-11T16:43:24.890046Z"}, "papermill": {"duration": 431.60664, "end_time": "2023-10-11T16:43:24.894244", "exception": false, "start_time": "2023-10-11T16:36:13.287604", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 1\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "3c1504bd07f7429992e82357ed0c93c3", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/64 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:24.829956\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["L.seed_everything(1)\n", "samples = model.sample(img_shape=(8, 1, 64, 64))\n", "show_imgs(samples.cpu())"]}, {"cell_type": "markdown", "id": "15942ca5", "metadata": {"papermill": {"duration": 0.021495, "end_time": "2023-10-11T16:43:24.942585", "exception": false, "start_time": "2023-10-11T16:43:24.921090", "status": "completed"}, "tags": []}, "source": ["The larger images show that changing the size of the image during testing confuses the model\n", "and generates abstract figures (you can sometimes spot a digit in the upper left corner).\n", "In addition, sampling for images of 64x64 pixels take more than a minute on a GPU.\n", "Clearly, autoregressive models cannot be scaled to large images without changing the sampling procedure such as with [forecasting](https://arxiv.org/abs/2002.09928).\n", "Our implementation is also not the most efficient as many computations can be stored and reused throughout the sampling process.\n", "Nevertheless, the sampling procedure stays sequential which is\n", "inherently slower than parallel generation like done in normalizing\n", "flows."]}, {"cell_type": "markdown", "id": "f75b50b6", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.021261, "end_time": "2023-10-11T16:43:24.985654", "exception": false, "start_time": "2023-10-11T16:43:24.964393", "status": "completed"}, "tags": []}, "source": ["### Autocompletion\n", "\n", "One common application done with autoregressive models is\n", "auto-completing an image. As autoregressive models predict pixels one by\n", "one, we can set the first $N$ pixels to predefined values and check how\n", "the model completes the image. For implementing this, we just need to\n", "skip the iterations in the sampling loop that already have a value\n", "unequals -1. See above in our PyTorch Lightning module for the specific\n", "implementation. In the cell below, we randomly take three images from\n", "the training set, mask about the lower half of the image, and let the\n", "model autocomplete it. To see the diversity of samples, we do this 12\n", "times for each image:"]}, {"cell_type": "code", "execution_count": 23, "id": "6b4a0719", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:43:25.029639Z", "iopub.status.busy": "2023-10-11T16:43:25.029208Z", "iopub.status.idle": "2023-10-11T16:43:30.928842Z", "shell.execute_reply": "2023-10-11T16:43:30.923862Z"}, "papermill": {"duration": 5.923941, "end_time": "2023-10-11T16:43:30.930528", "exception": false, "start_time": "2023-10-11T16:43:25.006587", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Original image and input image to sampling:\n"]}, {"data": {"application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1R5cGUgL0NhdGFsb2cgL1BhZ2VzIDIgMCBSID4+CmVuZG9iago4IDAgb2JqCjw8IC9Gb250IDMgMCBSIC9YT2JqZWN0IDcgMCBSIC9FeHRHU3RhdGUgNCAwIFIgL1BhdHRlcm4gNSAwIFIKL1NoYWRpbmcgNiAwIFIgL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gPj4KZW5kb2JqCjExIDAgb2JqCjw8IC9UeXBlIC9QYWdlIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovTWVkaWFCb3ggWyAwIDAgMTc1LjUyMjUgOTcuNTYgXSAvQ29udGVudHMgOSAwIFIgL0Fubm90cyAxMCAwIFIgPj4KZW5kb2JqCjkgMCBvYmoKPDwgL0xlbmd0aCAxMiAwIFIgL0ZpbHRlciAvRmxhdGVEZWNvZGUgPj4Kc3RyZWFtCnicVY4xDsIwDEV3n+KfIImDkpQRqBQxFgYOEIVCRItKJXp93A6tGJ7kZ9nf1nX+PlO+xCNOV9KbpZEYRWhhUIQJjCi0ZMQ64uCUs9aJvDbZB+W8NMxaPYjuNCAou8CeFc+T1U6xxyfjhh76IMGjpBdhksiI/1+GZdFWmI/LqrNrYuqgz4z6jYYa+gGxYy7BCmVuZHN0cmVhbQplbmRvYmoKMTIgMCBvYmoKMTQzCmVuZG9iagoxMCAwIG9iagpbIF0KZW5kb2JqCjMgMCBvYmoKPDwgPj4KZW5kb2JqCjQgMCBvYmoKPDwgL0ExIDw8IC9UeXBlIC9FeHRHU3RhdGUgL0NBIDEgL2NhIDEgPj4gPj4KZW5kb2JqCjUgMCBvYmoKPDwgPj4KZW5kb2JqCjYgMCBvYmoKPDwgPj4KZW5kb2JqCjcgMCBvYmoKPDwgL0kxIDEzIDAgUiA+PgplbmRvYmoKMTMgMCBvYmoKPDwgL1R5cGUgL1hPYmplY3QgL1N1YnR5cGUgL0ltYWdlIC9XaWR0aCAyMjQgL0hlaWdodCAxMTYKL0NvbG9yU3BhY2UgWyAvSW5kZXhlZCAvRGV2aWNlUkdCIDcwCij////9/f37+/v19fXv7+/t7e3r6+vj4+Pf39/R0dHPz8/Hx8fDw8PBwcG/v7+5ubmlpaWVlZWTk5OPj4+Li4t9fX1jY2NfX19RUVFPT09HR0c7Ozs1NTUjIyMdHR0bGxsZGRkTExP6+vry8vLw8PDs7Ozo6Ojk5OTY2NjQ0NDCwsK6urq0tLSurq6srKygoKCKioqIiIiAgIBycnJgYGBWVlZQUFBMTExEREQyMjIwMDAsLCwqKioiIiIaGhoUFBQSEhIODg4MDAxcblxuXG4GBgYEBAQAAAApCl0KL0JpdHNQZXJDb21wb25lbnQgOCAvRmlsdGVyIC9GbGF0ZURlY29kZQovRGVjb2RlUGFybXMgPDwgL1ByZWRpY3RvciAxMCAvQ29sb3JzIDEgL0NvbHVtbnMgMjI0IC9CaXRzUGVyQ29tcG9uZW50IDggPj4KL0xlbmd0aCAxNCAwIFIgPj4Kc3RyZWFtCnic7dxrU9pAGMXxRFoLFEUUWyxUUYpoixe8lWJBQOH7f6SenY0dlSRuMs/qZDn/l8Bw+L0Id/AuHM977xtgOwKzHoFZb2mA87fo4h3GCLSwSaDkGIEWNgmUHCPQwiaBkmMEWtgkUHKMQAubBEqOEWhhk0DJMQItbBIoOWYEvEKe522iyMs8oNlsZrL56pinx6Ivo7eMxggMNgnMLnAymRwj3/ev0eL5f1Cv1/uIisXiDXptM3psosd8PRa51dNbxbgtApcH2Gw2fd1XFNyMfdRoNLbQIfL/d47SA5vPxh7JequxsOXHbRFIoDNAHO3B9a2iCsrn835opVLpHqUHXj8bC7bCx0p6jEAC57iWE9RqtV5M5HK5X0g9gQtOiT0mTICSEfgYgVkHBtXr9Ut0pNtG/X5fnXGGgPuBRqORySaBIhH4IueBoX1D6oUggDvIcJNAkQg06TNSD4DtdnsXGW4SKBKBJi0NsFqtJtgkUCQCTXIbWKvVPqCfaDqdJtgkUCQCYxsg4NTxd4qSbRIoEoGxOQ/8goL3Cn+jZJsEikRgdOPxWH0FALg1FPt5UugmgSIRGJ3zwFR3ME82CRSJwOicB3a7XaXrdDrfUfJNAkUiMLw7VCgUFLBcLqfbJFAkAsNzHqi+NQ3cJzQYDNJtEigSgeE5D9xAAK6j1JsEikTgYn/RAQJwBVUqlT2UfJNAkQhczHngLQreq1A/ZxgOh+k2CRSJQHubBEqOEWhhk0DJMQItbBIoObY0f0PtbARmPQKznvPAfz0XsDAKZW5kc3RyZWFtCmVuZG9iagoxNCAwIG9iago1ODYKZW5kb2JqCjIgMCBvYmoKPDwgL1R5cGUgL1BhZ2VzIC9LaWRzIFsgMTEgMCBSIF0gL0NvdW50IDEgPj4KZW5kb2JqCjE1IDAgb2JqCjw8IC9DcmVhdG9yIChNYXRwbG90bGliIHYzLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZykKL1Byb2R1Y2VyIChNYXRwbG90bGliIHBkZiBiYWNrZW5kIHYzLjguMCkgL0NyZWF0aW9uRGF0ZSAoRDoyMDIzMTAxMTE2NDMyNVopCj4+CmVuZG9iagp4cmVmCjAgMTYKMDAwMDAwMDAwMCA2NTUzNSBmIAowMDAwMDAwMDE2IDAwMDAwIG4gCjAwMDAwMDE4NDAgMDAwMDAgbiAKMDAwMDAwMDU5NSAwMDAwMCBuIAowMDAwMDAwNjE2IDAwMDAwIG4gCjAwMDAwMDA2NzYgMDAwMDAgbiAKMDAwMDAwMDY5NyAwMDAwMCBuIAowMDAwMDAwNzE4IDAwMDAwIG4gCjAwMDAwMDAwNjUgMDAwMDAgbiAKMDAwMDAwMDMzNyAwMDAwMCBuIAowMDAwMDAwNTc1IDAwMDAwIG4gCjAwMDAwMDAyMDggMDAwMDAgbiAKMDAwMDAwMDU1NSAwMDAwMCBuIAowMDAwMDAwNzUwIDAwMDAwIG4gCjAwMDAwMDE4MjAgMDAwMDAgbiAKMDAwMDAwMTkwMCAwMDAwMCBuIAp0cmFpbGVyCjw8IC9TaXplIDE2IC9Sb290IDEgMCBSIC9JbmZvIDE1IDAgUiA+PgpzdGFydHhyZWYKMjA1MQolJUVPRgo=", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:25.050269\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["Global seed set to 1\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "2edb8f542ed54952acbb2523f3684bfc", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/28 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:27.120153\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Original image and input image to sampling:\n"]}, {"data": {"application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1R5cGUgL0NhdGFsb2cgL1BhZ2VzIDIgMCBSID4+CmVuZG9iago4IDAgb2JqCjw8IC9Gb250IDMgMCBSIC9YT2JqZWN0IDcgMCBSIC9FeHRHU3RhdGUgNCAwIFIgL1BhdHRlcm4gNSAwIFIKL1NoYWRpbmcgNiAwIFIgL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gPj4KZW5kb2JqCjExIDAgb2JqCjw8IC9UeXBlIC9QYWdlIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovTWVkaWFCb3ggWyAwIDAgMTc1LjUyMjUgOTcuNTYgXSAvQ29udGVudHMgOSAwIFIgL0Fubm90cyAxMCAwIFIgPj4KZW5kb2JqCjkgMCBvYmoKPDwgL0xlbmd0aCAxMiAwIFIgL0ZpbHRlciAvRmxhdGVEZWNvZGUgPj4Kc3RyZWFtCnicVY4xDsIwDEV3n+KfIImDkpQRqBQxFgYOEIVCRItKJXp93A6tGJ7kZ9nf1nX+PlO+xCNOV9KbpZEYRWhhUIQJjCi0ZMQ64uCUs9aJvDbZB+W8NMxaPYjuNCAou8CeFc+T1U6xxyfjhh76IMGjpBdhksiI/1+GZdFWmI/LqrNrYuqgz4z6jYYa+gGxYy7BCmVuZHN0cmVhbQplbmRvYmoKMTIgMCBvYmoKMTQzCmVuZG9iagoxMCAwIG9iagpbIF0KZW5kb2JqCjMgMCBvYmoKPDwgPj4KZW5kb2JqCjQgMCBvYmoKPDwgL0ExIDw8IC9UeXBlIC9FeHRHU3RhdGUgL0NBIDEgL2NhIDEgPj4gPj4KZW5kb2JqCjUgMCBvYmoKPDwgPj4KZW5kb2JqCjYgMCBvYmoKPDwgPj4KZW5kb2JqCjcgMCBvYmoKPDwgL0kxIDEzIDAgUiA+PgplbmRvYmoKMTMgMCBvYmoKPDwgL1R5cGUgL1hPYmplY3QgL1N1YnR5cGUgL0ltYWdlIC9XaWR0aCAyMjQgL0hlaWdodCAxMTYKL0NvbG9yU3BhY2UgWyAvSW5kZXhlZCAvRGV2aWNlUkdCIDY4Cij////7+/v5+fn39/fx8fHp6enl5eXd3d3Z2dnX19fV1dXBwcG/v7+9vb27u7u1tbWtra2hoaGVlZWRkZF1dXVzc3NxcXFnZ2dXV1dTU1NRUVFPT09LS0s7OzsbGxsXFxcVFRUTExNcclxyXHIDAwP+/v78/Pzy8vLq6uro6Ojm5ubW1tbMzMzGxsa8vLy6urq4uLi0tLSysrKsrKyYmJiAgIBqampiYmJaWlpMTExCQkJAQEA6OjowMDAgICAYGBgODg5cblxuXG4ICAgGBgYCAgIAAAApCl0KL0JpdHNQZXJDb21wb25lbnQgOCAvRmlsdGVyIC9GbGF0ZURlY29kZQovRGVjb2RlUGFybXMgPDwgL1ByZWRpY3RvciAxMCAvQ29sb3JzIDEgL0NvbHVtbnMgMjI0IC9CaXRzUGVyQ29tcG9uZW50IDggPj4KL0xlbmd0aCAxNCAwIFIgPj4Kc3RyZWFtCnic7ZrZktJAGEZR0VFZBBVFUZYxigtR1AEURBjl/Z/JryuxBkJIL3Q3lfCdy06nTs5N8leS0lXBKZ36AlzDwLzDwLxzNoEbH1ydQMZAB06vMgY6cHqVMdCB06vMMPA3qFQqP4G+00BWiWQ6ZzFQ5mTgQaeBzG/ger3+DIIg+Ab0nTqydSQLIpnOVTIwU8rALKeOzHtgH4RhGEQ8B/pOVVns2pJpuBiYLS1u4FcQC4fD4R+g71SV7bqGei4GKkkZmOpUlZ0k8AGIpdPpVEe40Q7cdRnKGJgpNXOqyhioIjMMfAc6nY6ZU1X2PzByGcoYmC4tfKAYETWFG+NAIxcDs6UMzHKqyrwHPgO3AJxPgLFTVRa7AiMXA9OdBQ7s9/ufQDw6iTfbxk65zAYMTMBAdadcZgPdwHa7HdwQB74ELyJeAUWnXGYDBiZgYN4DR6PRVuAj0Gg03oJ45Q7ASuYEwECbMDDBbuBB7oFfIMspl9mAgQkYWKzA9+DDHuVyuQQuQOqsykCbMDDBOQVWq9VLsL9nPp/fBWLPcrk86JTLbMDABOcU2Gq10vcMBoOPQER2u92DTrnMBgxMwMBNcQLr9fr+TUashGEo4l6DLKdcZgMGJih8YLPZ3BrVfoBerycO/AWLxeILwIGHQOaUy2zAwAQMzHvgarXa+n9TUKvVHoPv4GbxKZA55TIbMDABA/MeCK7BfRCkIV7dTyaTN0DmVJIdDQP3KXygQDwAx+PxVtltMJvNxLdQRaey7CgYmA4DFZzKsqMwDLTh9CpjoAOnVxkDHTi9yhjowOlVxkAHTq8yBjpwepUx0IHTq4yBDpxeZQx04PQqY6ADp1cZAx04vcrOJrCwMDDvMDDvMDDv/ANQi27OCmVuZHN0cmVhbQplbmRvYmoKMTQgMCBvYmoKNjU5CmVuZG9iagoyIDAgb2JqCjw8IC9UeXBlIC9QYWdlcyAvS2lkcyBbIDExIDAgUiBdIC9Db3VudCAxID4+CmVuZG9iagoxNSAwIG9iago8PCAvQ3JlYXRvciAoTWF0cGxvdGxpYiB2My44LjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcpCi9Qcm9kdWNlciAoTWF0cGxvdGxpYiBwZGYgYmFja2VuZCB2My44LjApIC9DcmVhdGlvbkRhdGUgKEQ6MjAyMzEwMTExNjQzMjdaKQo+PgplbmRvYmoKeHJlZgowIDE2CjAwMDAwMDAwMDAgNjU1MzUgZiAKMDAwMDAwMDAxNiAwMDAwMCBuIAowMDAwMDAxOTEwIDAwMDAwIG4gCjAwMDAwMDA1OTUgMDAwMDAgbiAKMDAwMDAwMDYxNiAwMDAwMCBuIAowMDAwMDAwNjc2IDAwMDAwIG4gCjAwMDAwMDA2OTcgMDAwMDAgbiAKMDAwMDAwMDcxOCAwMDAwMCBuIAowMDAwMDAwMDY1IDAwMDAwIG4gCjAwMDAwMDAzMzcgMDAwMDAgbiAKMDAwMDAwMDU3NSAwMDAwMCBuIAowMDAwMDAwMjA4IDAwMDAwIG4gCjAwMDAwMDA1NTUgMDAwMDAgbiAKMDAwMDAwMDc1MCAwMDAwMCBuIAowMDAwMDAxODkwIDAwMDAwIG4gCjAwMDAwMDE5NzAgMDAwMDAgbiAKdHJhaWxlcgo8PCAvU2l6ZSAxNiAvUm9vdCAxIDAgUiAvSW5mbyAxNSAwIFIgPj4Kc3RhcnR4cmVmCjIxMjEKJSVFT0YK", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:27.187600\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["Global seed set to 1\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "d1e2d868d5cb4fc3acdfc43ba118d35b", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/28 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:29.009993\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Original image and input image to sampling:\n"]}, {"data": {"application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1R5cGUgL0NhdGFsb2cgL1BhZ2VzIDIgMCBSID4+CmVuZG9iago4IDAgb2JqCjw8IC9Gb250IDMgMCBSIC9YT2JqZWN0IDcgMCBSIC9FeHRHU3RhdGUgNCAwIFIgL1BhdHRlcm4gNSAwIFIKL1NoYWRpbmcgNiAwIFIgL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gPj4KZW5kb2JqCjExIDAgb2JqCjw8IC9UeXBlIC9QYWdlIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovTWVkaWFCb3ggWyAwIDAgMTc1LjUyMjUgOTcuNTYgXSAvQ29udGVudHMgOSAwIFIgL0Fubm90cyAxMCAwIFIgPj4KZW5kb2JqCjkgMCBvYmoKPDwgL0xlbmd0aCAxMiAwIFIgL0ZpbHRlciAvRmxhdGVEZWNvZGUgPj4Kc3RyZWFtCnicVY4xDsIwDEV3n+KfIImDkpQRqBQxFgYOEIVCRItKJXp93A6tGJ7kZ9nf1nX+PlO+xCNOV9KbpZEYRWhhUIQJjCi0ZMQ64uCUs9aJvDbZB+W8NMxaPYjuNCAou8CeFc+T1U6xxyfjhh76IMGjpBdhksiI/1+GZdFWmI/LqrNrYuqgz4z6jYYa+gGxYy7BCmVuZHN0cmVhbQplbmRvYmoKMTIgMCBvYmoKMTQzCmVuZG9iagoxMCAwIG9iagpbIF0KZW5kb2JqCjMgMCBvYmoKPDwgPj4KZW5kb2JqCjQgMCBvYmoKPDwgL0ExIDw8IC9UeXBlIC9FeHRHU3RhdGUgL0NBIDEgL2NhIDEgPj4gPj4KZW5kb2JqCjUgMCBvYmoKPDwgPj4KZW5kb2JqCjYgMCBvYmoKPDwgPj4KZW5kb2JqCjcgMCBvYmoKPDwgL0kxIDEzIDAgUiA+PgplbmRvYmoKMTMgMCBvYmoKPDwgL1R5cGUgL1hPYmplY3QgL1N1YnR5cGUgL0ltYWdlIC9XaWR0aCAyMjQgL0hlaWdodCAxMTYKL0NvbG9yU3BhY2UgWyAvSW5kZXhlZCAvRGV2aWNlUkdCIDc4Cij////9/f35+fn39/fz8/Pv7+/t7e3r6+vp6enh4eHb29vT09PJycm/v7+7u7upqamnp6eTk5ORkZGPj4+NjY2JiYmFhYV3d3dxcXFdXV1ZWVlXV1dVVVVDQ0M/Pz8xMTErKysjIyMfHx8dHR0bGxsPDw8LCwsHBwcFBQX+/v78/Pz09PTs7Ozo6Ojm5ube3t7a2trOzs6+vr66urqurq6oqKikpKScnJyYmJiWlpaAgIB2dnZqampkZGRgYGBeXl5UVFRQUFBAQEA+Pj44ODgmJiYgICAWFhYSEhIQEBAODg4GBgYEBAQCAgIAAAApCl0KL0JpdHNQZXJDb21wb25lbnQgOCAvRmlsdGVyIC9GbGF0ZURlY29kZQovRGVjb2RlUGFybXMgPDwgL1ByZWRpY3RvciAxMCAvQ29sb3JzIDEgL0NvbHVtbnMgMjI0IC9CaXRzUGVyQ29tcG9uZW50IDggPj4KL0xlbmd0aCAxNCAwIFIgPj4Kc3RyZWFtCnic7Zxrc9JAGEZBQEVUEBTxVlAEBLl6t0otIIhB+f8/x2cnmwEqNJvhXSvLcz61YScnZ6bsdpOB2KnjxK76AmzDwEOHgYfO0QQu/wWnVyBjoAUnAyVlDLTgZKCkjIEWnAyUlDHQgpOBkjIGWnAyUFLGQAtOBkrKGGjByUBJWdTAR6Df78dBs9l8B86A53kRnKYy7er7rubKFUXGwO1SBl7m/D8Da7XaB5AALR949U+tTCbzDBg6Q2U1X7bpiq9cGQMXA48uEH/96vQ3wXA4/ALWAsFrMJlMRALPfJl2DX3XBdnEUMZABroQ+BJgPVKn/g5w5Dk493kLtLfX6/0E+wX6rv7Ktdzl6l3mYuBRBX4FeBM8AL/A5ovqSKlUugEw5jPYL1AKBq7BwIMOnE6nbdDtdtVcs31MvV6/BbCneQ/CnAwUgYEBzgcWCgW9Dds9ZjweqzGYZB6DMCcDRWBgQDabVRePpXD7679BLpdTYwaDwSsQ5mSgCAwMcDpQ3b9Lp9Pq4nee6CHQe7TRaGTiZKAIDFQ4H/gD6Iv/ewGfz+f3wTWgxxg6GSgCAwOwDqr7ymvboGq1+gTEYrH4invA0MlAERgY4HxgsVhUE0gymXyj6XQ66kgqlVI3KG4D/MbAJQNlMQ5cLBafQLvd1qtdK5FI3AX6CWTH7/0GDJ0MFIGBAc4HajzP049Zz09OTvRB9b/qdYD6F8DQyUARGHgB5wO38hSoaSefz0dwMlAEBppwB6i9Et6YEZwMFIGBJrgdOJvN1BqoJhkGbjgZKAIDQ6lUKsEOkYEbTgaKwMBQGLjLyUAR9g9sNBofgfpMT7lcjuBkoAgMtOdkoKSMgRacDJSUMdCCk4GSsqP5GmpnYeChw8BDx/nAPyUA2gAKZW5kc3RyZWFtCmVuZG9iagoxNCAwIG9iago2NjcKZW5kb2JqCjIgMCBvYmoKPDwgL1R5cGUgL1BhZ2VzIC9LaWRzIFsgMTEgMCBSIF0gL0NvdW50IDEgPj4KZW5kb2JqCjE1IDAgb2JqCjw8IC9DcmVhdG9yIChNYXRwbG90bGliIHYzLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZykKL1Byb2R1Y2VyIChNYXRwbG90bGliIHBkZiBiYWNrZW5kIHYzLjguMCkgL0NyZWF0aW9uRGF0ZSAoRDoyMDIzMTAxMTE2NDMyOVopCj4+CmVuZG9iagp4cmVmCjAgMTYKMDAwMDAwMDAwMCA2NTUzNSBmIAowMDAwMDAwMDE2IDAwMDAwIG4gCjAwMDAwMDE5NDIgMDAwMDAgbiAKMDAwMDAwMDU5NSAwMDAwMCBuIAowMDAwMDAwNjE2IDAwMDAwIG4gCjAwMDAwMDA2NzYgMDAwMDAgbiAKMDAwMDAwMDY5NyAwMDAwMCBuIAowMDAwMDAwNzE4IDAwMDAwIG4gCjAwMDAwMDAwNjUgMDAwMDAgbiAKMDAwMDAwMDMzNyAwMDAwMCBuIAowMDAwMDAwNTc1IDAwMDAwIG4gCjAwMDAwMDAyMDggMDAwMDAgbiAKMDAwMDAwMDU1NSAwMDAwMCBuIAowMDAwMDAwNzUwIDAwMDAwIG4gCjAwMDAwMDE5MjIgMDAwMDAgbiAKMDAwMDAwMjAwMiAwMDAwMCBuIAp0cmFpbGVyCjw8IC9TaXplIDE2IC9Sb290IDEgMCBSIC9JbmZvIDE1IDAgUiA+PgpzdGFydHhyZWYKMjE1MwolJUVPRgo=", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:29.073236\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["Global seed set to 1\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "1efb8e0f64c14d1db0009992b98008c5", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/28 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:30.870659\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["def autocomplete_image(img):\n", " # Remove lower half of the image\n", " img_init = img.clone()\n", " img_init[:, 10:, :] = -1\n", " print(\"Original image and input image to sampling:\")\n", " show_imgs([img, img_init])\n", " # Generate 12 example completions\n", " img_init = img_init.unsqueeze(dim=0).expand(12, -1, -1, -1).to(device)\n", " L.seed_everything(1)\n", " img_generated = model.sample(img_init.shape, img_init)\n", " print(\"Autocompletion samples:\")\n", " show_imgs(img_generated)\n", "\n", "\n", "for i in range(1, 4):\n", " img = train_set[i][0]\n", " autocomplete_image(img)"]}, {"cell_type": "markdown", "id": "ff51c9a7", "metadata": {"papermill": {"duration": 0.023729, "end_time": "2023-10-11T16:43:30.979962", "exception": false, "start_time": "2023-10-11T16:43:30.956233", "status": "completed"}, "tags": []}, "source": ["For the first two digits (7 and 6), we see that the 12 samples all\n", "result in a shape which resemble the original digit. Nevertheless, there\n", "are some style difference in writing the 7, and some deformed sixes in\n", "the samples. When autocompleting the 9 below, we see that the model can\n", "fit multiple digits to it. We obtain diverse samples from 0, 3, 8 and 9.\n", "This shows that despite having no latent space, we can still obtain\n", "diverse samples from an autoregressive model."]}, {"cell_type": "markdown", "id": "030bf4ea", "metadata": {"papermill": {"duration": 0.025649, "end_time": "2023-10-11T16:43:31.030490", "exception": false, "start_time": "2023-10-11T16:43:31.004841", "status": "completed"}, "tags": []}, "source": ["### Visualization of the predictive distribution (softmax)\n", "\n", "Autoregressive models use a softmax over 256 values to predict the next pixel.\n", "This gives the model a large flexibility as the probabilities for each pixel value can be learned independently if necessary.\n", "However, the values are actually not independent because the values 32 and 33 are much closer than 32 and 255.\n", "In the following, we visualize the softmax distribution that the model predicts to gain insights how it has learned the relationships of close-by pixels.\n", "\n", "To do this, we first run the model on a batch of images and store the output softmax distributions:"]}, {"cell_type": "code", "execution_count": 24, "id": "55d2964b", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:43:31.080165Z", "iopub.status.busy": "2023-10-11T16:43:31.079886Z", "iopub.status.idle": "2023-10-11T16:43:31.175788Z", "shell.execute_reply": "2023-10-11T16:43:31.174763Z"}, "papermill": {"duration": 0.122471, "end_time": "2023-10-11T16:43:31.177292", "exception": false, "start_time": "2023-10-11T16:43:31.054821", "status": "completed"}, "tags": []}, "outputs": [], "source": ["det_loader = data.DataLoader(train_set, batch_size=128, shuffle=False, drop_last=False)\n", "imgs, _ = next(iter(det_loader))\n", "imgs = imgs.to(device)\n", "with torch.no_grad():\n", " out = model(imgs)\n", " out = F.softmax(out, dim=1)\n", " mean_out = out.mean(dim=[0, 2, 3, 4]).cpu().numpy()\n", " out = out.cpu().numpy()"]}, {"cell_type": "markdown", "id": "94e3f1cb", "metadata": {"papermill": {"duration": 0.02434, "end_time": "2023-10-11T16:43:31.226172", "exception": false, "start_time": "2023-10-11T16:43:31.201832", "status": "completed"}, "tags": []}, "source": ["Before diving into the model, let's visualize the distribution of the pixel values in the whole dataset:"]}, {"cell_type": "code", "execution_count": 25, "id": "7f6ac797", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:43:31.275410Z", "iopub.status.busy": "2023-10-11T16:43:31.275133Z", "iopub.status.idle": "2023-10-11T16:43:32.705125Z", "shell.execute_reply": "2023-10-11T16:43:32.704222Z"}, "papermill": {"duration": 1.456515, "end_time": "2023-10-11T16:43:32.706522", "exception": false, "start_time": "2023-10-11T16:43:31.250007", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:31.994806\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["sns.set()\n", "plot_args = {\"color\": to_rgb(\"C0\") + (0.5,), \"edgecolor\": \"C0\", \"linewidth\": 0.5, \"width\": 1.0}\n", "plt.hist(imgs.view(-1).cpu().numpy(), bins=256, density=True, **plot_args)\n", "plt.yscale(\"log\")\n", "plt.xticks([0, 64, 128, 192, 256])\n", "plt.show()\n", "plt.close()"]}, {"cell_type": "markdown", "id": "d72383c6", "metadata": {"papermill": {"duration": 0.026872, "end_time": "2023-10-11T16:43:32.762902", "exception": false, "start_time": "2023-10-11T16:43:32.736030", "status": "completed"}, "tags": []}, "source": ["As we would expect from the seen images, the pixel value 0 (black) is the dominant value, followed by a batch of values between 250 and 255.\n", "Note that we use a log scale on the y-axis due to the big imbalance in the dataset.\n", "Interestingly, the pixel values 64, 128 and 191 also stand out which is likely due to the quantization used during the creation of the dataset.\n", "For RGB images, we would also see two peaks around 0 and 255,\n", "but the values in between would be much more frequent than in MNIST\n", "(see Figure 1 in the [PixelCNN++](https://arxiv.org/pdf/1701.05517.pdf) for a visualization on CIFAR10).\n", "\n", "Next, we can visualize the distribution our model predicts (in average):"]}, {"cell_type": "code", "execution_count": 26, "id": "af7ec5d7", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:43:32.817636Z", "iopub.status.busy": "2023-10-11T16:43:32.817237Z", "iopub.status.idle": "2023-10-11T16:43:34.550640Z", "shell.execute_reply": "2023-10-11T16:43:34.549871Z"}, "papermill": {"duration": 1.762708, "end_time": "2023-10-11T16:43:34.552080", "exception": false, "start_time": "2023-10-11T16:43:32.789372", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:33.841615\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["plt.bar(np.arange(mean_out.shape[0]), mean_out, **plot_args)\n", "plt.yscale(\"log\")\n", "plt.xticks([0, 64, 128, 192, 256])\n", "plt.show()\n", "plt.close()"]}, {"cell_type": "markdown", "id": "09600e38", "metadata": {"papermill": {"duration": 0.029302, "end_time": "2023-10-11T16:43:34.612326", "exception": false, "start_time": "2023-10-11T16:43:34.583024", "status": "completed"}, "tags": []}, "source": ["This distribution is very close to the actual dataset distribution.\n", "This is in general a good sign, but we can see a slightly smoother histogram than above.\n", "\n", "Finally, to take a closer look at learned value relations, we can\n", "visualize the distribution for individual pixel predictions to get a\n", "better intuition. For this, we pick 4 random images and pixels, and\n", "visualize their distribution below:"]}, {"cell_type": "code", "execution_count": 27, "id": "4f580dd0", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:43:34.676189Z", "iopub.status.busy": "2023-10-11T16:43:34.675468Z", "iopub.status.idle": "2023-10-11T16:43:38.754010Z", "shell.execute_reply": "2023-10-11T16:43:38.753235Z"}, "papermill": {"duration": 4.11602, "end_time": "2023-10-11T16:43:38.757667", "exception": false, "start_time": "2023-10-11T16:43:34.641647", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:36.491534\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["fig, ax = plt.subplots(2, 2, figsize=(10, 6))\n", "for i in range(4):\n", " ax_sub = ax[i // 2][i % 2]\n", " ax_sub.bar(np.arange(out.shape[1], dtype=np.int32), out[i + 4, :, 0, 14, 14], **plot_args)\n", " ax_sub.set_yscale(\"log\")\n", " ax_sub.set_xticks([0, 64, 128, 192, 256])\n", "plt.show()\n", "plt.close()"]}, {"cell_type": "markdown", "id": "49dde0fb", "metadata": {"papermill": {"duration": 0.044463, "end_time": "2023-10-11T16:43:38.843419", "exception": false, "start_time": "2023-10-11T16:43:38.798956", "status": "completed"}, "tags": []}, "source": ["Overall we see a very diverse set of distributions, with a usual peak\n", "for 0 and close to 1. However, the distributions in the first row show a\n", "potentially undesirable behavior. For instance, the value 242 has a\n", "1000x lower likelihood than 243 although they are extremely close and\n", "can often not be distinguished. This shows that the model might have not\n", "generalized well over pixel values. The better solution to this problem\n", "is to use discrete logitics mixtures instead of a softmax distribution.\n", "A discrete logistic distribution can be imagined as discretized, binned\n", "Gaussians. Using a mixture of discrete logistics instead of a softmax\n", "introduces an inductive bias to the model to assign close-by values\n", "similar likelihoods. We can visualize a discrete logistic below:"]}, {"cell_type": "code", "execution_count": 28, "id": "a45a1bc4", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:43:38.921790Z", "iopub.status.busy": "2023-10-11T16:43:38.921403Z", "iopub.status.idle": "2023-10-11T16:43:39.681965Z", "shell.execute_reply": "2023-10-11T16:43:39.681237Z"}, "papermill": {"duration": 0.801647, "end_time": "2023-10-11T16:43:39.683536", "exception": false, "start_time": "2023-10-11T16:43:38.881889", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-10-11T16:43:39.263469\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.0, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["mu = Tensor([128])\n", "sigma = Tensor([2.0])\n", "\n", "\n", "def discrete_logistic(x, mu, sigma):\n", " return torch.sigmoid((x + 0.5 - mu) / sigma) - torch.sigmoid((x - 0.5 - mu) / sigma)\n", "\n", "\n", "x = torch.arange(256)\n", "p = discrete_logistic(x, mu, sigma)\n", "\n", "# Visualization\n", "plt.figure(figsize=(6, 3))\n", "plt.bar(x.numpy(), p.numpy(), **plot_args)\n", "plt.xlim(96, 160)\n", "plt.title(\"Discrete logistic distribution\")\n", "plt.xlabel(\"Pixel value\")\n", "plt.ylabel(\"Probability\")\n", "plt.show()\n", "plt.close()"]}, {"cell_type": "markdown", "id": "aa95bd8a", "metadata": {"papermill": {"duration": 0.042576, "end_time": "2023-10-11T16:43:39.769774", "exception": false, "start_time": "2023-10-11T16:43:39.727198", "status": "completed"}, "tags": []}, "source": ["Instead of the softmax, the model would output mean and standard\n", "deviations for the $K$ logistics we use in the mixture. This is one of\n", "the improvements in autoregressive models that PixelCNN++ [3] has\n", "introduced compared to the original PixelCNN."]}, {"cell_type": "markdown", "id": "0419b360", "metadata": {"papermill": {"duration": 0.043157, "end_time": "2023-10-11T16:43:39.854743", "exception": false, "start_time": "2023-10-11T16:43:39.811586", "status": "completed"}, "tags": []}, "source": ["## Conclusion\n", "\n", "In this tutorial, we have looked at autoregressive image modeling, and\n", "implemented the PixelCNN architecture. With the usage of masked\n", "convolutions, we are able to apply a convolutional network in which a\n", "pixel is only influenced by all its predecessors. Separating the masked\n", "convolution into a horizontal and vertical stack allowed us to remove\n", "the known blind spot on the right upper row of a pixel. In experiments,\n", "autoregressive models outperformed normalizing flows in terms of bits\n", "per dimension, but are much slower to sample from. Improvements, that we\n", "have not implemented ourselves here, are discrete logistic mixtures, a\n", "downsampling architecture, and changing the pixel order in a diagonal\n", "fashion (see PixelSNAIL). Overall, autoregressive models are another,\n", "strong family of generative models, which however are mostly used in\n", "sequence tasks because of their linear scaling in sampling time than\n", "quadratic as on images."]}, {"cell_type": "markdown", "id": "05634737", "metadata": {"papermill": {"duration": 0.042491, "end_time": "2023-10-11T16:43:39.939987", "exception": false, "start_time": "2023-10-11T16:43:39.897496", "status": "completed"}, "tags": []}, "source": ["## References\n", "[1] van den Oord, A., et al.\n", "\"Pixel Recurrent Neural Networks.\"\n", "arXiv preprint arXiv:1601.06759 (2016).\n", "[Link](https://arxiv.org/abs/1601.06759)\n", "\n", "[2] van den Oord, A., et al.\n", "\"Conditional Image Generation with PixelCNN Decoders.\"\n", "In Advances in Neural Information Processing Systems 29, pp.\n", "4790\u20134798 (2016).\n", "[Link](http://papers.nips.cc/paper/6527-conditional-image-generation-with-pixelcnn-decoders.pdf)\n", "\n", "[3] Salimans, Tim, et al.\n", "\"PixelCNN++: Improving the PixelCNN with Discretized Logistic Mixture Likelihood and Other Modifications.\"\n", "arXiv preprint arXiv:1701.05517 (2017).\n", "[Link](https://arxiv.org/abs/1701.05517)"]}, {"cell_type": "markdown", "id": "d3b1e11e", "metadata": {"papermill": {"duration": 0.041858, "end_time": "2023-10-11T16:43:40.023939", "exception": false, "start_time": "2023-10-11T16:43:39.982081", "status": "completed"}, "tags": []}, "source": ["## Congratulations - Time to Join the Community!\n", "\n", "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n", "movement, you can do so in the following ways!\n", "\n", "### Star [Lightning](https://github.com/Lightning-AI/lightning) on GitHub\n", "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n", "tools we're building.\n", "\n", "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n", "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n", "and share your interests in `#general` channel\n", "\n", "\n", "### Contributions !\n", "The best way to contribute to our community is to become a code contributor! At any time you can go to\n", "[Lightning](https://github.com/Lightning-AI/lightning) or [Bolt](https://github.com/Lightning-AI/lightning-bolts)\n", "GitHub Issues page and filter for \"good first issue\".\n", "\n", "* [Lightning good first issue](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* [Bolt good first issue](https://github.com/Lightning-AI/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* You can also contribute your own notebooks with useful examples !\n", "\n", "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", "\n", "[![Pytorch Lightning](){height=\"60px\" width=\"240px\"}](https://pytorchlightning.ai)"]}, {"cell_type": "raw", "metadata": {"raw_mimetype": "text/restructuredtext"}, "source": [".. customcarditem::\n", " :header: Tutorial 10: Autoregressive Image Modeling\n", " :card_description: In this tutorial, we implement an autoregressive likelihood model for the task of image modeling. Autoregressive models are naturally strong generative models that constitute...\n", " :tags: Image,GPU/TPU,UvA-DL-Course\n", " :image: _static/images/course_UvA-DL/10-autoregressive-image-modeling.jpg"]}], "metadata": {"jupytext": {"cell_metadata_filter": "colab,colab_type,id,-all", "formats": "ipynb,py:percent", "main_language": "python"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12"}, "papermill": {"default_parameters": {}, "duration": 907.793146, "end_time": "2023-10-11T16:43:42.431236", "environment_variables": {}, "exception": null, "input_path": "course_UvA-DL/10-autoregressive-image-modeling/Autoregressive_Image_Modeling.ipynb", "output_path": ".notebooks/course_UvA-DL/10-autoregressive-image-modeling.ipynb", "parameters": {}, "start_time": "2023-10-11T16:28:34.638090", "version": "2.4.0"}, "widgets": {"application/vnd.jupyter.widget-state+json": {"state": {"0e56a919ff994f698fd233e0cedbaaf8": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "0f631515cd3e46bc90b7534b7288422b": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_b76de0e9228b426695822f45773f07fd", "placeholder": "\u200b", "style": "IPY_MODEL_1ded9662ce6c4e1ea617f11a7e99c44a", "tabbable": null, "tooltip": null, "value": " 93%"}}, "1ded9662ce6c4e1ea617f11a7e99c44a": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "1e120d9cf8004b40aa5e2160aa0a54a8": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": "hidden", "width": null}}, "1efb8e0f64c14d1db0009992b98008c5": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_0f631515cd3e46bc90b7534b7288422b", "IPY_MODEL_e2299437f50e43a3a63aa7cb2138121d", "IPY_MODEL_70670ef05b0b466fa4b10336919ade05"], "layout": "IPY_MODEL_ef71c3339e4d41caaefad4ae3fa58506", "tabbable": null, "tooltip": null}}, "1f7aab1c20b94af1a21f27e0e17eab70": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "20aaeac66e434350ba7b9a3887c28d3e": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "286aab3bed5640b9b86f53c61774d760": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "2cfff5ba0d0544238647ca6b66571412": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_9ab63dc599f949769c2f3964a10bf1b7", "placeholder": "\u200b", "style": "IPY_MODEL_96e30cb44cf941e18da0a827aa22fa58", "tabbable": null, "tooltip": null, "value": " 28/28 [00:01<00:00, 13.40it/s]"}}, "2edb8f542ed54952acbb2523f3684bfc": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_e479166e719048508d553287dd5d1fa2", "IPY_MODEL_fee1273514f844129fc63c113e48983d", "IPY_MODEL_bd8457fb6ec842adae2927a139b75dd6"], "layout": "IPY_MODEL_1e120d9cf8004b40aa5e2160aa0a54a8", "tabbable": null, "tooltip": null}}, "3c1504bd07f7429992e82357ed0c93c3": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_be291b59904c43568a1c98da286dce34", "IPY_MODEL_4de6e0c4c5a140c7899db1e8d813ba90", "IPY_MODEL_7358c01648b1468eafa9b22ce5bc22b7"], "layout": "IPY_MODEL_4db3183a6fb04702bfadd7f316f4a672", "tabbable": null, "tooltip": null}}, "41ccb9d6e78546d0ac13c9a76568eb34": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_457d3022ee93421ebfa582ed2b36fb95", "IPY_MODEL_9c79aca42dae4225b6671202e0240022", "IPY_MODEL_b1f79a205e8e4cdc94b5b590f8cc35f8"], "layout": "IPY_MODEL_aae0ad24a9544268adc47b7ee1d74ee0", "tabbable": null, "tooltip": null}}, "4287a288447f40d6ab0b5baa22a19c34": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_20aaeac66e434350ba7b9a3887c28d3e", "placeholder": "\u200b", "style": "IPY_MODEL_442571c5f61f4c7d917e4cba64662d96", "tabbable": null, "tooltip": null, "value": "100%"}}, "442571c5f61f4c7d917e4cba64662d96": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "457d3022ee93421ebfa582ed2b36fb95": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_b6a7298a39a54cc18f0ac65f46925b18", "placeholder": "\u200b", "style": "IPY_MODEL_0e56a919ff994f698fd233e0cedbaaf8", "tabbable": null, "tooltip": null, "value": "100%"}}, "4b9ecf4f1b57450ca305f2a942a42708": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": "hidden", "width": null}}, "4db3183a6fb04702bfadd7f316f4a672": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": "hidden", "width": null}}, "4de6e0c4c5a140c7899db1e8d813ba90": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_be17ea1ee216442c866045e4ba435bc3", "max": 64.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_86fb5ea0a828471b87200aaf2523170d", "tabbable": null, "tooltip": null, "value": 64.0}}, "576fbf78aee640b8858dd5732441d85d": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "688b1e51f42b4c1d8a8e5225df145632": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "6f2ca2f208994e7ab6ffb973d117e609": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "70670ef05b0b466fa4b10336919ade05": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_286aab3bed5640b9b86f53c61774d760", "placeholder": "\u200b", "style": "IPY_MODEL_ac869c606c6e4610acd99ba41923888f", "tabbable": null, "tooltip": null, "value": " 26/28 [00:01<00:00, 14.35it/s]"}}, "7243cacdc61044ffac82a99d63361d34": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "7358c01648b1468eafa9b22ce5bc22b7": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_7c2b64e05c954190a7b7320b62e05850", "placeholder": "\u200b", "style": "IPY_MODEL_a77a947b74664433ae9085a5418e4a76", "tabbable": null, "tooltip": null, "value": " 64/64 [07:11<00:00, 6.56s/it]"}}, "7395d51d66f74a91b813bd91570e0dce": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "7c2b64e05c954190a7b7320b62e05850": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "7f31ccc1ef3944fe81aacfa0a09a980d": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "86fb5ea0a828471b87200aaf2523170d": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "88989ec6bb41465fac2c1a7eda8be3dc": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_9ad8569c97f8429e981c1065729821c0", "max": 28.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_6f2ca2f208994e7ab6ffb973d117e609", "tabbable": null, "tooltip": null, "value": 28.0}}, "96e30cb44cf941e18da0a827aa22fa58": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "9ab63dc599f949769c2f3964a10bf1b7": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "9ad8569c97f8429e981c1065729821c0": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "9c79aca42dae4225b6671202e0240022": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_7f31ccc1ef3944fe81aacfa0a09a980d", "max": 28.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_576fbf78aee640b8858dd5732441d85d", "tabbable": null, "tooltip": null, "value": 28.0}}, "9d39ae4886cb42588072d053b24224e4": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "a7554b9341ee43e6b29cd9a39a830c18": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "a77a947b74664433ae9085a5418e4a76": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "aae0ad24a9544268adc47b7ee1d74ee0": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": "hidden", "width": null}}, "ac869c606c6e4610acd99ba41923888f": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "b1f79a205e8e4cdc94b5b590f8cc35f8": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e872215a19ca467a9542e0d2c4a7a91b", "placeholder": "\u200b", "style": "IPY_MODEL_e9181668620f4323bfead217a729c30a", "tabbable": null, "tooltip": null, "value": " 28/28 [02:57<00:00, 6.34s/it]"}}, "b6a7298a39a54cc18f0ac65f46925b18": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "b76de0e9228b426695822f45773f07fd": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "bd8457fb6ec842adae2927a139b75dd6": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_1f7aab1c20b94af1a21f27e0e17eab70", "placeholder": "\u200b", "style": "IPY_MODEL_d411295bc6f045cd96e4d8d77b9342db", "tabbable": null, "tooltip": null, "value": " 27/28 [00:01<00:00, 11.90it/s]"}}, "be17ea1ee216442c866045e4ba435bc3": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "be291b59904c43568a1c98da286dce34": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_a7554b9341ee43e6b29cd9a39a830c18", "placeholder": "\u200b", "style": "IPY_MODEL_9d39ae4886cb42588072d053b24224e4", "tabbable": null, "tooltip": null, "value": "100%"}}, "c3eba258a0de491fb084306e80ba90e9": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "c58bb9d0ea7b4d489b3c1d6d4303a34c": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "d1e2d868d5cb4fc3acdfc43ba118d35b": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_4287a288447f40d6ab0b5baa22a19c34", "IPY_MODEL_88989ec6bb41465fac2c1a7eda8be3dc", "IPY_MODEL_2cfff5ba0d0544238647ca6b66571412"], "layout": "IPY_MODEL_4b9ecf4f1b57450ca305f2a942a42708", "tabbable": null, "tooltip": null}}, "d411295bc6f045cd96e4d8d77b9342db": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "e2299437f50e43a3a63aa7cb2138121d": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_c58bb9d0ea7b4d489b3c1d6d4303a34c", "max": 28.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_7395d51d66f74a91b813bd91570e0dce", "tabbable": null, "tooltip": null, "value": 28.0}}, "e479166e719048508d553287dd5d1fa2": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_688b1e51f42b4c1d8a8e5225df145632", "placeholder": "\u200b", "style": "IPY_MODEL_7243cacdc61044ffac82a99d63361d34", "tabbable": null, "tooltip": null, "value": " 96%"}}, "e58f6fe439334ac5a5b801be93a545a8": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "e872215a19ca467a9542e0d2c4a7a91b": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "e9181668620f4323bfead217a729c30a": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "ef71c3339e4d41caaefad4ae3fa58506": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": "hidden", "width": null}}, "fee1273514f844129fc63c113e48983d": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e58f6fe439334ac5a5b801be93a545a8", "max": 28.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_c3eba258a0de491fb084306e80ba90e9", "tabbable": null, "tooltip": null, "value": 28.0}}}, "version_major": 2, "version_minor": 0}}}, "nbformat": 4, "nbformat_minor": 5}