{"cells": [{"cell_type": "markdown", "id": "100c2262", "metadata": {"papermill": {"duration": 0.018049, "end_time": "2023-03-14T16:08:42.495700", "exception": false, "start_time": "2023-03-14T16:08:42.477651", "status": "completed"}, "tags": []}, "source": ["\n", "# Tutorial 9: Normalizing Flows for Image Modeling\n", "\n", "* **Author:** Phillip Lippe\n", "* **License:** CC BY-SA\n", "* **Generated:** 2023-03-14T16:07:05.259127\n", "\n", "In this tutorial, we will take a closer look at complex, deep normalizing flows.\n", "The most popular, current application of deep normalizing flows is to model datasets of images.\n", "As for other generative models, images are a good domain to start working on because\n", "(1) CNNs are widely studied and strong models exist,\n", "(2) images are high-dimensional and complex,\n", "and (3) images are discrete integers.\n", "In this tutorial, we will review current advances in normalizing flows for image modeling,\n", "and get hands-on experience on coding normalizing flows.\n", "Note that normalizing flows are commonly parameter heavy and therefore computationally expensive.\n", "We will use relatively simple and shallow flows to save computational cost and allow you to run the notebook on CPU,\n", "but keep in mind that a simple way to improve the scores of the flows we study here is to make them deeper.\n", "This notebook is part of a lecture series on Deep Learning at the University of Amsterdam.\n", "The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.\n", "\n", "\n", "---\n", "Open in [![Open In Colab](){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/course_UvA-DL/09-normalizing-flows.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "87ad8b6e", "metadata": {"papermill": {"duration": 0.008889, "end_time": "2023-03-14T16:08:42.513816", "exception": false, "start_time": "2023-03-14T16:08:42.504927", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "d2626593", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2023-03-14T16:08:42.564704Z", "iopub.status.busy": "2023-03-14T16:08:42.564240Z", "iopub.status.idle": "2023-03-14T16:08:45.876740Z", "shell.execute_reply": "2023-03-14T16:08:45.875371Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 3.355994, "end_time": "2023-03-14T16:08:45.879658", "exception": false, "start_time": "2023-03-14T16:08:42.523664", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}], "source": ["! pip install --quiet \"ipython[notebook]>=8.0.0, <8.12.0\" \"seaborn\" \"pytorch-lightning>=1.4, <2.0.0\" \"torch>=1.8.1, <1.14.0\" \"tabulate\" \"lightning>=2.0.0rc0\" \"torchmetrics>=0.7, <0.12\" \"torchvision\" \"matplotlib\" \"setuptools==67.4.0\""]}, {"cell_type": "markdown", "id": "c4fcf3d8", "metadata": {"papermill": {"duration": 0.009032, "end_time": "2023-03-14T16:08:45.902797", "exception": false, "start_time": "2023-03-14T16:08:45.893765", "status": "completed"}, "tags": []}, "source": ["
\n", "Throughout this notebook, we make use of [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).\n", "The first cell imports our usual libraries."]}, {"cell_type": "code", "execution_count": 2, "id": "91b34352", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:45.922980Z", "iopub.status.busy": "2023-03-14T16:08:45.922606Z", "iopub.status.idle": "2023-03-14T16:08:49.148569Z", "shell.execute_reply": "2023-03-14T16:08:49.147852Z"}, "papermill": {"duration": 3.238278, "end_time": "2023-03-14T16:08:49.150284", "exception": false, "start_time": "2023-03-14T16:08:45.912006", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 42\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Using device cuda:0\n"]}], "source": ["import math\n", "import os\n", "import time\n", "import urllib.request\n", "from urllib.error import HTTPError\n", "\n", "import lightning as L\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import matplotlib_inline.backend_inline\n", "import numpy as np\n", "import seaborn as sns\n", "import tabulate\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "import torch.utils.data as data\n", "import torchvision\n", "from IPython.display import HTML, display\n", "from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint\n", "from matplotlib.colors import to_rgb\n", "from torch import Tensor\n", "from torchvision import transforms\n", "from torchvision.datasets import MNIST\n", "from tqdm.notebook import tqdm\n", "\n", "%matplotlib inline\n", "matplotlib_inline.backend_inline.set_matplotlib_formats(\"svg\", \"pdf\") # For export\n", "matplotlib.rcParams[\"lines.linewidth\"] = 2.0\n", "sns.reset_orig()\n", "\n", "# Path to the folder where the datasets are/should be downloaded (e.g. MNIST)\n", "DATASET_PATH = os.environ.get(\"PATH_DATASETS\", \"data\")\n", "# Path to the folder where the pretrained models are saved\n", "CHECKPOINT_PATH = os.environ.get(\"PATH_CHECKPOINT\", \"saved_models/tutorial11\")\n", "\n", "# Setting the seed\n", "L.seed_everything(42)\n", "\n", "# Ensure that all operations are deterministic on GPU (if used) for reproducibility\n", "torch.backends.cudnn.deterministic = True\n", "torch.backends.cudnn.benchmark = False\n", "\n", "# Fetching the device that will be used throughout this notebook\n", "device = torch.device(\"cpu\") if not torch.cuda.is_available() else torch.device(\"cuda:0\")\n", "print(\"Using device\", device)"]}, {"cell_type": "markdown", "id": "4bcabe69", "metadata": {"papermill": {"duration": 0.009295, "end_time": "2023-03-14T16:08:49.170996", "exception": false, "start_time": "2023-03-14T16:08:49.161701", "status": "completed"}, "tags": []}, "source": ["Again, we have a few pretrained models. We download them below to the specified path above."]}, {"cell_type": "code", "execution_count": 3, "id": "0a65243d", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:49.191095Z", "iopub.status.busy": "2023-03-14T16:08:49.190541Z", "iopub.status.idle": "2023-03-14T16:08:50.538855Z", "shell.execute_reply": "2023-03-14T16:08:50.537598Z"}, "papermill": {"duration": 1.361205, "end_time": "2023-03-14T16:08:50.541518", "exception": false, "start_time": "2023-03-14T16:08:49.180313", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial11/MNISTFlow_simple.ckpt...\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial11/MNISTFlow_vardeq.ckpt...\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial11/MNISTFlow_multiscale.ckpt...\n"]}], "source": ["# Github URL where saved models are stored for this tutorial\n", "base_url = \"https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial11/\"\n", "# Files to download\n", "pretrained_files = [\"MNISTFlow_simple.ckpt\", \"MNISTFlow_vardeq.ckpt\", \"MNISTFlow_multiscale.ckpt\"]\n", "# Create checkpoint path if it doesn't exist yet\n", "os.makedirs(CHECKPOINT_PATH, exist_ok=True)\n", "\n", "# For each file, check whether it already exists. If not, try downloading it.\n", "for file_name in pretrained_files:\n", " file_path = os.path.join(CHECKPOINT_PATH, file_name)\n", " if not os.path.isfile(file_path):\n", " file_url = base_url + file_name\n", " print(\"Downloading %s...\" % file_url)\n", " try:\n", " urllib.request.urlretrieve(file_url, file_path)\n", " except HTTPError as e:\n", " print(\n", " \"Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\\n\",\n", " e,\n", " )"]}, {"cell_type": "markdown", "id": "23698177", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.009424, "end_time": "2023-03-14T16:08:50.566184", "exception": false, "start_time": "2023-03-14T16:08:50.556760", "status": "completed"}, "tags": []}, "source": ["We will use the MNIST dataset in this notebook.\n", "MNIST constitutes, despite its simplicity, a challenge for small generative models as it requires the global understanding of an image.\n", "At the same time, we can easily judge whether generated images come from the same distribution as the dataset\n", "(i.e. represent real digits), or not.\n", "\n", "To deal better with the discrete nature of the images, we transform them\n", "from a range of 0-1 to a range of 0-255 as integers."]}, {"cell_type": "code", "execution_count": 4, "id": "9965c930", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:50.586882Z", "iopub.status.busy": "2023-03-14T16:08:50.586498Z", "iopub.status.idle": "2023-03-14T16:08:51.988192Z", "shell.execute_reply": "2023-03-14T16:08:51.987570Z"}, "papermill": {"duration": 1.413875, "end_time": "2023-03-14T16:08:51.989455", "exception": false, "start_time": "2023-03-14T16:08:50.575580", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /__w/6/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "8e868462f756444d9b1a566c8f77d457", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/9912422 [00:00 make them a tensor and discretize\n", "transform = transforms.Compose([transforms.ToTensor(), discretize])\n", "\n", "# Loading the training dataset. We need to split it into a training and validation part\n", "train_dataset = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)\n", "L.seed_everything(42)\n", "train_set, val_set = torch.utils.data.random_split(train_dataset, [50000, 10000])\n", "\n", "# Loading the test set\n", "test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)\n", "\n", "# We define a set of data loaders that we can use for various purposes later.\n", "# Note that for actually training a model, we will use different data loaders\n", "# with a lower batch size.\n", "train_loader = data.DataLoader(train_set, batch_size=256, shuffle=False, drop_last=False)\n", "val_loader = data.DataLoader(val_set, batch_size=64, shuffle=False, drop_last=False, num_workers=4)\n", "test_loader = data.DataLoader(test_set, batch_size=64, shuffle=False, drop_last=False, num_workers=4)"]}, {"cell_type": "markdown", "id": "f2a87c6f", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.010169, "end_time": "2023-03-14T16:08:52.011203", "exception": false, "start_time": "2023-03-14T16:08:52.001034", "status": "completed"}, "tags": []}, "source": ["In addition, we will define below a function to simplify the visualization of images/samples.\n", "Some training examples of the MNIST dataset is shown below."]}, {"cell_type": "code", "execution_count": 5, "id": "d170c817", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:52.035150Z", "iopub.status.busy": "2023-03-14T16:08:52.034845Z", "iopub.status.idle": "2023-03-14T16:08:52.192780Z", "shell.execute_reply": "2023-03-14T16:08:52.192170Z"}, "papermill": {"duration": 0.172541, "end_time": "2023-03-14T16:08:52.193954", "exception": false, "start_time": "2023-03-14T16:08:52.021413", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1R5cGUgL0NhdGFsb2cgL1BhZ2VzIDIgMCBSID4+CmVuZG9iago4IDAgb2JqCjw8IC9Gb250IDMgMCBSIC9YT2JqZWN0IDcgMCBSIC9FeHRHU3RhdGUgNCAwIFIgL1BhdHRlcm4gNSAwIFIKL1NoYWRpbmcgNiAwIFIgL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gPj4KZW5kb2JqCjExIDAgb2JqCjw8IC9UeXBlIC9QYWdlIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovTWVkaWFCb3ggWyAwIDAgMzQxLjY3NDgzODcwOTcgMTgwLjcyIF0gL0NvbnRlbnRzIDkgMCBSIC9Bbm5vdHMgMTAgMCBSID4+CmVuZG9iago5IDAgb2JqCjw8IC9MZW5ndGggMTIgMCBSIC9GaWx0ZXIgL0ZsYXRlRGVjb2RlID4+CnN0cmVhbQp4nFWOSw7CMAxE9z7FnCDfKkmXQKWIZWHBAaJQiCioVKLXx61AhcWzPJbHHtnk1zXlQ9xidyS5qjSSRmE6KBRmgkZkOlKserKVFs5XwdYsb79SByW84Zla2wvRmQZ4YRas4TpvB69qD+2csAbPjBPukBv+MvKrwkx8PeI/2LD4HeYgH+v3cOoh9xrNAy219AYPKzF0CmVuZHN0cmVhbQplbmRvYmoKMTIgMCBvYmoKMTQ4CmVuZG9iagoxMCAwIG9iagpbIF0KZW5kb2JqCjMgMCBvYmoKPDwgPj4KZW5kb2JqCjQgMCBvYmoKPDwgL0ExIDw8IC9UeXBlIC9FeHRHU3RhdGUgL0NBIDEgL2NhIDEgPj4gPj4KZW5kb2JqCjUgMCBvYmoKPDwgPj4KZW5kb2JqCjYgMCBvYmoKPDwgPj4KZW5kb2JqCjcgMCBvYmoKPDwgL0kxIDEzIDAgUiA+PgplbmRvYmoKMTMgMCBvYmoKPDwgL1R5cGUgL1hPYmplY3QgL1N1YnR5cGUgL0ltYWdlIC9XaWR0aCA0NTUgL0hlaWdodCAyMzEKL0NvbG9yU3BhY2UgWy9JbmRleGVkIC9EZXZpY2VSR0IgMjIyICj////+/v79/f38/Pz7+/v6+vr5+fn4+Pj39/f19fX09PTz8/Py8vLx8fHw8PDv7+/u7u7t7e3s7Ozr6+vq6urp6eno6Ojn5+fm5ubl5eXk5OTj4+Pi4uLh4eHg4ODf39/e3t7d3d3c3Nzb29va2trZ2dnY2NjX19fW1tbV1dXU1NTT09PR0dHQ0NDPz8/Ozs7Nzc3MzMzLy8vJycnHx8fGxsbFxcXExMTDw8PCwsLBwcHAwMC/v7++vr69vb28vLy7u7u6urq5ubm4uLi3t7e2tra1tbW0tLSzs7OysrKwsLCvr6+urq6tra2srKyqqqqpqamoqKinp6elpaWkpKSioqKhoaGgoKCenp6cnJyampqZmZmYmJiXl5eWlpaVlZWUlJSTk5ORkZGPj4+NjY2Li4uKioqJiYmIiIiHh4eGhoaFhYWEhISDg4OCgoKAgIB+fn59fX18fHx7e3t5eXl4eHh3d3d2dnZ1dXV0dHRzc3NycnJxcXFwcHBvb29ubm5qampnZ2dmZmZkZGRjY2NiYmJgYGBfX19eXl5dXV1cXFxcXFxbW1taWlpZWVlYWFhXV1dWVlZVVVVUVFRTU1NRUVFQUFBPT09MTExLS0tKSkpJSUlISEhHR0dGRkZFRUVERERDQ0NCQkJBQUFAQEA/Pz8+Pj48PDw7Ozs6Ojo5OTk4ODg3Nzc1NTU0NDQzMzMyMjIxMTEwMDAvLy8uLi4sLCwrKysqKipcKFwoXCgnJycmJiYlJSUkJCQjIyMiIiIgICAfHx8eHh4dHR0cHBwbGxsaGhoZGRkYGBgXFxcWFhYVFRUUFBQTExMSEhIREREQEBAPDw8ODg5cclxyXHIMDAwLCwtcblxuXG4JCQkICAgHBwcGBgYFBQUEBAQDAwMCAgIBAQEAAAApXQovQml0c1BlckNvbXBvbmVudCA4IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlCi9EZWNvZGVQYXJtcyA8PCAvUHJlZGljdG9yIDEwIC9Db2xvcnMgMSAvQ29sdW1ucyA0NTUgPj4gL0xlbmd0aCAxNCAwIFIgPj4Kc3RyZWFtCnic7Z37nxZVAcbZlQphIVuMELQ0tUASgVq20iKoKEouEXQxW4LASLpIUWJYQRchFoKSpTQNdEOWDYtNkizUoFrSYLm4UCjwt/Q87MyH2XnPmTlzeZfl8Hx/AeedmWfOfF98z5wz55wB3xQ+MOBiX4AoBXn0A3n0A3n0A3n0A3n0g9Djub5CeVXJk0c/8uTRjzx59CNPHv3Ik0c/8uTRjzx59CNPHi10dna+HzwB+iQvL/KYjDw65fb7PHl0yu33eZepx4Ng+/btdaAGDACTJ0+25xbNcyZPXjeYOnUqCzILVD2vCPJoRx7l8RL2+F0wHdT2prGx0Z5bJC8TWfNeAlNATQ9rQFXziiKPZuRRHvPkFUUezVxOHv8HWBdoaWlpAINAbSVXXXXVT4E5t7yCpJA1bwYIHF4NWNSq5hVFHs3IozzmySuKPJq5XDz+ExgqNQYK1HMeAmxKWLhwoX0n3uJXQfKpspbvnSDwaPgWppM172fgJoAbxsw7AUt/7NixTHnyGEMe5TFPnjzKYyQvm8eVK1d+HKQ7JM3NzXnKiRIcuxXwFJs3bzbs8UewdOnSN4HxALWtltRyupXv3AWPdXV1W4HzcZnzjh49Og+wknhFDyxx8Ncrxo0btx0458ljDHmUx0x5fe7x7Nmz/wHLweDBgxPEDRs2bCd4JuDMmTN5yvkvEJ5x7dq13ES1vITdu3d/GYwAkdh7QGo5029IQOhxypQpzsfkysOPIIW9GfwY3H///VGPADd78A7glCePMeRRHjPlyaM8VuY5eOzu7k6QBz4BPgXa29tLKCfqNpvDMw8fPnwBeBswR78XsKW3QF6UBx98cCCgx0WLFrkdkyfvHwDPjLTVHty1/fv3/zJgKghcXg+6urrS8+QxijzKY5Y8eZRHc14hj42NjXx/5RVQXjmpha2MfBU2lnfllVeyjAtBsCW5juOUF8IXcyZOnBi0rdZ0dnY6lylz3mqAq/8KMLQQcxO+UfgSD2cZm5qa0vPkMUQe5TFrXp96PAX4UU+nSpR6MAkcOnQodgz/x9+V+D905/vKR9FVq1a9I4CPr9u2beMnnwW15x9Xh7EXLfk0znmbQM0FAo9bwMM9/BqknyY17znAGzhy5EjeLcte+KdxI+CVoOaRniePIfIYQR7lMYY8luSR/V+VdZsbQCtg5WZtBaycfBTw74888kiecibzGGDHI67ErXvQOY9vcEY8Tp8+fSbgQ1ywhT2F3LQo8dEyNY8txEEPY8LV4A6HXZIbN25M2FEeY8ijPGbJk0cij+a8bB6HDh3KHjG+mDcbVFqOMmLECFYPspYzmfeAoKp1AKQfkNejnWtAR0dH3rw5gIpWr15t2eMMmDt3Lne6DZw+fTq9fPIYIo/ymCVPHok8mvMSPIbttBEGDhx4LTA0YZvho3obyFLOZEKPS5YscTsgr0d8C0dWgi/yUA5YQGWv09ySnph3FFwHqMh+JRynH/Q/Jr+bG8mTxxB5lEfXvIvikaVxspXMk8C1nCls2bLljeB94MSJE27H5PF4Bzh+/LhhJz45801F7PQCyJr3bxC+1lj5WHgYLAavB8FO6ZctjzHkUR5d8+TxPPJozbtkPOKJrQMOecZMczDm8fg9YN6JoxPGAKi0dB6m5vH5kaWIdCvySXjTpk2sQEVu3MLE0dixPHkMkUd5zJInj/Joz8vtkb1xEyZM4FstHOlGW6gdxHbio/NukKWcFn4CgrP+Cjgfl8fjh0BlPQdbjnMmSNZzdu7cmTdvGWAVZsiQIR8I4EBAbGEdjuMr3g3wn/JoKGf6jvIoj1ny5FEe7XkJHpN7GPmSzKoe7gXmnb4BspbTAB+QOcMKzsi1UZIHyuXN+wyItJPju3IoeD+X/YFtbW0fBvxkLCiQx2m47gb19fVhkwD/TeCrEw4iD7SuA87lk8cQeZTHLHl97tH83qM7t9xyC9+Bz1pOA3l/GzPlcVKQC3Pokg8CDofmXY9sXgnKyQsHIP8OBFvZAst/IrD8InAunzyGyKM85syTR3fk8WnAt3Gy+hsNeGXmGdrzePwY4JlHjRq1H2Q6NlPeqVOnbgY1Rlgb4YDw10BJeZX8ArDq4zxdiDwakEd5zJNXiTwml9NtZy89kt+CLCo5D8pvQGquc8FeBu8CPPu0adOcj8ubx4oUvzcRgXxfZsOGDRxiXn5ejDuCRmrUftwOkEcz8iiPefJiVMkjaWtrS5b3VvBkD6eTBwedy15OTirPEP7q2kc5lZdXlAJ5fJOSz478fZTHgnlFkUcz8uiUJ48l0389lkvWvE8DemQFoC/yilIg788g7JKUx4J5RZFHM/LolCePJSOPBp599tnXAXrk/O8LFixImCKxhLwSkEcD8uiaJ48lI48G5NE1r3973Lt3b9jYwLVYnn/++ermlUCBvP+CzwMuULZnz55MefJYMvJoQB5d8/q3R+W55smjH3ny6EeePPqRJ49+5MmjH3ny6EeePPqRF3oUlzby6Afy6Afy6Afy6Afy6Afy6Ad6fvQjTx79yJNHP/Lk0Y88efQjTx79yJNHP/Lk0Y88efQjTx79yJNHP/Lk0Y88efQjTx79yJNHP/L6r8e/gYd6uBpw7FxkfcTm5uaS8yrhOiy3gh8A54Mq854BE0BNTc0N4JPgO+CBBx5g2TgzL1eymjFjRjjNJNeBnDlz5hfBUeCUJ4925DE9N31HeZTHPHmVyGN6rn2HbYCLInIZ4oG9CabtIPX19dzp6+AUKJBn50eA95VzMDofFMt79NFHx4Hvgw0bNiRf6lnARSZ5B1asWMFlIIeAOXPmcFnm5Dx5tCOPCcijFXmUx8wed4HYps3A+fjkvNbWVhiqD6TV2j1e2FK7GOTNS4TLJNPjG4DzQbE8iHsF5Inn3NJcJrXm/KKittXZ5DEVeUxAHl2poseGhgY+3MS2chEsBjqdIjmvq6vrOkBFXH1xSm+CLW8HEY+8zfYnvIvssQBcSPkLQB7T8xKRxwTk0RV5dM9LxF+PBw4cMHikQ8PmxFz7Dg+DdcC+RwdYtmxZtDJ033335c2z0N3d3Qjo8RrgfFxJHvft2zcZMH7evHn4encl58mjBXlMRh6dkMeseRYuQ49sByjPoxutra1Bj+T5LsnyPT7++OPhOp4/BM7HFSgf50N+DnwboGrF5UPZfnzkyJH0PHm0II/JyKOVvvOIH0I+LBo+uQgeq/v72JceHwN33XXXWwDz+GdTU9NfgHOePFqQR6dctzw78tg7Tx4teO2xoaHB0P9I6BG1IOdctzw7Vfc4ceLE6nqEO8hqCtfvwq3le49cZPo1kDVPHi3Io1Ou83VakMfeefJowWuPs2fPtnzCNlan13QuFY/jx4+vrsdJkyaFAXNAW1tb5muM5MmjBXl0ynW+Tgvy2DtPHi1465GvidpfwuGndsuVuZmu1cD8+fMvcY8nT578A/gqGAlQiOsBh9EdBlnz5NGCPDrlOl+nBXnsnefu0d6Cyv7Ha3uAytnckX8fcIGI4YIeT4A7QTAsmScfCrZu3Wo5IE/edlBXV0eHHAr9d+B8bJ68lwCU3gs4wGo0wFeH4wKc8+TRgDzKozw6IY/mPDePNNUAYptQtxnQG1Z3+MkBsLkXRcoZ4WkQGTdHh7eD1HJmCmkGKAw9LgWZji1YPnY6Tup5tHSbJkQe7cijPMqjE/JoznPzyGoL1ezatWtxRB6qL2xXDT7ZZa0IGXKdChWno6ODM0JFPH4OlJv3KmB7J24k6xtuL8kUyDPA2TmgEjWtOk7d4ZQnjzHkUR7z5BmQxxLy/PUIcfTY0DOonH9YJlopz+M+8C2wZs2aVjAf8E88k0cccp6ohImesuRFYO0iaB+/DTgflzfPwuHDh9l6zrFzTnnyGEMe5TFPnoWSPR4IRsoFTabmdx5DzCMGzLmWT1taWviTdzNIns8KDveDonmVPAEGA3hkif8KnI+153WCTKcBXwO8CU899VR6njzGkEd5tOfJozyW4TEQ6DajIx4nLaMGKnPtn9ZGuhZrexPZMnbsWM7Snn5ReeodowA8zgXs78x0bGXeEuA83VcETkvC6taOHTvS8+TRgDzKozwGn8pjQJkeM011HDS1Ju9kz2NdINJ6mlzPCeCCJnnzzPwccMYT3MD0xk2nvI+AqWDPnj2ZzvR7QI/Lly9Pz5PHGPIoj+a8/uCxfOx5L4De0tI9lt9vxQWtgsbVl0G2whnz+Pb/HWDQoEFsLuViVivB+vXr+b6j5TR4Nr4JsIzt7e3pefIYQx7dkEd5lEdrXj/yyDcqFi9eHPN4I/gSeBE0Nzdzbttgh+r0PwZ1nJpp06ZlK1hiHh9D8RzIJlOOjmPAmDFjOOjg9gBUhNYEcIb90aNHs4wrgFOePMaQRzfkUR7l0ZrXjzySjo6OEeBuQI8o9p9AZAfekHU9VKf/kYPyyvcYgaPwuK7KwYMH94J7QE0ls2bNSlh8pTJPHmPIoxvyKI/n5NGW1888Ki9nnjz6kSePfuTJox958uhHnjz6kSePfuTJox958uhHnjz6kSePfuTJox958uhHnjz6kSePfuSFHsWljTz6gTz6gTz6gTz6gTz6gTz6wf8BhQQ12wplbmRzdHJlYW0KZW5kb2JqCjE0IDAgb2JqCjM0NDQKZW5kb2JqCjIgMCBvYmoKPDwgL1R5cGUgL1BhZ2VzIC9LaWRzIFsgMTEgMCBSIF0gL0NvdW50IDEgPj4KZW5kb2JqCjE1IDAgb2JqCjw8IC9DcmVhdG9yIChNYXRwbG90bGliIHYzLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZykKL1Byb2R1Y2VyIChNYXRwbG90bGliIHBkZiBiYWNrZW5kIHYzLjcuMSkgL0NyZWF0aW9uRGF0ZSAoRDoyMDIzMDMxNDE2MDg1MlopCj4+CmVuZG9iagp4cmVmCjAgMTYKMDAwMDAwMDAwMCA2NTUzNSBmIAowMDAwMDAwMDE2IDAwMDAwIG4gCjAwMDAwMDUxNTUgMDAwMDAgbiAKMDAwMDAwMDYwNyAwMDAwMCBuIAowMDAwMDAwNjI4IDAwMDAwIG4gCjAwMDAwMDA2ODggMDAwMDAgbiAKMDAwMDAwMDcwOSAwMDAwMCBuIAowMDAwMDAwNzMwIDAwMDAwIG4gCjAwMDAwMDAwNjUgMDAwMDAgbiAKMDAwMDAwMDM0NCAwMDAwMCBuIAowMDAwMDAwNTg3IDAwMDAwIG4gCjAwMDAwMDAyMDggMDAwMDAgbiAKMDAwMDAwMDU2NyAwMDAwMCBuIAowMDAwMDAwNzYyIDAwMDAwIG4gCjAwMDAwMDUxMzQgMDAwMDAgbiAKMDAwMDAwNTIxNSAwMDAwMCBuIAp0cmFpbGVyCjw8IC9TaXplIDE2IC9Sb290IDEgMCBSIC9JbmZvIDE1IDAgUiA+PgpzdGFydHhyZWYKNTM2NgolJUVPRgo=", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:08:52.069465\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["def show_imgs(imgs, title=None, row_size=4):\n", " # Form a grid of pictures (we use max. 8 columns)\n", " num_imgs = imgs.shape[0] if isinstance(imgs, Tensor) else len(imgs)\n", " is_int = imgs.dtype == torch.int32 if isinstance(imgs, Tensor) else imgs[0].dtype == torch.int32\n", " nrow = min(num_imgs, row_size)\n", " ncol = int(math.ceil(num_imgs / nrow))\n", " imgs = torchvision.utils.make_grid(imgs, nrow=nrow, pad_value=128 if is_int else 0.5)\n", " np_imgs = imgs.cpu().numpy()\n", " # Plot the grid\n", " plt.figure(figsize=(1.5 * nrow, 1.5 * ncol))\n", " plt.imshow(np.transpose(np_imgs, (1, 2, 0)), interpolation=\"nearest\")\n", " plt.axis(\"off\")\n", " if title is not None:\n", " plt.title(title)\n", " plt.show()\n", " plt.close()\n", "\n", "\n", "show_imgs([train_set[i][0] for i in range(8)])"]}, {"cell_type": "markdown", "id": "bd028bb1", "metadata": {"papermill": {"duration": 0.010648, "end_time": "2023-03-14T16:08:52.215564", "exception": false, "start_time": "2023-03-14T16:08:52.204916", "status": "completed"}, "tags": []}, "source": ["## Normalizing Flows as generative model\n", "\n", "In the previous lectures, we have seen Energy-based models, Variational Autoencoders (VAEs)\n", "and Generative Adversarial Networks (GANs) as example of generative models.\n", "However, none of them explicitly learn the probability density function $p(x)$ of the real input data.\n", "While VAEs model a lower bound, energy-based models only implicitly learn the probability density.\n", "GANs on the other hand provide us a sampling mechanism for generating new data, without offering a likelihood estimate.\n", "The generative model we will look at here, called Normalizing Flows, actually models the true data distribution\n", "$p(x)$ and provides us with an exact likelihood estimate.\n", "Below, we can visually compare VAEs, GANs and Flows\n", "(figure credit - [Lilian Weng](https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html)):\n", "\n", "
\n", "\n", "The major difference compared to VAEs is that flows use *invertible* functions $f$\n", "to map the input data $x$ to a latent representation $z$.\n", "To realize this, $z$ must be of the same shape as $x$.\n", "This is in contrast to VAEs where $z$ is usually much lower dimensional than the original input data.\n", "However, an invertible mapping also means that for every data point $x$, we have a corresponding latent representation\n", "$z$ which allows us to perform lossless reconstruction ($z$ to $x$).\n", "In the visualization above, this means that $x=x'$ for flows, no matter what invertible function $f$ and input $x$ we choose.\n", "\n", "Nonetheless, how are normalizing flows modeling a probability density with an invertible function?\n", "The answer to this question is the rule for change of variables.\n", "Specifically, given a prior density $p_z(z)$ (e.g. Gaussian) and an invertible function $f$,\n", "we can determine $p_x(x)$ as follows:\n", "\n", "$$\n", "\\begin{split}\n", " \\int p_x(x) dx & = \\int p_z(z) dz = 1 \\hspace{1cm}\\text{(by definition of a probability distribution)}\\\\\n", " \\Leftrightarrow p_x(x) & = p_z(z) \\left|\\frac{dz}{dx}\\right| = p_z(f(x)) \\left|\\frac{df(x)}{dx}\\right|\n", "\\end{split}\n", "$$\n", "\n", "Hence, in order to determine the probability of $x$, we only need to determine its probability in latent space,\n", "and get the derivate of $f$.\n", "Note that this is for a univariate distribution, and $f$ is required to be invertible and smooth.\n", "For a multivariate case, the derivative becomes a Jacobian of which we need to take the determinant.\n", "As we usually use the log-likelihood as objective, we write the multivariate term with logarithms below:\n", "\n", "$$\n", "\\log p_x(\\mathbf{x}) = \\log p_z(f(\\mathbf{x})) + \\log{} \\left|\\det \\frac{df(\\mathbf{x})}{d\\mathbf{x}}\\right|\n", "$$\n", "\n", "Although we now know how a normalizing flow obtains its likelihood, it might not be clear what a normalizing flow does intuitively.\n", "For this, we should look from the inverse perspective of the flow starting with the prior probability density $p_z(z)$.\n", "If we apply an invertible function on it, we effectively \"transform\" its probability density.\n", "For instance, if $f^{-1}(z)=z+1$, we shift the density by one while still remaining a valid probability distribution,\n", "and being invertible.\n", "We can also apply more complex transformations, like scaling: $f^{-1}(z)=2z+1$, but there you might see a difference.\n", "When you scale, you also change the volume of the probability density, as for example on uniform distributions\n", "(figure credit - [Eric Jang](https://blog.evjang.com/2018/01/nf1.html)):\n", "\n", "
\n", "\n", "You can see that the height of $p(y)$ should be lower than $p(x)$ after scaling.\n", "This change in volume represents $\\left|\\frac{df(x)}{dx}\\right|$ in our equation above,\n", "and ensures that even after scaling, we still have a valid probability distribution.\n", "We can go on with making our function $f$ more complex.\n", "However, the more complex $f$ becomes, the harder it will be to find the inverse $f^{-1}$ of it,\n", "and to calculate the log-determinant of the Jacobian $\\log{} \\left|\\det \\frac{df(\\mathbf{x})}{d\\mathbf{x}}\\right|$.\n", "An easier trick to stack multiple invertible functions $f_{1,...,K}$ after each other, as all together,\n", "they still represent a single, invertible function.\n", "Using multiple, learnable invertible functions, a normalizing flow attempts to transform\n", "$p_z(z)$ slowly into a more complex distribution which should finally be $p_x(x)$.\n", "We visualize the idea below\n", "(figure credit - [Lilian Weng](https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html)):\n", "\n", "
\n", "\n", "Starting from $z_0$, which follows the prior Gaussian distribution, we sequentially apply the invertible\n", "functions $f_1,f_2,...,f_K$, until $z_K$ represents $x$.\n", "Note that in the figure above, the functions $f$ represent the inverted function from $f$ we had above\n", "(here: $f:Z\\to X$, above: $f:X\\to Z$).\n", "This is just a different notation and has no impact on the actual flow design because all $f$ need to be invertible anyways.\n", "When we estimate the log likelihood of a data point $x$ as in the equations above,\n", "we run the flows in the opposite direction than visualized above.\n", "Multiple flow layers have been proposed that use a neural network as learnable parameters,\n", "such as the planar and radial flow.\n", "However, we will focus here on flows that are commonly used in image\n", "modeling, and will discuss them in the rest of the notebook along with\n", "the details of how to train a normalizing flow."]}, {"cell_type": "markdown", "id": "be852389", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.010512, "end_time": "2023-03-14T16:08:52.236613", "exception": false, "start_time": "2023-03-14T16:08:52.226101", "status": "completed"}, "tags": []}, "source": ["## Normalizing Flows on images\n", "\n", "
\n", "\n", "To become familiar with normalizing flows, especially for the application of image modeling,\n", "it is best to discuss the different elements in a flow along with the implementation.\n", "As a general concept, we want to build a normalizing flow that maps an input image (here MNIST) to an equally sized latent space:\n", "\n", "
\n", "\n", "As a first step, we will implement a template of a normalizing flow in PyTorch Lightning.\n", "During training and validation, a normalizing flow performs density estimation in the forward direction.\n", "For this, we apply a series of flow transformations on the input $x$ and estimate the probability\n", "of the input by determining the probability of the transformed point $z$ given a prior,\n", "and the change of volume caused by the transformations.\n", "During inference, we can do both density estimation and sampling new points by inverting the flow transformations.\n", "Therefore, we define a function `_get_likelihood` which performs density estimation,\n", "and `sample` to generate new examples.\n", "The functions `training_step`, `validation_step` and `test_step` all make use of `_get_likelihood`.\n", "\n", "The standard metric used in generative models, and in particular normalizing flows, is bits per dimensions (bpd).\n", "Bpd is motivated from an information theory perspective and describes how many bits we would need to encode a particular example in our modeled distribution.\n", "The less bits we need, the more likely the example in our distribution.\n", "When we test for the bits per dimension of our test dataset, we can judge whether our model generalizes to new samples of the dataset and didn't memorize the training dataset.\n", "In order to calculate the bits per dimension score, we can rely on the negative log-likelihood and change the log base (as bits are binary while NLL is usually exponential):\n", "\n", "$$\\text{bpd} = \\text{nll} \\cdot \\log_2\\left(\\exp(1)\\right) \\cdot \\left(\\prod d_i\\right)^{-1}$$\n", "\n", "where $d_1,...,d_K$ are the dimensions of the input.\n", "For images, this would be the height, width and channel number.\n", "We divide the log likelihood by these extra dimensions to have a metric which we can compare for different image resolutions.\n", "In the original image space, MNIST examples have a bits per dimension\n", "score of 8 (we need 8 bits to encode each pixel as there are 256\n", "possible values)."]}, {"cell_type": "code", "execution_count": 6, "id": "e75ff1f5", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:52.260104Z", "iopub.status.busy": "2023-03-14T16:08:52.259443Z", "iopub.status.idle": "2023-03-14T16:08:52.276630Z", "shell.execute_reply": "2023-03-14T16:08:52.276152Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.031425, "end_time": "2023-03-14T16:08:52.278733", "exception": false, "start_time": "2023-03-14T16:08:52.247308", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class ImageFlow(L.LightningModule):\n", " def __init__(self, flows, import_samples=8):\n", " \"\"\"\n", " Args:\n", " flows: A list of flows (each a nn.Module) that should be applied on the images.\n", " import_samples: Number of importance samples to use during testing (see explanation below). Can be changed at any time\n", " \"\"\"\n", " super().__init__()\n", " self.flows = nn.ModuleList(flows)\n", " self.import_samples = import_samples\n", " # Create prior distribution for final latent space\n", " self.prior = torch.distributions.normal.Normal(loc=0.0, scale=1.0)\n", " # Example input for visualizing the graph\n", " self.example_input_array = train_set[0][0].unsqueeze(dim=0)\n", "\n", " def forward(self, imgs):\n", " # The forward function is only used for visualizing the graph\n", " return self._get_likelihood(imgs)\n", "\n", " def encode(self, imgs):\n", " # Given a batch of images, return the latent representation z and ldj of the transformations\n", " z, ldj = imgs, torch.zeros(imgs.shape[0], device=self.device)\n", " for flow in self.flows:\n", " z, ldj = flow(z, ldj, reverse=False)\n", " return z, ldj\n", "\n", " def _get_likelihood(self, imgs, return_ll=False):\n", " \"\"\"Given a batch of images, return the likelihood of those.\n", "\n", " If return_ll is True, this function returns the log likelihood of the input. Otherwise, the ouptut metric is\n", " bits per dimension (scaled negative log likelihood)\n", " \"\"\"\n", " z, ldj = self.encode(imgs)\n", " log_pz = self.prior.log_prob(z).sum(dim=[1, 2, 3])\n", " log_px = ldj + log_pz\n", " nll = -log_px\n", " # Calculating bits per dimension\n", " bpd = nll * np.log2(np.exp(1)) / np.prod(imgs.shape[1:])\n", " return bpd.mean() if not return_ll else log_px\n", "\n", " @torch.no_grad()\n", " def sample(self, img_shape, z_init=None):\n", " \"\"\"Sample a batch of images from the flow.\"\"\"\n", " # Sample latent representation from prior\n", " if z_init is None:\n", " z = self.prior.sample(sample_shape=img_shape).to(device)\n", " else:\n", " z = z_init.to(device)\n", "\n", " # Transform z to x by inverting the flows\n", " ldj = torch.zeros(img_shape[0], device=device)\n", " for flow in reversed(self.flows):\n", " z, ldj = flow(z, ldj, reverse=True)\n", " return z\n", "\n", " def configure_optimizers(self):\n", " optimizer = optim.Adam(self.parameters(), lr=1e-3)\n", " # An scheduler is optional, but can help in flows to get the last bpd improvement\n", " scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.99)\n", " return [optimizer], [scheduler]\n", "\n", " def training_step(self, batch, batch_idx):\n", " # Normalizing flows are trained by maximum likelihood => return bpd\n", " loss = self._get_likelihood(batch[0])\n", " self.log(\"train_bpd\", loss)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " loss = self._get_likelihood(batch[0])\n", " self.log(\"val_bpd\", loss)\n", "\n", " def test_step(self, batch, batch_idx):\n", " # Perform importance sampling during testing => estimate likelihood M times for each image\n", " samples = []\n", " for _ in range(self.import_samples):\n", " img_ll = self._get_likelihood(batch[0], return_ll=True)\n", " samples.append(img_ll)\n", " img_ll = torch.stack(samples, dim=-1)\n", "\n", " # To average the probabilities, we need to go from log-space to exp, and back to log.\n", " # Logsumexp provides us a stable implementation for this\n", " img_ll = torch.logsumexp(img_ll, dim=-1) - np.log(self.import_samples)\n", "\n", " # Calculate final bpd\n", " bpd = -img_ll * np.log2(np.exp(1)) / np.prod(batch[0].shape[1:])\n", " bpd = bpd.mean()\n", "\n", " self.log(\"test_bpd\", bpd)"]}, {"cell_type": "markdown", "id": "939bf7af", "metadata": {"papermill": {"duration": 0.01055, "end_time": "2023-03-14T16:08:52.304898", "exception": false, "start_time": "2023-03-14T16:08:52.294348", "status": "completed"}, "tags": []}, "source": ["The `test_step` function differs from the training and validation step in that it makes use of importance sampling.\n", "We will discuss the motiviation and details behind this after\n", "understanding how flows model discrete images in continuous space."]}, {"cell_type": "markdown", "id": "248c83f9", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.010749, "end_time": "2023-03-14T16:08:52.326437", "exception": false, "start_time": "2023-03-14T16:08:52.315688", "status": "completed"}, "tags": []}, "source": ["### Dequantization\n", "\n", "Normalizing flows rely on the rule of change of variables, which is naturally defined in continuous space.\n", "Applying flows directly on discrete data leads to undesired density models where arbitrarly high likelihood are placed on a few, particular values.\n", "See the illustration below:\n", "\n", "
\n", "\n", "The black points represent the discrete points, and the green volume the density modeled by a normalizing flow in continuous space.\n", "The flow would continue to increase the likelihood for $x=0,1,2,3$ while having no volume on any other point.\n", "Remember that in continuous space, we have the constraint that the overall volume of the probability density must be 1 ($\\int p(x)dx=1$).\n", "Otherwise, we don't model a probability distribution anymore.\n", "However, the discrete points $x=0,1,2,3$ represent delta peaks with no width in continuous space.\n", "This is why the flow can place an infinite high likelihood on these few points while still representing a distribution in continuous space.\n", "Nonetheless, the learned density does not tell us anything about the distribution among the discrete points,\n", "as in discrete space, the likelihoods of those four points would have to sum to 1, not to infinity.\n", "\n", "To prevent such degenerated solutions, a common solution is to add a small amount of noise to each discrete value, which is also referred to as dequantization.\n", "Considering $x$ as an integer (as it is the case for images), the dequantized representation $v$ can be formulated as $v=x+u$ where $u\\in[0,1)^D$.\n", "Thus, the discrete value $1$ is modeled by a distribution over the interval $[1.0, 2.0)$, the value $2$ by an volume over $[2.0, 3.0)$, etc.\n", "Our objective of modeling $p(x)$ becomes:\n", "\n", "$$ p(x) = \\int p(x+u)du = \\int \\frac{q(u|x)}{q(u|x)}p(x+u)du = \\mathbb{E}_{u\\sim q(u|x)}\\left[\\frac{p(x+u)}{q(u|x)} \\right]$$\n", "\n", "with $q(u|x)$ being the noise distribution.\n", "For now, we assume it to be uniform, which can also be written as $p(x)=\\mathbb{E}_{u\\sim U(0,1)^D}\\left[p(x+u) \\right]$.\n", "\n", "In the following, we will implement Dequantization as a flow transformation itself.\n", "After adding noise to the discrete values, we additionally transform the volume into a Gaussian-like shape.\n", "This is done by scaling $x+u$ between $0$ and $1$, and applying the invert of the sigmoid function $\\sigma(z)^{-1} = \\log z - \\log 1-z$.\n", "If we would not do this, we would face two problems:\n", "\n", "1.\n", "The input is scaled between 0 and 256 while the prior distribution is a Gaussian with mean $0$ and standard deviation $1$.\n", "In the first iterations after initializing the parameters of the flow, we would have extremely low likelihoods for large values like $256$.\n", "This would cause the training to diverge instantaneously.\n", "2.\n", "As the output distribution is a Gaussian, it is beneficial for the flow to have a similarly shaped input distribution.\n", "This will reduce the modeling complexity that is required by the flow.\n", "\n", "Overall, we can implement dequantization as follows:"]}, {"cell_type": "code", "execution_count": 7, "id": "9e1c0242", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:52.349387Z", "iopub.status.busy": "2023-03-14T16:08:52.348820Z", "iopub.status.idle": "2023-03-14T16:08:52.362864Z", "shell.execute_reply": "2023-03-14T16:08:52.362396Z"}, "papermill": {"duration": 0.028177, "end_time": "2023-03-14T16:08:52.365251", "exception": false, "start_time": "2023-03-14T16:08:52.337074", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class Dequantization(nn.Module):\n", " def __init__(self, alpha=1e-5, quants=256):\n", " \"\"\"\n", " Args:\n", " alpha: small constant that is used to scale the original input.\n", " Prevents dealing with values very close to 0 and 1 when inverting the sigmoid\n", " quants: Number of possible discrete values (usually 256 for 8-bit image)\n", " \"\"\"\n", " super().__init__()\n", " self.alpha = alpha\n", " self.quants = quants\n", "\n", " def forward(self, z, ldj, reverse=False):\n", " if not reverse:\n", " z, ldj = self.dequant(z, ldj)\n", " z, ldj = self.sigmoid(z, ldj, reverse=True)\n", " else:\n", " z, ldj = self.sigmoid(z, ldj, reverse=False)\n", " z = z * self.quants\n", " ldj += np.log(self.quants) * np.prod(z.shape[1:])\n", " z = torch.floor(z).clamp(min=0, max=self.quants - 1).to(torch.int32)\n", " return z, ldj\n", "\n", " def sigmoid(self, z, ldj, reverse=False):\n", " # Applies an invertible sigmoid transformation\n", " if not reverse:\n", " ldj += (-z - 2 * F.softplus(-z)).sum(dim=[1, 2, 3])\n", " z = torch.sigmoid(z)\n", " else:\n", " z = z * (1 - self.alpha) + 0.5 * self.alpha # Scale to prevent boundaries 0 and 1\n", " ldj += np.log(1 - self.alpha) * np.prod(z.shape[1:])\n", " ldj += (-torch.log(z) - torch.log(1 - z)).sum(dim=[1, 2, 3])\n", " z = torch.log(z) - torch.log(1 - z)\n", " return z, ldj\n", "\n", " def dequant(self, z, ldj):\n", " # Transform discrete values to continuous volumes\n", " z = z.to(torch.float32)\n", " z = z + torch.rand_like(z).detach()\n", " z = z / self.quants\n", " ldj -= np.log(self.quants) * np.prod(z.shape[1:])\n", " return z, ldj"]}, {"cell_type": "markdown", "id": "8f1d2ba6", "metadata": {"papermill": {"duration": 0.010565, "end_time": "2023-03-14T16:08:52.390006", "exception": false, "start_time": "2023-03-14T16:08:52.379441", "status": "completed"}, "tags": []}, "source": ["A good check whether a flow is correctly implemented or not, is to verify that it is invertible.\n", "Hence, we will dequantize a randomly chosen training image, and then quantize it again.\n", "We would expect that we would get the exact same image out:"]}, {"cell_type": "code", "execution_count": 8, "id": "05d6e03d", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:52.412837Z", "iopub.status.busy": "2023-03-14T16:08:52.412458Z", "iopub.status.idle": "2023-03-14T16:08:52.426314Z", "shell.execute_reply": "2023-03-14T16:08:52.425832Z"}, "papermill": {"duration": 0.027687, "end_time": "2023-03-14T16:08:52.428343", "exception": false, "start_time": "2023-03-14T16:08:52.400656", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 42\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Dequantization was not invertible.\n", "Original value: 0\n", "Reconstructed value: 1\n"]}], "source": ["# Testing invertibility of dequantization layer\n", "L.seed_everything(42)\n", "orig_img = train_set[0][0].unsqueeze(dim=0)\n", "ldj = torch.zeros(\n", " 1,\n", ")\n", "dequant_module = Dequantization()\n", "deq_img, ldj = dequant_module(orig_img, ldj, reverse=False)\n", "reconst_img, ldj = dequant_module(deq_img, ldj, reverse=True)\n", "\n", "d1, d2 = torch.where(orig_img.squeeze() != reconst_img.squeeze())\n", "if len(d1) != 0:\n", " print(\"Dequantization was not invertible.\")\n", " for i in range(d1.shape[0]):\n", " print(\"Original value:\", orig_img[0, 0, d1[i], d2[i]].item())\n", " print(\"Reconstructed value:\", reconst_img[0, 0, d1[i], d2[i]].item())\n", "else:\n", " print(\"Successfully inverted dequantization\")\n", "\n", "# Layer is not strictly invertible due to float precision constraints\n", "# assert (orig_img == reconst_img).all().item()"]}, {"cell_type": "markdown", "id": "64743184", "metadata": {"papermill": {"duration": 0.010719, "end_time": "2023-03-14T16:08:52.453044", "exception": false, "start_time": "2023-03-14T16:08:52.442325", "status": "completed"}, "tags": []}, "source": ["In contrast to our expectation, the test fails.\n", "However, this is no reason to doubt our implementation here as only one single value is not equal to the original.\n", "This is caused due to numerical inaccuracies in the sigmoid invert.\n", "While the input space to the inverted sigmoid is scaled between 0 and 1, the output space is between $-\\infty$ and $\\infty$.\n", "And as we use 32 bits to represent the numbers (in addition to applying logs over and over again),\n", "such inaccuries can occur and should not be worrisome.\n", "Nevertheless, it is good to be aware of them, and can be improved by using a double tensor (float64).\n", "\n", "Finally, we can take our dequantization and actually visualize the\n", "distribution it transforms the discrete values into:"]}, {"cell_type": "code", "execution_count": 9, "id": "d4d49a0a", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:52.476344Z", "iopub.status.busy": "2023-03-14T16:08:52.475761Z", "iopub.status.idle": "2023-03-14T16:08:52.922166Z", "shell.execute_reply": "2023-03-14T16:08:52.921514Z"}, "papermill": {"duration": 0.462246, "end_time": "2023-03-14T16:08:52.926115", "exception": false, "start_time": "2023-03-14T16:08:52.463869", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:08:52.653925\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["\n", "\n", "def visualize_dequantization(quants, prior=None):\n", " \"\"\"Function for visualizing the dequantization values of discrete values in continuous space.\"\"\"\n", " # Prior over discrete values. If not given, a uniform is assumed\n", " if prior is None:\n", " prior = np.ones(quants, dtype=np.float32) / quants\n", " prior = prior / prior.sum() # Ensure proper categorical distribution\n", "\n", " inp = torch.arange(-4, 4, 0.01).view(-1, 1, 1, 1) # Possible continuous values we want to consider\n", " ldj = torch.zeros(inp.shape[0])\n", " dequant_module = Dequantization(quants=quants)\n", " # Invert dequantization on continuous values to find corresponding discrete value\n", " out, ldj = dequant_module.forward(inp, ldj, reverse=True)\n", " inp, out, prob = inp.squeeze().numpy(), out.squeeze().numpy(), ldj.exp().numpy()\n", " prob = prob * prior[out] # Probability scaled by categorical prior\n", "\n", " # Plot volumes and continuous distribution\n", " sns.set_style(\"white\")\n", " _ = plt.figure(figsize=(6, 3))\n", " x_ticks = []\n", " for v in np.unique(out):\n", " indices = np.where(out == v)\n", " color = to_rgb(\"C%i\" % v)\n", " plt.fill_between(inp[indices], prob[indices], np.zeros(indices[0].shape[0]), color=color + (0.5,), label=str(v))\n", " plt.plot([inp[indices[0][0]]] * 2, [0, prob[indices[0][0]]], color=color)\n", " plt.plot([inp[indices[0][-1]]] * 2, [0, prob[indices[0][-1]]], color=color)\n", " x_ticks.append(inp[indices[0][0]])\n", " x_ticks.append(inp.max())\n", " plt.xticks(x_ticks, [\"%.1f\" % x for x in x_ticks])\n", " plt.plot(inp, prob, color=(0.0, 0.0, 0.0))\n", " # Set final plot properties\n", " plt.ylim(0, prob.max() * 1.1)\n", " plt.xlim(inp.min(), inp.max())\n", " plt.xlabel(\"z\")\n", " plt.ylabel(\"Probability\")\n", " plt.title(\"Dequantization distribution for %i discrete values\" % quants)\n", " plt.legend()\n", " plt.show()\n", " plt.close()\n", "\n", "\n", "visualize_dequantization(quants=8)"]}, {"cell_type": "markdown", "id": "044aceb3", "metadata": {"papermill": {"duration": 0.019884, "end_time": "2023-03-14T16:08:52.969239", "exception": false, "start_time": "2023-03-14T16:08:52.949355", "status": "completed"}, "tags": []}, "source": ["The visualized distribution show the sub-volumes that are assigned to the different discrete values.\n", "The value $0$ has its volume between $[-\\infty, -1.9)$, the value $1$ is represented by the interval $[-1.9, -1.1)$, etc.\n", "The volume for each discrete value has the same probability mass.\n", "That's why the volumes close to the center (e.g. 3 and 4) have a smaller area on the z-axis as others\n", "($z$ is being used to denote the output of the whole dequantization flow).\n", "\n", "Effectively, the consecutive normalizing flow models discrete images by the following objective:\n", "\n", "$$\\log p(x) = \\log \\mathbb{E}_{u\\sim q(u|x)}\\left[\\frac{p(x+u)}{q(u|x)} \\right] \\geq \\mathbb{E}_{u}\\left[\\log \\frac{p(x+u)}{q(u|x)} \\right]$$\n", "\n", "Although normalizing flows are exact in likelihood, we have a lower bound.\n", "Specifically, this is an example of the Jensen inequality because we need to move the log into the expectation so we can use Monte-carlo estimates.\n", "In general, this bound is considerably smaller than the ELBO in variational autoencoders.\n", "Actually, we can reduce the bound ourselves by estimating the expectation not by one, but by $M$ samples.\n", "In other words, we can apply importance sampling which leads to the following inequality:\n", "\n", "$$\\log p(x) = \\log \\mathbb{E}_{u\\sim q(u|x)}\\left[\\frac{p(x+u)}{q(u|x)} \\right] \\geq \\mathbb{E}_{u}\\left[\\log \\frac{1}{M} \\sum_{m=1}^{M} \\frac{p(x+u_m)}{q(u_m|x)} \\right] \\geq \\mathbb{E}_{u}\\left[\\log \\frac{p(x+u)}{q(u|x)} \\right]$$\n", "\n", "The importance sampling $\\frac{1}{M} \\sum_{m=1}^{M} \\frac{p(x+u_m)}{q(u_m|x)}$ becomes\n", "$\\mathbb{E}_{u\\sim q(u|x)}\\left[\\frac{p(x+u)}{q(u|x)} \\right]$ if $M\\to \\infty$,\n", "so that the more samples we use, the tighter the bound is.\n", "During testing, we can make use of this property and have it implemented in `test_step` in `ImageFlow`.\n", "In theory, we could also use this tighter bound during training.\n", "However, related work has shown that this does not necessarily lead to\n", "an improvement given the additional computational cost, and it is more\n", "efficient to stick with a single estimate [5]."]}, {"cell_type": "markdown", "id": "aca798a1", "metadata": {"papermill": {"duration": 0.014861, "end_time": "2023-03-14T16:08:53.000621", "exception": false, "start_time": "2023-03-14T16:08:52.985760", "status": "completed"}, "tags": []}, "source": ["### Variational Dequantization\n", "\n", "Dequantization uses a uniform distribution for the noise $u$ which effectively leads to images being represented as hypercubes\n", "(cube in high dimensions) with sharp borders.\n", "However, modeling such sharp borders is not easy for a flow as it uses smooth transformations to convert it into a Gaussian distribution.\n", "\n", "Another way of looking at it is if we change the prior distribution in the previous visualization.\n", "Imagine we have independent Gaussian noise on pixels which is commonly the case for any real-world taken picture.\n", "Therefore, the flow would have to model a distribution as above, but with the individual volumes scaled as follows:"]}, {"cell_type": "code", "execution_count": 10, "id": "320081f9", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:53.040814Z", "iopub.status.busy": "2023-03-14T16:08:53.040389Z", "iopub.status.idle": "2023-03-14T16:08:53.454171Z", "shell.execute_reply": "2023-03-14T16:08:53.453564Z"}, "papermill": {"duration": 0.439315, "end_time": "2023-03-14T16:08:53.458705", "exception": false, "start_time": "2023-03-14T16:08:53.019390", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:08:53.205021\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["visualize_dequantization(quants=8, prior=np.array([0.075, 0.2, 0.4, 0.2, 0.075, 0.025, 0.0125, 0.0125]))"]}, {"cell_type": "markdown", "id": "b202c9c8", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.015208, "end_time": "2023-03-14T16:08:53.492761", "exception": false, "start_time": "2023-03-14T16:08:53.477553", "status": "completed"}, "tags": []}, "source": ["Transforming such a probability into a Gaussian is a difficult task, especially with such hard borders.\n", "Dequantization has therefore been extended to more sophisticated, learnable distributions beyond uniform in a variational framework.\n", "In particular, if we remember the learning objective\n", "$\\log p(x) = \\log \\mathbb{E}_{u}\\left[\\frac{p(x+u)}{q(u|x)} \\right]$,\n", "the uniform distribution can be replaced by a learned distribution $q_{\\theta}(u|x)$ with support over $u\\in[0,1)^D$.\n", "This approach is called Variational Dequantization and has been proposed by Ho et al.\n", "[3].\n", "How can we learn such a distribution?\n", "We can use a second normalizing flow that takes $x$ as external input and learns a flexible distribution over $u$.\n", "To ensure a support over $[0,1)^D$, we can apply a sigmoid activation function as final flow transformation.\n", "\n", "Inheriting the original dequantization class, we can implement variational dequantization as follows:"]}, {"cell_type": "code", "execution_count": 11, "id": "9977763b", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:53.525421Z", "iopub.status.busy": "2023-03-14T16:08:53.525020Z", "iopub.status.idle": "2023-03-14T16:08:53.536012Z", "shell.execute_reply": "2023-03-14T16:08:53.535528Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.029932, "end_time": "2023-03-14T16:08:53.538205", "exception": false, "start_time": "2023-03-14T16:08:53.508273", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class VariationalDequantization(Dequantization):\n", " def __init__(self, var_flows, alpha=1e-5):\n", " \"\"\"\n", " Args:\n", " var_flows: A list of flow transformations to use for modeling q(u|x)\n", " alpha: Small constant, see Dequantization for details\n", " \"\"\"\n", " super().__init__(alpha=alpha)\n", " self.flows = nn.ModuleList(var_flows)\n", "\n", " def dequant(self, z, ldj):\n", " z = z.to(torch.float32)\n", " img = (z / 255.0) * 2 - 1 # We condition the flows on x, i.e. the original image\n", "\n", " # Prior of u is a uniform distribution as before\n", " # As most flow transformations are defined on [-infinity,+infinity], we apply an inverse sigmoid first.\n", " deq_noise = torch.rand_like(z).detach()\n", " deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=True)\n", " for flow in self.flows:\n", " deq_noise, ldj = flow(deq_noise, ldj, reverse=False, orig_img=img)\n", " deq_noise, ldj = self.sigmoid(deq_noise, ldj, reverse=False)\n", "\n", " # After the flows, apply u as in standard dequantization\n", " z = (z + deq_noise) / 256.0\n", " ldj -= np.log(256.0) * np.prod(z.shape[1:])\n", " return z, ldj"]}, {"cell_type": "markdown", "id": "b13b49a8", "metadata": {"papermill": {"duration": 0.014961, "end_time": "2023-03-14T16:08:53.572621", "exception": false, "start_time": "2023-03-14T16:08:53.557660", "status": "completed"}, "tags": []}, "source": ["Variational dequantization can be used as a substitute for dequantization.\n", "We will compare dequantization and variational dequantization in later experiments."]}, {"cell_type": "markdown", "id": "4d25fabd", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.014996, "end_time": "2023-03-14T16:08:53.602877", "exception": false, "start_time": "2023-03-14T16:08:53.587881", "status": "completed"}, "tags": []}, "source": ["### Coupling layers\n", "\n", "
\n", "\n", "Next, we look at possible transformations to apply inside the flow.\n", "A recent popular flow layer, which works well in combination with deep neural networks,\n", "is the coupling layer introduced by Dinh et al.\n", "[1].\n", "The input $z$ is arbitrarily split into two parts, $z_{1:j}$ and $z_{j+1:d}$, of which the first remains unchanged by the flow.\n", "Yet, $z_{1:j}$ is used to parameterize the transformation for the second part, $z_{j+1:d}$.\n", "Various transformations have been proposed in recent time [3,4], but here we will settle for the simplest and most efficient one: affine coupling.\n", "In this coupling layer, we apply an affine transformation by shifting the input by a bias $\\mu$ and scale it by $\\sigma$.\n", "In other words, our transformation looks as follows:\n", "\n", "$$z'_{j+1:d} = \\mu_{\\theta}(z_{1:j}) + \\sigma_{\\theta}(z_{1:j}) \\odot z_{j+1:d}$$\n", "\n", "The functions $\\mu$ and $\\sigma$ are implemented as a shared neural network,\n", "and the sum and multiplication are performed element-wise.\n", "The LDJ is thereby the sum of the logs of the scaling factors: $\\sum_i \\left[\\log \\sigma_{\\theta}(z_{1:j})\\right]_i$.\n", "Inverting the layer can as simply be done as subtracting the bias and dividing by the scale:\n", "\n", "$$z_{j+1:d} = \\left(z'_{j+1:d} - \\mu_{\\theta}(z_{1:j})\\right) / \\sigma_{\\theta}(z_{1:j})$$\n", "\n", "We can also visualize the coupling layer in form of a computation graph,\n", "where $z_1$ represents $z_{1:j}$, and $z_2$ represents $z_{j+1:d}$:\n", "\n", "
\n", "\n", "In our implementation, we will realize the splitting of variables as masking.\n", "The variables to be transformed, $z_{j+1:d}$, are masked when passing $z$ to the shared network to predict the transformation parameters.\n", "When applying the transformation, we mask the parameters for $z_{1:j}$\n", "so that we have an identity operation for those variables:"]}, {"cell_type": "code", "execution_count": 12, "id": "d18ee4c7", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:53.634921Z", "iopub.status.busy": "2023-03-14T16:08:53.634448Z", "iopub.status.idle": "2023-03-14T16:08:53.647226Z", "shell.execute_reply": "2023-03-14T16:08:53.646704Z"}, "papermill": {"duration": 0.031527, "end_time": "2023-03-14T16:08:53.649422", "exception": false, "start_time": "2023-03-14T16:08:53.617895", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class CouplingLayer(nn.Module):\n", " def __init__(self, network, mask, c_in):\n", " \"\"\"Coupling layer inside a normalizing flow.\n", "\n", " Args:\n", " network: A PyTorch nn.Module constituting the deep neural network for mu and sigma.\n", " Output shape should be twice the channel size as the input.\n", " mask: Binary mask (0 or 1) where 0 denotes that the element should be transformed,\n", " while 1 means the latent will be used as input to the NN.\n", " c_in: Number of input channels\n", " \"\"\"\n", " super().__init__()\n", " self.network = network\n", " self.scaling_factor = nn.Parameter(torch.zeros(c_in))\n", " # Register mask as buffer as it is a tensor which is not a parameter,\n", " # but should be part of the modules state.\n", " self.register_buffer(\"mask\", mask)\n", "\n", " def forward(self, z, ldj, reverse=False, orig_img=None):\n", " \"\"\"\n", " Args:\n", " z: Latent input to the flow\n", " ldj: The current ldj of the previous flows.\n", " The ldj of this layer will be added to this tensor.\n", " reverse: If True, we apply the inverse of the layer.\n", " orig_img (optional): Only needed in VarDeq. Allows external\n", " input to condition the flow on (e.g. original image)\n", " \"\"\"\n", " # Apply network to masked input\n", " z_in = z * self.mask\n", " if orig_img is None:\n", " nn_out = self.network(z_in)\n", " else:\n", " nn_out = self.network(torch.cat([z_in, orig_img], dim=1))\n", " s, t = nn_out.chunk(2, dim=1)\n", "\n", " # Stabilize scaling output\n", " s_fac = self.scaling_factor.exp().view(1, -1, 1, 1)\n", " s = torch.tanh(s / s_fac) * s_fac\n", "\n", " # Mask outputs (only transform the second part)\n", " s = s * (1 - self.mask)\n", " t = t * (1 - self.mask)\n", "\n", " # Affine transformation\n", " if not reverse:\n", " # Whether we first shift and then scale, or the other way round,\n", " # is a design choice, and usually does not have a big impact\n", " z = (z + t) * torch.exp(s)\n", " ldj += s.sum(dim=[1, 2, 3])\n", " else:\n", " z = (z * torch.exp(-s)) - t\n", " ldj -= s.sum(dim=[1, 2, 3])\n", "\n", " return z, ldj"]}, {"cell_type": "markdown", "id": "1e103dac", "metadata": {"papermill": {"duration": 0.014956, "end_time": "2023-03-14T16:08:53.684033", "exception": false, "start_time": "2023-03-14T16:08:53.669077", "status": "completed"}, "tags": []}, "source": ["For stabilization purposes, we apply a $\\tanh$ activation function on the scaling output.\n", "This prevents sudden large output values for the scaling that can destabilize training.\n", "To still allow scaling factors smaller or larger than -1 and 1 respectively,\n", "we have a learnable parameter per dimension, called `scaling_factor`.\n", "This scales the tanh to different limits.\n", "Below, we visualize the effect of the scaling factor on the output activation of the scaling terms:"]}, {"cell_type": "code", "execution_count": 13, "id": "44e8290a", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:53.715799Z", "iopub.status.busy": "2023-03-14T16:08:53.715417Z", "iopub.status.idle": "2023-03-14T16:08:54.573050Z", "shell.execute_reply": "2023-03-14T16:08:54.572427Z"}, "papermill": {"duration": 0.877475, "end_time": "2023-03-14T16:08:54.576709", "exception": false, "start_time": "2023-03-14T16:08:53.699234", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:08:54.070086\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["with torch.no_grad():\n", " x = torch.arange(-5, 5, 0.01)\n", " scaling_factors = [0.5, 1, 2]\n", " sns.set()\n", " fig, ax = plt.subplots(1, 3, figsize=(12, 3))\n", " for i, scale in enumerate(scaling_factors):\n", " y = torch.tanh(x / scale) * scale\n", " ax[i].plot(x.numpy(), y.numpy())\n", " ax[i].set_title(\"Scaling factor: \" + str(scale))\n", " ax[i].set_ylim(-3, 3)\n", " plt.subplots_adjust(wspace=0.4)\n", " sns.reset_orig()\n", " plt.show()"]}, {"cell_type": "markdown", "id": "0a4c5cae", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.016704, "end_time": "2023-03-14T16:08:54.614666", "exception": false, "start_time": "2023-03-14T16:08:54.597962", "status": "completed"}, "tags": []}, "source": ["Coupling layers generalize to any masking technique we could think of.\n", "However, the most common approach for images is to split the input $z$ in half, using a checkerboard mask or channel mask.\n", "A checkerboard mask splits the variables across the height and width dimensions and assigns each other pixel to $z_{j+1:d}$.\n", "Thereby, the mask is shared across channels.\n", "In contrast, the channel mask assigns half of the channels to $z_{j+1:d}$, and the other half to $z_{1:j+1}$.\n", "Note that when we apply multiple coupling layers, we invert the masking for each other layer so that each variable is transformed a similar amount of times.\n", "\n", "Let's implement a function that creates a checkerboard mask and a channel mask for us:"]}, {"cell_type": "code", "execution_count": 14, "id": "124363d3", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:54.649537Z", "iopub.status.busy": "2023-03-14T16:08:54.649056Z", "iopub.status.idle": "2023-03-14T16:08:54.656432Z", "shell.execute_reply": "2023-03-14T16:08:54.655621Z"}, "papermill": {"duration": 0.027594, "end_time": "2023-03-14T16:08:54.658995", "exception": false, "start_time": "2023-03-14T16:08:54.631401", "status": "completed"}, "tags": []}, "outputs": [], "source": ["def create_checkerboard_mask(h, w, invert=False):\n", " x, y = torch.arange(h, dtype=torch.int32), torch.arange(w, dtype=torch.int32)\n", " xx, yy = torch.meshgrid(x, y)\n", " mask = torch.fmod(xx + yy, 2)\n", " mask = mask.to(torch.float32).view(1, 1, h, w)\n", " if invert:\n", " mask = 1 - mask\n", " return mask\n", "\n", "\n", "def create_channel_mask(c_in, invert=False):\n", " mask = torch.cat([torch.ones(c_in // 2, dtype=torch.float32), torch.zeros(c_in - c_in // 2, dtype=torch.float32)])\n", " mask = mask.view(1, c_in, 1, 1)\n", " if invert:\n", " mask = 1 - mask\n", " return mask"]}, {"cell_type": "markdown", "id": "8d6f119f", "metadata": {"papermill": {"duration": 0.016328, "end_time": "2023-03-14T16:08:54.696519", "exception": false, "start_time": "2023-03-14T16:08:54.680191", "status": "completed"}, "tags": []}, "source": ["We can also visualize the corresponding masks for an image of size $8\\times 8\\times 2$ (2 channels):"]}, {"cell_type": "code", "execution_count": 15, "id": "e73db108", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:54.730666Z", "iopub.status.busy": "2023-03-14T16:08:54.730376Z", "iopub.status.idle": "2023-03-14T16:08:55.071855Z", "shell.execute_reply": "2023-03-14T16:08:55.071060Z"}, "papermill": {"duration": 0.361902, "end_time": "2023-03-14T16:08:55.074928", "exception": false, "start_time": "2023-03-14T16:08:54.713026", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.9/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)\n", " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n"]}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:08:54.954296\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:08:55.028934\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["checkerboard_mask = create_checkerboard_mask(h=8, w=8).expand(-1, 2, -1, -1)\n", "channel_mask = create_channel_mask(c_in=2).expand(-1, -1, 8, 8)\n", "\n", "show_imgs(checkerboard_mask.transpose(0, 1), \"Checkerboard mask\")\n", "show_imgs(channel_mask.transpose(0, 1), \"Channel mask\")"]}, {"cell_type": "markdown", "id": "1bda3665", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.017332, "end_time": "2023-03-14T16:08:55.118805", "exception": false, "start_time": "2023-03-14T16:08:55.101473", "status": "completed"}, "tags": []}, "source": ["As a last aspect of coupling layers, we need to decide for the deep neural network we want to apply in the coupling layers.\n", "The input to the layers is an image, and hence we stick with a CNN.\n", "Because the input to a transformation depends on all transformations before,\n", "it is crucial to ensure a good gradient flow through the CNN back to the input,\n", "which can be optimally achieved by a ResNet-like architecture.\n", "Specifically, we use a Gated ResNet that adds a $\\sigma$-gate to the skip connection,\n", "similarly to the input gate in LSTMs.\n", "The details are not necessarily important here, and the network is\n", "strongly inspired from Flow++ [3] in case you are interested in building\n", "even stronger models."]}, {"cell_type": "code", "execution_count": 16, "id": "de6ca2e1", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:55.155408Z", "iopub.status.busy": "2023-03-14T16:08:55.154889Z", "iopub.status.idle": "2023-03-14T16:08:55.171000Z", "shell.execute_reply": "2023-03-14T16:08:55.170454Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.037079, "end_time": "2023-03-14T16:08:55.173218", "exception": false, "start_time": "2023-03-14T16:08:55.136139", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class ConcatELU(nn.Module):\n", " \"\"\"Activation function that applies ELU in both direction (inverted and plain).\n", "\n", " Allows non-linearity while providing strong gradients for any input (important for final convolution)\n", " \"\"\"\n", "\n", " def forward(self, x):\n", " return torch.cat([F.elu(x), F.elu(-x)], dim=1)\n", "\n", "\n", "class LayerNormChannels(nn.Module):\n", " def __init__(self, c_in, eps=1e-5):\n", " \"\"\"\n", " This module applies layer norm across channels in an image.\n", " Inputs:\n", " c_in - Number of channels of the input\n", " eps - Small constant to stabilize std\n", " \"\"\"\n", " super().__init__()\n", " self.gamma = nn.Parameter(torch.ones(1, c_in, 1, 1))\n", " self.beta = nn.Parameter(torch.zeros(1, c_in, 1, 1))\n", " self.eps = eps\n", "\n", " def forward(self, x):\n", " mean = x.mean(dim=1, keepdim=True)\n", " var = x.var(dim=1, unbiased=False, keepdim=True)\n", " y = (x - mean) / torch.sqrt(var + self.eps)\n", " y = y * self.gamma + self.beta\n", " return y\n", "\n", "\n", "class GatedConv(nn.Module):\n", " def __init__(self, c_in, c_hidden):\n", " \"\"\"\n", " This module applies a two-layer convolutional ResNet block with input gate\n", " Args:\n", " c_in: Number of channels of the input\n", " c_hidden: Number of hidden dimensions we want to model (usually similar to c_in)\n", " \"\"\"\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " ConcatELU(),\n", " nn.Conv2d(2 * c_in, c_hidden, kernel_size=3, padding=1),\n", " ConcatELU(),\n", " nn.Conv2d(2 * c_hidden, 2 * c_in, kernel_size=1),\n", " )\n", "\n", " def forward(self, x):\n", " out = self.net(x)\n", " val, gate = out.chunk(2, dim=1)\n", " return x + val * torch.sigmoid(gate)\n", "\n", "\n", "class GatedConvNet(nn.Module):\n", " def __init__(self, c_in, c_hidden=32, c_out=-1, num_layers=3):\n", " \"\"\"Module that summarizes the previous blocks to a full convolutional neural network.\n", "\n", " Args:\n", " c_in: Number of input channels\n", " c_hidden: Number of hidden dimensions to use within the network\n", " c_out: Number of output channels. If -1, 2 times the input channels are used (affine coupling)\n", " num_layers: Number of gated ResNet blocks to apply\n", " \"\"\"\n", " super().__init__()\n", " c_out = c_out if c_out > 0 else 2 * c_in\n", " layers = []\n", " layers += [nn.Conv2d(c_in, c_hidden, kernel_size=3, padding=1)]\n", " for layer_index in range(num_layers):\n", " layers += [GatedConv(c_hidden, c_hidden), LayerNormChannels(c_hidden)]\n", " layers += [ConcatELU(), nn.Conv2d(2 * c_hidden, c_out, kernel_size=3, padding=1)]\n", " self.nn = nn.Sequential(*layers)\n", "\n", " self.nn[-1].weight.data.zero_()\n", " self.nn[-1].bias.data.zero_()\n", "\n", " def forward(self, x):\n", " return self.nn(x)"]}, {"cell_type": "markdown", "id": "3084a0b1", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.017371, "end_time": "2023-03-14T16:08:55.211629", "exception": false, "start_time": "2023-03-14T16:08:55.194258", "status": "completed"}, "tags": []}, "source": ["### Training loop\n", "\n", "Finally, we can add Dequantization, Variational Dequantization and Coupling Layers together to build our full normalizing flow on MNIST images.\n", "We apply 8 coupling layers in the main flow, and 4 for variational dequantization if applied.\n", "We apply a checkerboard mask throughout the network as with a single channel (black-white images),\n", "we cannot apply channel mask.\n", "The overall architecture is visualized below.\n", "\n", "\n", "
"]}, {"cell_type": "code", "execution_count": 17, "id": "1bec6c41", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:55.248636Z", "iopub.status.busy": "2023-03-14T16:08:55.248203Z", "iopub.status.idle": "2023-03-14T16:08:55.257903Z", "shell.execute_reply": "2023-03-14T16:08:55.257293Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.030947, "end_time": "2023-03-14T16:08:55.260082", "exception": false, "start_time": "2023-03-14T16:08:55.229135", "status": "completed"}, "tags": []}, "outputs": [], "source": ["def create_simple_flow(use_vardeq=True):\n", " flow_layers = []\n", " if use_vardeq:\n", " vardeq_layers = [\n", " CouplingLayer(\n", " network=GatedConvNet(c_in=2, c_out=2, c_hidden=16),\n", " mask=create_checkerboard_mask(h=28, w=28, invert=(i % 2 == 1)),\n", " c_in=1,\n", " )\n", " for i in range(4)\n", " ]\n", " flow_layers += [VariationalDequantization(var_flows=vardeq_layers)]\n", " else:\n", " flow_layers += [Dequantization()]\n", "\n", " for i in range(8):\n", " flow_layers += [\n", " CouplingLayer(\n", " network=GatedConvNet(c_in=1, c_hidden=32),\n", " mask=create_checkerboard_mask(h=28, w=28, invert=(i % 2 == 1)),\n", " c_in=1,\n", " )\n", " ]\n", "\n", " flow_model = ImageFlow(flow_layers).to(device)\n", " return flow_model"]}, {"cell_type": "markdown", "id": "a2bc5ef0", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.017328, "end_time": "2023-03-14T16:08:55.296857", "exception": false, "start_time": "2023-03-14T16:08:55.279529", "status": "completed"}, "tags": []}, "source": ["For implementing the training loop, we use the framework of PyTorch Lightning and reduce the code overhead.\n", "If interested, you can take a look at the generated tensorboard file,\n", "in particularly the graph to see an overview of flow transformations that are applied.\n", "Note that we again provide pre-trained models (see later on in the notebook)\n", "as normalizing flows are particularly expensive to train.\n", "We have also run validation and testing as this can take some time as well with the added importance sampling."]}, {"cell_type": "code", "execution_count": 18, "id": "25bb5564", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:55.333029Z", "iopub.status.busy": "2023-03-14T16:08:55.332680Z", "iopub.status.idle": "2023-03-14T16:08:55.346583Z", "shell.execute_reply": "2023-03-14T16:08:55.345982Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.034613, "end_time": "2023-03-14T16:08:55.348820", "exception": false, "start_time": "2023-03-14T16:08:55.314207", "status": "completed"}, "tags": []}, "outputs": [], "source": ["def train_flow(flow, model_name=\"MNISTFlow\"):\n", " # Create a PyTorch Lightning trainer\n", " trainer = L.Trainer(\n", " default_root_dir=os.path.join(CHECKPOINT_PATH, model_name),\n", " accelerator=\"auto\",\n", " devices=1,\n", " max_epochs=200,\n", " gradient_clip_val=1.0,\n", " callbacks=[\n", " ModelCheckpoint(save_weights_only=True, mode=\"min\", monitor=\"val_bpd\"),\n", " LearningRateMonitor(\"epoch\"),\n", " ],\n", " )\n", " trainer.logger._log_graph = True\n", " trainer.logger._default_hp_metric = None # Optional logging argument that we don't need\n", "\n", " train_data_loader = data.DataLoader(\n", " train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=8\n", " )\n", " result = None\n", "\n", " # Check whether pretrained model exists. If yes, load it and skip training\n", " pretrained_filename = os.path.join(CHECKPOINT_PATH, model_name + \".ckpt\")\n", " if os.path.isfile(pretrained_filename):\n", " print(\"Found pretrained model, loading...\")\n", " ckpt = torch.load(pretrained_filename, map_location=device)\n", " flow.load_state_dict(ckpt[\"state_dict\"])\n", " result = ckpt.get(\"result\", None)\n", " else:\n", " print(\"Start training\", model_name)\n", " trainer.fit(flow, train_data_loader, val_loader)\n", "\n", " # Test best model on validation and test set if no result has been found\n", " # Testing can be expensive due to the importance sampling.\n", " if result is None:\n", " val_result = trainer.test(flow, dataloaders=val_loader, verbose=False)\n", " start_time = time.time()\n", " test_result = trainer.test(flow, dataloaders=test_loader, verbose=False)\n", " duration = time.time() - start_time\n", " result = {\"test\": test_result, \"val\": val_result, \"time\": duration / len(test_loader) / flow.import_samples}\n", "\n", " return flow, result"]}, {"cell_type": "markdown", "id": "321c2785", "metadata": {"papermill": {"duration": 0.017358, "end_time": "2023-03-14T16:08:55.387653", "exception": false, "start_time": "2023-03-14T16:08:55.370295", "status": "completed"}, "tags": []}, "source": ["## Multi-scale architecture\n", "\n", "
\n", "\n", "One disadvantage of normalizing flows is that they operate on the exact same dimensions as the input.\n", "If the input is high-dimensional, so is the latent space, which requires larger computational cost to learn suitable transformations.\n", "However, particularly in the image domain, many pixels contain less information in the sense\n", "that we could remove them without loosing the semantical information of the image.\n", "\n", "Based on this intuition, deep normalizing flows on images commonly apply a multi-scale architecture [1].\n", "After the first $N$ flow transformations, we split off half of the latent dimensions and directly evaluate them on the prior.\n", "The other half is run through $N$ more flow transformations, and depending on the size of the input,\n", "we split it again in half or stop overall at this position.\n", "The two operations involved in this setup is `Squeeze` and `Split` which\n", "we will review more closely and implement below."]}, {"cell_type": "markdown", "id": "b67b73bc", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.017403, "end_time": "2023-03-14T16:08:55.422536", "exception": false, "start_time": "2023-03-14T16:08:55.405133", "status": "completed"}, "tags": []}, "source": ["### Squeeze and Split\n", "\n", "When we want to remove half of the pixels in an image, we have the problem of deciding which variables to cut,\n", "and how to rearrange the image.\n", "Thus, the squeezing operation is commonly used before split, which divides the image into subsquares\n", "of shape $2\\times 2\\times C$, and reshapes them into $1\\times 1\\times 4C$ blocks.\n", "Effectively, we reduce the height and width of the image by a factor of 2 while scaling the number of channels by 4.\n", "Afterwards, we can perform the split operation over channels without the need of rearranging the pixels.\n", "The smaller scale also makes the overall architecture more efficient.\n", "Visually, the squeeze operation should transform the input as follows:\n", "\n", "
\n", "\n", "The input of $4\\times 4\\times 1$ is scaled to $2\\times 2\\times 4$ following\n", "the idea of grouping the pixels in $2\\times 2\\times 1$ subsquares.\n", "Next, let's try to implement this layer:"]}, {"cell_type": "code", "execution_count": 19, "id": "1ae89a3f", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:55.458800Z", "iopub.status.busy": "2023-03-14T16:08:55.458447Z", "iopub.status.idle": "2023-03-14T16:08:55.467921Z", "shell.execute_reply": "2023-03-14T16:08:55.467065Z"}, "papermill": {"duration": 0.030217, "end_time": "2023-03-14T16:08:55.470193", "exception": false, "start_time": "2023-03-14T16:08:55.439976", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class SqueezeFlow(nn.Module):\n", " def forward(self, z, ldj, reverse=False):\n", " B, C, H, W = z.shape\n", " if not reverse:\n", " # Forward direction: H x W x C => H/2 x W/2 x 4C\n", " z = z.reshape(B, C, H // 2, 2, W // 2, 2)\n", " z = z.permute(0, 1, 3, 5, 2, 4)\n", " z = z.reshape(B, 4 * C, H // 2, W // 2)\n", " else:\n", " # Reverse direction: H/2 x W/2 x 4C => H x W x C\n", " z = z.reshape(B, C // 4, 2, 2, H, W)\n", " z = z.permute(0, 1, 4, 2, 5, 3)\n", " z = z.reshape(B, C // 4, H * 2, W * 2)\n", " return z, ldj"]}, {"cell_type": "markdown", "id": "414c6bf6", "metadata": {"papermill": {"duration": 0.017403, "end_time": "2023-03-14T16:08:55.508292", "exception": false, "start_time": "2023-03-14T16:08:55.490889", "status": "completed"}, "tags": []}, "source": ["Before moving on, we can verify our implementation by comparing our output with the example figure above:"]}, {"cell_type": "code", "execution_count": 20, "id": "c1f0cc17", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:55.544965Z", "iopub.status.busy": "2023-03-14T16:08:55.544742Z", "iopub.status.idle": "2023-03-14T16:08:55.551366Z", "shell.execute_reply": "2023-03-14T16:08:55.550861Z"}, "papermill": {"duration": 0.026587, "end_time": "2023-03-14T16:08:55.552899", "exception": false, "start_time": "2023-03-14T16:08:55.526312", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Image (before)\n", " tensor([[[[ 1, 2, 3, 4],\n", " [ 5, 6, 7, 8],\n", " [ 9, 10, 11, 12],\n", " [13, 14, 15, 16]]]])\n", "\n", "Image (forward)\n", " tensor([[[[ 1, 2, 5, 6],\n", " [ 3, 4, 7, 8]],\n", "\n", " [[ 9, 10, 13, 14],\n", " [11, 12, 15, 16]]]])\n", "\n", "Image (reverse)\n", " tensor([[[[ 1, 2, 3, 4],\n", " [ 5, 6, 7, 8],\n", " [ 9, 10, 11, 12],\n", " [13, 14, 15, 16]]]])\n"]}], "source": ["sq_flow = SqueezeFlow()\n", "rand_img = torch.arange(1, 17).view(1, 1, 4, 4)\n", "print(\"Image (before)\\n\", rand_img)\n", "forward_img, _ = sq_flow(rand_img, ldj=None, reverse=False)\n", "print(\"\\nImage (forward)\\n\", forward_img.permute(0, 2, 3, 1)) # Permute for readability\n", "reconst_img, _ = sq_flow(forward_img, ldj=None, reverse=True)\n", "print(\"\\nImage (reverse)\\n\", reconst_img)"]}, {"cell_type": "markdown", "id": "b64f0fc8", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.017584, "end_time": "2023-03-14T16:08:55.591185", "exception": false, "start_time": "2023-03-14T16:08:55.573601", "status": "completed"}, "tags": []}, "source": ["The split operation divides the input into two parts, and evaluates one part directly on the prior.\n", "So that our flow operation fits to the implementation of the previous layers,\n", "we will return the prior probability of the first part as the log determinant jacobian of the layer.\n", "It has the same effect as if we would combine all variable splits at the\n", "end of the flow, and evaluate them together on the prior."]}, {"cell_type": "code", "execution_count": 21, "id": "e144a578", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:55.628014Z", "iopub.status.busy": "2023-03-14T16:08:55.627560Z", "iopub.status.idle": "2023-03-14T16:08:55.634432Z", "shell.execute_reply": "2023-03-14T16:08:55.633786Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.027109, "end_time": "2023-03-14T16:08:55.636027", "exception": false, "start_time": "2023-03-14T16:08:55.608918", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class SplitFlow(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.prior = torch.distributions.normal.Normal(loc=0.0, scale=1.0)\n", "\n", " def forward(self, z, ldj, reverse=False):\n", " if not reverse:\n", " z, z_split = z.chunk(2, dim=1)\n", " ldj += self.prior.log_prob(z_split).sum(dim=[1, 2, 3])\n", " else:\n", " z_split = self.prior.sample(sample_shape=z.shape).to(device)\n", " z = torch.cat([z, z_split], dim=1)\n", " ldj -= self.prior.log_prob(z_split).sum(dim=[1, 2, 3])\n", " return z, ldj"]}, {"cell_type": "markdown", "id": "63b1fe64", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.017596, "end_time": "2023-03-14T16:08:55.676678", "exception": false, "start_time": "2023-03-14T16:08:55.659082", "status": "completed"}, "tags": []}, "source": ["### Building a multi-scale flow\n", "\n", "After defining the squeeze and split operation, we are finally able to build our own multi-scale flow.\n", "Deep normalizing flows such as Glow and Flow++ [2,3] often apply a split operation directly after squeezing.\n", "However, with shallow flows, we need to be more thoughtful about where to place the split operation as we need at least a minimum amount of transformations on each variable.\n", "Our setup is inspired by the original RealNVP architecture [1] which is shallower than other,\n", "more recent state-of-the-art architectures.\n", "\n", "Hence, for the MNIST dataset, we will apply the first squeeze operation after two coupling layers, but don't apply a split operation yet.\n", "Because we have only used two coupling layers and each the variable has been only transformed once, a split operation would be too early.\n", "We apply two more coupling layers before finally applying a split flow and squeeze again.\n", "The last four coupling layers operate on a scale of $7\\times 7\\times 8$.\n", "The full flow architecture is shown below.\n", "\n", "
\n", "\n", "Note that while the feature maps inside the coupling layers reduce with the height and width of the input,\n", "the increased number of channels is not directly considered.\n", "To counteract this, we increase the hidden dimensions for the coupling layers on the squeezed input.\n", "The dimensions are often scaled by 2 as this approximately increases the computation cost by 4 canceling with the squeezing operation.\n", "However, we will choose the hidden dimensionalities $32, 48, 64$ for the\n", "three scales respectively to keep the number of parameters reasonable\n", "and show the efficiency of multi-scale architectures."]}, {"cell_type": "code", "execution_count": 22, "id": "f65414af", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:55.713291Z", "iopub.status.busy": "2023-03-14T16:08:55.712818Z", "iopub.status.idle": "2023-03-14T16:08:55.723271Z", "shell.execute_reply": "2023-03-14T16:08:55.722130Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.031541, "end_time": "2023-03-14T16:08:55.725764", "exception": false, "start_time": "2023-03-14T16:08:55.694223", "status": "completed"}, "tags": []}, "outputs": [], "source": ["def create_multiscale_flow():\n", " flow_layers = []\n", "\n", " vardeq_layers = [\n", " CouplingLayer(\n", " network=GatedConvNet(c_in=2, c_out=2, c_hidden=16),\n", " mask=create_checkerboard_mask(h=28, w=28, invert=(i % 2 == 1)),\n", " c_in=1,\n", " )\n", " for i in range(4)\n", " ]\n", " flow_layers += [VariationalDequantization(vardeq_layers)]\n", "\n", " flow_layers += [\n", " CouplingLayer(\n", " network=GatedConvNet(c_in=1, c_hidden=32),\n", " mask=create_checkerboard_mask(h=28, w=28, invert=(i % 2 == 1)),\n", " c_in=1,\n", " )\n", " for i in range(2)\n", " ]\n", " flow_layers += [SqueezeFlow()]\n", " for i in range(2):\n", " flow_layers += [\n", " CouplingLayer(\n", " network=GatedConvNet(c_in=4, c_hidden=48), mask=create_channel_mask(c_in=4, invert=(i % 2 == 1)), c_in=4\n", " )\n", " ]\n", " flow_layers += [SplitFlow(), SqueezeFlow()]\n", " for i in range(4):\n", " flow_layers += [\n", " CouplingLayer(\n", " network=GatedConvNet(c_in=8, c_hidden=64), mask=create_channel_mask(c_in=8, invert=(i % 2 == 1)), c_in=8\n", " )\n", " ]\n", "\n", " flow_model = ImageFlow(flow_layers).to(device)\n", " return flow_model"]}, {"cell_type": "markdown", "id": "4d5371d1", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.017539, "end_time": "2023-03-14T16:08:55.765523", "exception": false, "start_time": "2023-03-14T16:08:55.747984", "status": "completed"}, "tags": []}, "source": ["We can show the difference in number of parameters below:"]}, {"cell_type": "code", "execution_count": 23, "id": "3fd1623f", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:55.802358Z", "iopub.status.busy": "2023-03-14T16:08:55.801981Z", "iopub.status.idle": "2023-03-14T16:08:57.208234Z", "shell.execute_reply": "2023-03-14T16:08:57.207398Z"}, "papermill": {"duration": 1.426841, "end_time": "2023-03-14T16:08:57.209967", "exception": false, "start_time": "2023-03-14T16:08:55.783126", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Number of parameters: 556,312\n", "Number of parameters: 628,388\n", "Number of parameters: 1,711,818\n"]}], "source": ["def print_num_params(model):\n", " num_params = sum(np.prod(p.shape) for p in model.parameters())\n", " print(f\"Number of parameters: {num_params:,}\")\n", "\n", "\n", "print_num_params(create_simple_flow(use_vardeq=False))\n", "print_num_params(create_simple_flow(use_vardeq=True))\n", "print_num_params(create_multiscale_flow())"]}, {"cell_type": "markdown", "id": "ee63fa46", "metadata": {"papermill": {"duration": 0.028226, "end_time": "2023-03-14T16:08:57.262815", "exception": false, "start_time": "2023-03-14T16:08:57.234589", "status": "completed"}, "tags": []}, "source": ["Although the multi-scale flow has almost 3 times the parameters of the single scale flow,\n", "it is not necessarily more computationally expensive than its counterpart.\n", "We will compare the runtime in the following experiments as well."]}, {"cell_type": "markdown", "id": "fa5919dc", "metadata": {"papermill": {"duration": 0.017593, "end_time": "2023-03-14T16:08:57.301622", "exception": false, "start_time": "2023-03-14T16:08:57.284029", "status": "completed"}, "tags": []}, "source": ["## Analysing the flows\n", "\n", "In the last part of the notebook, we will train all the models we have implemented above,\n", "and try to analyze the effect of the multi-scale architecture and variational dequantization.\n", "\n", "### Training flow variants\n", "\n", "Before we can analyse the flow models, we need to train them first.\n", "We provide pre-trained models that contain the validation and test performance, and run-time information.\n", "As flow models are computationally expensive, we advice you to rely on\n", "those pretrained models for a first run through the notebook."]}, {"cell_type": "code", "execution_count": 24, "id": "655f8984", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:57.338023Z", "iopub.status.busy": "2023-03-14T16:08:57.337827Z", "iopub.status.idle": "2023-03-14T16:08:57.789595Z", "shell.execute_reply": "2023-03-14T16:08:57.788785Z"}, "papermill": {"duration": 0.472976, "end_time": "2023-03-14T16:08:57.792209", "exception": false, "start_time": "2023-03-14T16:08:57.319233", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Found pretrained model, loading...\n", "Found pretrained model, loading...\n", "Found pretrained model, loading...\n"]}], "source": ["flow_dict = {\"simple\": {}, \"vardeq\": {}, \"multiscale\": {}}\n", "flow_dict[\"simple\"][\"model\"], flow_dict[\"simple\"][\"result\"] = train_flow(\n", " create_simple_flow(use_vardeq=False), model_name=\"MNISTFlow_simple\"\n", ")\n", "flow_dict[\"vardeq\"][\"model\"], flow_dict[\"vardeq\"][\"result\"] = train_flow(\n", " create_simple_flow(use_vardeq=True), model_name=\"MNISTFlow_vardeq\"\n", ")\n", "flow_dict[\"multiscale\"][\"model\"], flow_dict[\"multiscale\"][\"result\"] = train_flow(\n", " create_multiscale_flow(), model_name=\"MNISTFlow_multiscale\"\n", ")"]}, {"cell_type": "markdown", "id": "f704f49e", "metadata": {"papermill": {"duration": 0.017977, "end_time": "2023-03-14T16:08:57.833000", "exception": false, "start_time": "2023-03-14T16:08:57.815023", "status": "completed"}, "tags": []}, "source": ["### Density modeling and sampling\n", "\n", "Firstly, we can compare the models on their quantitative results.\n", "The following table shows all important statistics.\n", "The inference time specifies the time needed to determine the\n", "probability for a batch of 64 images for each model, and the sampling\n", "time the duration it took to sample a batch of 64 images."]}, {"cell_type": "code", "execution_count": 25, "id": "5e593b47", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:57.870186Z", "iopub.status.busy": "2023-03-14T16:08:57.869973Z", "iopub.status.idle": "2023-03-14T16:08:57.878622Z", "shell.execute_reply": "2023-03-14T16:08:57.878204Z"}, "papermill": {"duration": 0.029186, "end_time": "2023-03-14T16:08:57.880201", "exception": false, "start_time": "2023-03-14T16:08:57.851015", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/html": ["\n", "\n"], "text/plain": [""]}, "metadata": {}, "output_type": "display_data"}], "source": ["%%html\n", "\n", ""]}, {"cell_type": "code", "execution_count": 26, "id": "80d4f2a9", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:57.922444Z", "iopub.status.busy": "2023-03-14T16:08:57.922013Z", "iopub.status.idle": "2023-03-14T16:08:57.936216Z", "shell.execute_reply": "2023-03-14T16:08:57.935832Z"}, "papermill": {"duration": 0.035789, "end_time": "2023-03-14T16:08:57.938405", "exception": false, "start_time": "2023-03-14T16:08:57.902616", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/html": ["\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Model Validation Bpd Test Bpd Inference time Sampling time Num Parameters
simple 1.080 bpd 1.078 bpd 20 ms 18 ms 556,312
vardeq 1.045 bpd 1.043 bpd 26 ms 18 ms 628,388
multiscale1.022 bpd 1.020 bpd 23 ms 15 ms 1,711,818
"], "text/plain": [""]}, "metadata": {}, "output_type": "display_data"}], "source": ["\n", "table = [\n", " [\n", " key,\n", " \"%4.3f bpd\" % flow_dict[key][\"result\"][\"val\"][0][\"test_bpd\"],\n", " \"%4.3f bpd\" % flow_dict[key][\"result\"][\"test\"][0][\"test_bpd\"],\n", " \"%2.0f ms\" % (1000 * flow_dict[key][\"result\"][\"time\"]),\n", " \"%2.0f ms\" % (1000 * flow_dict[key][\"result\"].get(\"samp_time\", 0)),\n", " \"{:,}\".format(sum(np.prod(p.shape) for p in flow_dict[key][\"model\"].parameters())),\n", " ]\n", " for key in flow_dict\n", "]\n", "display(\n", " HTML(\n", " tabulate.tabulate(\n", " table,\n", " tablefmt=\"html\",\n", " headers=[\"Model\", \"Validation Bpd\", \"Test Bpd\", \"Inference time\", \"Sampling time\", \"Num Parameters\"],\n", " )\n", " )\n", ")"]}, {"cell_type": "markdown", "id": "b4bc8486", "metadata": {"papermill": {"duration": 0.018627, "end_time": "2023-03-14T16:08:57.980586", "exception": false, "start_time": "2023-03-14T16:08:57.961959", "status": "completed"}, "tags": []}, "source": ["As we have intially expected, using variational dequantization improves upon standard dequantization in terms of bits per dimension.\n", "Although the difference with 0.04bpd doesn't seem impressive first, it is a considerably step for generative models\n", "(most state-of-the-art models improve upon previous models in a range of 0.02-0.1bpd on CIFAR with three times as high bpd).\n", "While it takes longer to evaluate the probability of an image due to the variational dequantization,\n", "which also leads to a longer training time, it does not have an effect on the sampling time.\n", "This is because inverting variational dequantization is the same as dequantization: finding the next lower integer.\n", "\n", "When we compare the two models to multi-scale architecture, we can see that the bits per dimension score again dropped by about 0.04bpd.\n", "Additionally, the inference time and sampling time improved notably despite having more parameters.\n", "Thus, we see that the multi-scale flow is not only stronger for density modeling, but also more efficient.\n", "\n", "Next, we can test the sampling quality of the models.\n", "We should note that the samples for variational dequantization and standard dequantization are very similar,\n", "and hence we visualize here only the ones for variational dequantization and the multi-scale model.\n", "However, feel free to also test out the `\"simple\"` model.\n", "The seeds are set to obtain reproducable generations and are not cherry picked."]}, {"cell_type": "code", "execution_count": 27, "id": "e1cbb1be", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:08:58.019235Z", "iopub.status.busy": "2023-03-14T16:08:58.018946Z", "iopub.status.idle": "2023-03-14T16:09:00.213140Z", "shell.execute_reply": "2023-03-14T16:09:00.212673Z"}, "papermill": {"duration": 2.214889, "end_time": "2023-03-14T16:09:00.214392", "exception": false, "start_time": "2023-03-14T16:08:57.999503", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 44\n"]}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:00.147318\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["L.seed_everything(44)\n", "samples = flow_dict[\"vardeq\"][\"model\"].sample(img_shape=[16, 1, 28, 28])\n", "show_imgs(samples.cpu())"]}, {"cell_type": "code", "execution_count": 28, "id": "e60cb58e", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:09:00.254458Z", "iopub.status.busy": "2023-03-14T16:09:00.254291Z", "iopub.status.idle": "2023-03-14T16:09:00.375395Z", "shell.execute_reply": "2023-03-14T16:09:00.374906Z"}, "papermill": {"duration": 0.142231, "end_time": "2023-03-14T16:09:00.376562", "exception": false, "start_time": "2023-03-14T16:09:00.234331", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 44\n"]}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:00.310180\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["L.seed_everything(44)\n", "samples = flow_dict[\"multiscale\"][\"model\"].sample(img_shape=[16, 8, 7, 7])\n", "show_imgs(samples.cpu())"]}, {"cell_type": "markdown", "id": "9c6d84f7", "metadata": {"papermill": {"duration": 0.019612, "end_time": "2023-03-14T16:09:00.415899", "exception": false, "start_time": "2023-03-14T16:09:00.396287", "status": "completed"}, "tags": []}, "source": ["From the few samples, we can see a clear difference between the simple and the multi-scale model.\n", "The single-scale model has only learned local, small correlations while the multi-scale model was able to learn full,\n", "global relations that form digits.\n", "This show-cases another benefit of the multi-scale model.\n", "In contrast to VAEs, the outputs are sharp as normalizing flows can naturally model complex,\n", "multi-modal distributions while VAEs have the independent decoder output noise.\n", "Nevertheless, the samples from this flow are far from perfect as not all samples show true digits."]}, {"cell_type": "markdown", "id": "944d6269", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.019671, "end_time": "2023-03-14T16:09:00.455179", "exception": false, "start_time": "2023-03-14T16:09:00.435508", "status": "completed"}, "tags": []}, "source": ["### Interpolation in latent space\n", "\n", "Another popular test for the smoothness of the latent space of generative models is to interpolate between two training examples.\n", "As normalizing flows are strictly invertible, we can guarantee that any image is represented in the latent space.\n", "We again compare the variational dequantization model with the multi-scale model below."]}, {"cell_type": "code", "execution_count": 29, "id": "7970cd2a", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:09:00.495723Z", "iopub.status.busy": "2023-03-14T16:09:00.495365Z", "iopub.status.idle": "2023-03-14T16:09:00.530122Z", "shell.execute_reply": "2023-03-14T16:09:00.529642Z"}, "papermill": {"duration": 0.056622, "end_time": "2023-03-14T16:09:00.531336", "exception": false, "start_time": "2023-03-14T16:09:00.474714", "status": "completed"}, "tags": []}, "outputs": [], "source": ["@torch.no_grad()\n", "def interpolate(model, img1, img2, num_steps=8):\n", " \"\"\"\n", " Args:\n", " model: object of ImageFlow class that represents the (trained) flow model\n", " img1, img2: Image tensors of shape [1, 28, 28]. Images between which should be interpolated.\n", " num_steps: Number of interpolation steps. 8 interpolation steps mean 6 intermediate pictures besides img1 and img2\n", " \"\"\"\n", " imgs = torch.stack([img1, img2], dim=0).to(model.device)\n", " z, _ = model.encode(imgs)\n", " alpha = torch.linspace(0, 1, steps=num_steps, device=z.device).view(-1, 1, 1, 1)\n", " interpolations = z[0:1] * alpha + z[1:2] * (1 - alpha)\n", " interp_imgs = model.sample(interpolations.shape[:1] + imgs.shape[1:], z_init=interpolations)\n", " show_imgs(interp_imgs, row_size=8)\n", "\n", "\n", "exmp_imgs, _ = next(iter(train_loader))"]}, {"cell_type": "code", "execution_count": 30, "id": "2e17956c", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:09:00.571381Z", "iopub.status.busy": "2023-03-14T16:09:00.571218Z", "iopub.status.idle": "2023-03-14T16:09:00.781262Z", "shell.execute_reply": "2023-03-14T16:09:00.780473Z"}, "papermill": {"duration": 0.231591, "end_time": "2023-03-14T16:09:00.782599", "exception": false, "start_time": "2023-03-14T16:09:00.551008", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 42\n"]}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:00.640028\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1R5cGUgL0NhdGFsb2cgL1BhZ2VzIDIgMCBSID4+CmVuZG9iago4IDAgb2JqCjw8IC9Gb250IDMgMCBSIC9YT2JqZWN0IDcgMCBSIC9FeHRHU3RhdGUgNCAwIFIgL1BhdHRlcm4gNSAwIFIKL1NoYWRpbmcgNiAwIFIgL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gPj4KZW5kb2JqCjExIDAgb2JqCjw8IC9UeXBlIC9QYWdlIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovTWVkaWFCb3ggWyAwIDAgNjQzLjI5NzUgOTcuNTYgXSAvQ29udGVudHMgOSAwIFIgL0Fubm90cyAxMCAwIFIgPj4KZW5kb2JqCjkgMCBvYmoKPDwgL0xlbmd0aCAxMiAwIFIgL0ZpbHRlciAvRmxhdGVEZWNvZGUgPj4Kc3RyZWFtCnicVY7LDoJADEX3/Yr7BfPCeS1VkolLdOEHTEaUgAZJ5PctLiAuTtLTtLeVdfk8cjmnA44XkpvliTQ6poVCx8zQSExLim0gt6uEid6y9JtEL6zjhlqrO9GNRnhhfjgTRFgmQyW0w7vgiifknoMnTu+YmSMT/n8ZeTEKE7Ac51Vr1sQ8QJ406hcaaugLxjwu9wplbmRzdHJlYW0KZW5kb2JqCjEyIDAgb2JqCjE0NAplbmRvYmoKMTAgMCBvYmoKWyBdCmVuZG9iagozIDAgb2JqCjw8ID4+CmVuZG9iago0IDAgb2JqCjw8IC9BMSA8PCAvVHlwZSAvRXh0R1N0YXRlIC9DQSAxIC9jYSAxID4+ID4+CmVuZG9iago1IDAgb2JqCjw8ID4+CmVuZG9iago2IDAgb2JqCjw8ID4+CmVuZG9iago3IDAgb2JqCjw8IC9JMSAxMyAwIFIgPj4KZW5kb2JqCjEzIDAgb2JqCjw8IC9UeXBlIC9YT2JqZWN0IC9TdWJ0eXBlIC9JbWFnZSAvV2lkdGggODc0IC9IZWlnaHQgMTE2Ci9Db2xvclNwYWNlIFsvSW5kZXhlZCAvRGV2aWNlUkdCIDI0NSAo/////v7+/f39/Pz8+/v7+vr6+fn5+Pj49/f39vb29fX19PT08/Pz8vLy8fHx8PDw7+/v7u7u7e3t7Ozs6+vr6urq6enp6Ojo5+fn5ubm5eXl5OTk4+Pj4uLi4eHh4ODg39/f3t7e3d3d3Nzc29vb2tra2dnZ19fX1tbW1dXV1NTU09PT0tLS0dHR0NDQz8/Pzs7Ozc3NzMzMy8vLysrKycnJyMjIx8fHxsbGxcXFxMTEw8PDwsLCwcHBwMDAv7+/vr6+vb29vLy8u7u7urq6ubm5uLi4t7e3tra2tbW1tLS0s7OzsrKysbGxsLCwr6+vrq6ura2trKysq6urqqqqqampqKiop6enpqampaWlpKSko6OjoqKioaGhoKCgn5+fnp6enZ2dnJycm5ubmpqamZmZmJiYl5eXlpaWlZWVlJSUk5OTkpKSkZGRkJCQj4+Pjo6OjY2NjIyMi4uLiYmJiIiIh4eHhoaGhYWFg4ODgoKCgICAf39/fn5+fX19fHx8e3t7enp6d3d3dnZ2dXV1c3NzcnJycXFxcHBwb29vbm5ubW1tbGxsa2trampqaWlpZ2dnZmZmZWVlZGRkY2NjYmJiYWFhYGBgX19fXl5eXV1dXFxcXFxcW1tbWlpaWVlZV1dXVlZWVVVVVFRUU1NTUlJSUVFRUFBQT09PTk5OTU1NTExMS0tLSkpKSUlJSEhIR0dHRkZGRUVFREREQ0NDQkJCQUFBQEBAPz8/Pj4+PT09PDw8Ozs7Ojo6OTk5ODg4NjY2NTU1NDQ0MzMzMjIyMTExMDAwLy8vLi4uLS0tLCwsKysrKioqXClcKVwpXChcKFwoJycnJiYmJSUlJCQkIyMjIiIiISEhICAgHx8fHh4eHR0dHBwcGxsbGhoaGRkZGBgYFxcXFhYWFRUVFBQUExMTEhISEREREBAQDw8PDg4OXHJcclxyDAwMCwsLXG5cblxuCQkJCAgIBwcHBgYGBQUFBAQEAwMDAgICAQEBAAAAKV0KL0JpdHNQZXJDb21wb25lbnQgOCAvRmlsdGVyIC9GbGF0ZURlY29kZQovRGVjb2RlUGFybXMgPDwgL1ByZWRpY3RvciAxMCAvQ29sb3JzIDEgL0NvbHVtbnMgODc0ID4+IC9MZW5ndGggMTQgMCBSID4+CnN0cmVhbQp4nO2de2DVVh3H216kRUHAObC8QatlZWoZMGAM5aFzDJE3bAg4FAXnpt18IFMR2EB8YHECik6QrfiY4qa8EZQ5HHMKm+M1sbDKqzAsCNz//X7Jud7cm5tzTm7TAPH3+af0JjffnOT3CU1yclLwsCAIEVBwtVdAEP4/ENUEIRJENUGIBFFNECJBVBOESBDVBCESUqolo+DhqxAW16zYNizKrEjDRLXrNCu2DYsyS1QLMyyuWbFtWJRZolqYYXHNim3DoswS1cIMi2tWbBsWZZaoFmZYXLNi27Aos0S1MMPimhXbhkWZJaqFGRbXrNg2LMqsoGG1tbW3gw0bNuQRJqpdp1mxbViUWaJamGFxzYptw6LMEtXCDItrVmwbFmWWqBZmWFyzYtuwIF85f/58FejevfujIHhWkLDXwfDhwwvB2LFjg6ylqHZ9Z8W2YUG+IqoZMq/ZPXc9ZcW2YUG+IqoZMq/ZPXc9ZcW2YUG+IqoZMq/ZPRdlFs6w7wMTJkxYCoJn2YadBt8ArVu3LgAzZ84MspYBG3b48OFPgAKH6aDpsrABa4cNG1bkcDcInmUb9hoYCgodqqurA2SJalc5S1RrZJaoZpF5zZZ/lFmiWiOzRDWLTNsw/nX8LTBv3rxzoKGhIY8w29l37dr1ToC91gJ8BEybNq0ONEHWN0HHjh0TAHnNwE3glVdeCZBlG3YvSDgUOY1Du6YdBwHCLLOS48aNK0ozEth+M3gWTpjGprP2gOBZtmFsidLsRoBTxABZoloWopomzDJLVNOFiWoKUU0TZpklqunCRDWFqKYJs8wS1XRhQVVbBcrKytg65I4GKwFOvQNkWoVduHDBVSVsZElJSSXgGgQIs5n1KECtuypS/XMwaLBQPEAWz7C7AxWEdqmsRLdu3WzUDlglbweqHHmpQrWNal+6dMn47YDV0bdvX5dq2GEl7dq1Y69Bmy8HzLoZqCA05jKw/WYy8EZ8F1CqLQcBglxholpSVDOGWWaJarowUS0pqhnDLLNENV2YqJYU1Yxhllmimi7MWrVTp07dA7j50iWZqpKKiorfActMY9hJ8OCDDybSFKZL8o0ADto20KJdpz4NcmbxisV8cPHixTCy2Gfvh4CL9W5EfPgVYDpsBayScuA6hqhkrkFVVdVBEErDFP37908dqNIbsRjwGLl3794ws1yqjR8/3vZbmVlBVWvVqtXPQH5hopqoFkLDFKKaJkxUE9VCaJhCVNOEWauGMzJuN/zt3e6xxx5bCFxVAlqCLVu22GQaw1ylz6B+/fq5yl/xS2ATZsoaBdQie4IOHTp4sxIvvPBCGFlJVHczRfPmzXnQcGXhQ9bNggULbLKsquTll19W54WppSvNmjm/zgOhNExx6623plTDIbka7Ny583HAbTp79uwws3B4r0AW22faPb5ZQVUbOnRowCBXmKgmqoXQMIWopgkT1US1EBqmENU0YaKaqBZCwxSimibMSrV/gLKyMm5BbLqd+OTvYK3DcKCKBa3mFQ1TpjZsM2gLsDgU3YKLYOPGjTVgyZIlrupnh8Xa2lpjA7UNW7FiBeudy+vVq9cBgJPezwN14SJ1/QLHF14dOXPmTP5ZvwalpaVc5BvA+vXrPwew1NSVJpUFByeD54Auy1gl9QCHkVQTunTpwrEB8El7MAjgw7eCbdu2+S8kQEV+GRQXFxep8j948KC613UCoOGlAwYM0C4gQBbn42ZDFg+VNl/JmWUVtmjRIm5Eqnb//fcHzxLViKgmqhkR1UQ1PaKaOevaU42P5KNVrMH/gMyJ/GTx4sXca5hnFjBlasO+BNTu2gRcUyDdb0BXoE48pk6damyg73SuN2pdFXkzHDDUBFYIzjq3uG5BKb337duXXxb5EVDnMW8C586d46cow/EgfbqWynoG6LK0YZfAQJBwbqb1BhcuXOCU06dP89g4E6isiRMn+i/IsiJ5s+nNQN3n2gXUFGxLTqRqWAntQmyr/8UXXyxTfQPBn4DxK35ZxjB2o8N6q05ZhdpjuylMVBPVRDVfRDVRTVQT1UQ1H0Q1UU2fifLadwPAtuLVkdwznT179h0AKzNRt+eMDdy+fXsrwAp48sknc8/DfoJqM/fs2dN3UcY9x3pQd5hYfPv371cT/gU44EEig/cA3+6Jxh2HhnEBalm8FFNVVfUxkH52J4u/AF2Wtkr2grcBpTYvwwwZMmQMaNmyZerOl9qI5eXl/guyrMjU8VHB4+2MGTMmgGnTpjGddYoi+gNobFayurralcWjLw5bXwe8zNSnTx8eyHwPVJlZxrDVoDCNUq2mpuYJxVNPPWVcYVFNVBPVTFmimqgmqolqoppPw0Q1UU2f+VnAHVNRUeG/sE2bNqmd9xNgyvQPUyMccHg49Tg+b2Hj1FRN58gEHIRAFQp2p7GBunXepHYXb8jj/J0Fc9ddd70XZNU95uFN37yzkmvWrOHxKuFc+Sh2bvWmF16UGcesi37PyNlUCUu6KDMh0y/1K3/ccsstF9Q1kzwaBurq6njVo8gCzue7HFvVnEPIFVD97JSLJqjfCP/ZBuBoacoyhvUHLtXuvPNOHq8SzkONpKSkhJ9o722LaqKaqGbKEtVENVEtfqrxD27uI9/XpfAezqRJkzgPatR/r9k0cPTo0VwOb9Gg9DkWwTqgXjKyfv36LwBVPizZI0eOGBvoO52qqVtLUHbGHXfcwZMK9UlW+eO3H4C8s8gikMi4dZZlQRr27+RRRZelDfsjSFWdCeTd6zuAhE3DsJe6AJus5gDHs/yzSNeuXXMvvTATzPdzoMsKqpovHTp02A20DRPVRDVRzRdRTVQT1UQ1Uc0HUU1U02Sqrcjd77s2fKpMlQi00K64toF874zzFNUV1R566CFW5YfASy+9xBmef/55Vzn+GNiE+U7nqMyusk/Xe5YMpLKyUtvXzmbHnQGq72ZOv7L82wh0WcYqmQt0fnGEY75hpch5F07uPMssHqu8IW8BNzp0AmosiRYtWvBIkHdWMj1cM+ElJiy5I+gBunXrxk6y7JFbeOVW3g05X/CTp2rt2rUr9cBqLSgoYHptzk6SopqoJqqZskQ1UU1UE9V0iGqimqiWO+nYsWOp3e+9NX38+HGOVsFTXTWPacW1DTwCXKWWcEa+WOpcjuHrZ5TzykXfPoJZYb7T7wZeqbDj2AuTO9C1Mqa3j9jsOO4GnV+psUWw83hZ5KdAl6UNY8dVdgX0et2yZUt6gR177MSJE/zhPNaW4NAg+TUM8MU2rqzFAMfHfwPXcCxPP/00Sx/TvwPyzkp27949FYSS2JJzRMRfgbZt23KenAPc56Ha+0B9fb13Hl5fg4Och6PT+IaJaqKaqOaLqCaqiWoxUy3p3FfjGrsejjl8+DAfLnBejJfiPmBacWMDH3jgAVfh8YGcj4MhQ4aot/Cl7hX17dvXlGXcc6nOnZ56L/LS6NGNL1++zNZ481zyKdVw0sHzpkb1gfwgSB0mijI6Pw4ePNi1Tt8F7du354SvgjwaBpYtW8bDrSvrt8A7H2pejYrxbZBfFp8OcpZxBd+B6f4JcMrGlcmpYh6qPQpyzsOHrsrLy9kZM/djZqKaqCaqacOSopqoJqqJanpENVFNVPMNmwO4xjiffr8CO4eftGnThuWDsu+L30JRDSfsmWWffdVCfWoa58wV5jud7w/PGeCVoSjnvZkAWbwnmBrJwYJfAFOWtko+AFwiuyzYvn07u5ZyZL5x48ax16fK5MEzj4aBWbNmuY5KXNZsz0tn2INUKVJWVnYI5Jf1e+AK69Gjh/fyGHXkQ42YzuLNuRxb1QYApdow4B3xor6+fiiAZtt8R2QX1UQ1UU0blhTVRDVRLY6qnQefBM7j+lcoKSlhrvojuL1jHl+nblpxYwNxfoK//JcVZJ4GFhcX87UjLi9GjhxpyjLuOb5TpygDr2Opf+per2OTlVQDEPiHpUEzvwa43XVZ/mFbt25Vb2lJjWngyiv630vhC9MffgqcBXk0DPTr1y+rbbwviKMFz1x27949BbhO5ubPn597OTZZrOfMLL7S8ujRoxySug6sXbuWfbTURL7WSJdlVG0qcHXMcrL4cBMfHaP3OCnmBJux20Q1UU1U80VUE9VENVFNh6gmqolqpky+woTd8vie0PT6sw8dH5WAhweBaRk2DWQbjgPYxCrgyF/I4/Bxffr0SZWjhdbGPcehGVatWuXqwskLLji/vck74vCrr77aqKzks88+mymxpi+kw5+BLss/bOXKlcaszE/9B3MwZbGbY69evVKLLEpfhmndujWfnHHJ7nzYeseOHfllEQ5J3bt370TmYgcOHPhhwEt26sYry3/69Om5jx9Je9VQ9adRiS7bhoDZs2fzfx/Xp4888oj/QkQ1UU1U04YlRTVRTVQT1fSIaqKaqGbKzA3lY7O14yh7Mq3C4BcfGvs+UJ9UVlYyq0I7pLknTDtPQ0PD7cClGo4k3wNqX/IHb/Pu3LmzcVl79uwxuZXFX4Euyz9s8uTJRtVcwIeTvpd9TFl8Mi0zLMuCTHz7EdpkEQ6jV1VVxTEUfEMIO1Rg55qyrCoRRxNe5ynMSUlJCR+p9Ouw6g4T1UQ1UU2LqGYZpp1HVBPVjFwbqvEhcLYMBWoze5AGzpkz52bwmvMaGj7Bj1axce8GAcJMs3GnqM6WPEm7fPnyPUDtsZSAI0aMaFzW3r17syocS98POnfuXOQKc/gM4I0iXVbuMA5b4LzmJtsv9dZGT1ZiypQpeTZs3bp17NjpWpxr6emyT/02atSo+pwjBthkEWwuFsDq1av5rkSX2on0m2FKS0tXAN+7kplZtmXPm3eogRFpxXiK/zioqakxfltUcxDVRDUjoppNmGk2UU1UMyKq2YSZZhPVRDUjV1u1AwcO8J4aG9oEqk11UL+xw6LaqB1AgDDTbBuAWjLKvjPE661u3LhqdOHChY3Lwrbylv5HAfZf1oROnTppB97TbkS+8dT7EB4LvX///hy7Iv2AGuHs2vuh2oZBUtcjhbx8NHjw4ErAGkRxrgHV1dU8mOFHdV1dnX+QKYtHns2bN5eDHTt2cIwB9rCcNGkSX1Nr8Mo/K8//YfILE9VENVGtaRHVHEQ1Ua2JEdUcRDVRrYlpvGqoh9QuawLVDh06dDo9lAMHEm8a1f4GOBKFumyQ8PBF0Og+kK+//jpvxquKzB1UdGVg7C5LliyxycodxvfllJWVpRY5aNAgDjPCETImT57MG78qvQWYO3fuHpB/w2pqanhZ5LbbbnsG0C/tskxos9jZctmyZae9o3s0JktUI6KaPktUCyFLVCOimj5LVAshS1Tz8gRQWaiksgBhNrPyLszSpUt5xy6z+vmsfoOuL13wLA4LtxW4HglSLF++PPfouDmzgmxE3s46efIkH/zAyRnPnZ4DNt+8ChUZZZao5kVUy84S1ULIEtW8iGrZWaJaCFmimhdRLTtLVAsh63pRDSXIF1Dy7ZTBdlweYazBMWPG8AaR/9P5OcOCZ+VB3KskrlmimhdRLTsrtg2LMktU8yKqZWfFtmFRZl0vquWdGds9F2VWbBsWZZaoFmZYXLNi27Aos0S1MMPimhXbhkWZJaqFGRbXrNg2LMosUS3MsLhmxbZhUWaJamGGxTUrtg2LMktUCzMsrlmxbViUWaJamGFxzYptw6LMuiqqCYLQpIhqghAJopogRIKoJgiRIKoJQiSIaoIQCaKaIETCfwHkuJnRCmVuZHN0cmVhbQplbmRvYmoKMTQgMCBvYmoKNDI4OAplbmRvYmoKMiAwIG9iago8PCAvVHlwZSAvUGFnZXMgL0tpZHMgWyAxMSAwIFIgXSAvQ291bnQgMSA+PgplbmRvYmoKMTUgMCBvYmoKPDwgL0NyZWF0b3IgKE1hdHBsb3RsaWIgdjMuNy4xLCBodHRwczovL21hdHBsb3RsaWIub3JnKQovUHJvZHVjZXIgKE1hdHBsb3RsaWIgcGRmIGJhY2tlbmQgdjMuNy4xKSAvQ3JlYXRpb25EYXRlIChEOjIwMjMwMzE0MTYwOTAwWikKPj4KZW5kb2JqCnhyZWYKMCAxNgowMDAwMDAwMDAwIDY1NTM1IGYgCjAwMDAwMDAwMTYgMDAwMDAgbiAKMDAwMDAwNjA2MCAwMDAwMCBuIAowMDAwMDAwNTk2IDAwMDAwIG4gCjAwMDAwMDA2MTcgMDAwMDAgbiAKMDAwMDAwMDY3NyAwMDAwMCBuIAowMDAwMDAwNjk4IDAwMDAwIG4gCjAwMDAwMDA3MTkgMDAwMDAgbiAKMDAwMDAwMDA2NSAwMDAwMCBuIAowMDAwMDAwMzM3IDAwMDAwIG4gCjAwMDAwMDA1NzYgMDAwMDAgbiAKMDAwMDAwMDIwOCAwMDAwMCBuIAowMDAwMDAwNTU2IDAwMDAwIG4gCjAwMDAwMDA3NTEgMDAwMDAgbiAKMDAwMDAwNjAzOSAwMDAwMCBuIAowMDAwMDA2MTIwIDAwMDAwIG4gCnRyYWlsZXIKPDwgL1NpemUgMTYgL1Jvb3QgMSAwIFIgL0luZm8gMTUgMCBSID4+CnN0YXJ0eHJlZgo2MjcxCiUlRU9GCg==", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:00.743094\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["L.seed_everything(42)\n", "for i in range(2):\n", " interpolate(flow_dict[\"vardeq\"][\"model\"], exmp_imgs[2 * i], exmp_imgs[2 * i + 1])"]}, {"cell_type": "code", "execution_count": 31, "id": "3ca1ae47", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:09:00.828856Z", "iopub.status.busy": "2023-03-14T16:09:00.828634Z", "iopub.status.idle": "2023-03-14T16:09:01.038969Z", "shell.execute_reply": "2023-03-14T16:09:01.038355Z"}, "papermill": {"duration": 0.236091, "end_time": "2023-03-14T16:09:01.041318", "exception": false, "start_time": "2023-03-14T16:09:00.805227", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 42\n"]}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:00.902316\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:01.001884\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["L.seed_everything(42)\n", "for i in range(2):\n", " interpolate(flow_dict[\"multiscale\"][\"model\"], exmp_imgs[2 * i], exmp_imgs[2 * i + 1])"]}, {"cell_type": "markdown", "id": "648f4caa", "metadata": {"papermill": {"duration": 0.020808, "end_time": "2023-03-14T16:09:01.086975", "exception": false, "start_time": "2023-03-14T16:09:01.066167", "status": "completed"}, "tags": []}, "source": ["The interpolations of the multi-scale model result in more realistic digits\n", "(first row $7\\leftrightarrow 8\\leftrightarrow 6$, second row $9\\leftrightarrow 4\\leftrightarrow 6$),\n", "while the variational dequantization model focuses on local patterns that globally do not form a digit.\n", "For the multi-scale model, we actually did not do the \"true\" interpolation between the two images\n", "as we did not consider the variables that were split along the flow (they have been sampled randomly for all samples).\n", "However, as we will see in the next experiment, the early variables do not effect the overall image much."]}, {"cell_type": "markdown", "id": "2628a523", "metadata": {"papermill": {"duration": 0.020664, "end_time": "2023-03-14T16:09:01.128363", "exception": false, "start_time": "2023-03-14T16:09:01.107699", "status": "completed"}, "tags": []}, "source": ["### Visualization of latents in different levels of multi-scale\n", "\n", "In the following we will focus more on the multi-scale flow.\n", "We want to analyse what information is being stored in the variables split at early layers,\n", "and what information for the final variables.\n", "For this, we sample 8 images where each of them share the same final latent variables,\n", "but differ in the other part of the latent variables.\n", "Below we visualize three examples of this:"]}, {"cell_type": "code", "execution_count": 32, "id": "52225998", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:09:01.171113Z", "iopub.status.busy": "2023-03-14T16:09:01.170822Z", "iopub.status.idle": "2023-03-14T16:09:01.406005Z", "shell.execute_reply": "2023-03-14T16:09:01.405490Z"}, "papermill": {"duration": 0.259501, "end_time": "2023-03-14T16:09:01.408520", "exception": false, "start_time": "2023-03-14T16:09:01.149019", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 44\n"]}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:01.214433\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/pdf": "JVBERi0xLjQKJazcIKu6CjEgMCBvYmoKPDwgL1R5cGUgL0NhdGFsb2cgL1BhZ2VzIDIgMCBSID4+CmVuZG9iago4IDAgb2JqCjw8IC9Gb250IDMgMCBSIC9YT2JqZWN0IDcgMCBSIC9FeHRHU3RhdGUgNCAwIFIgL1BhdHRlcm4gNSAwIFIKL1NoYWRpbmcgNiAwIFIgL1Byb2NTZXQgWyAvUERGIC9UZXh0IC9JbWFnZUIgL0ltYWdlQyAvSW1hZ2VJIF0gPj4KZW5kb2JqCjExIDAgb2JqCjw8IC9UeXBlIC9QYWdlIC9QYXJlbnQgMiAwIFIgL1Jlc291cmNlcyA4IDAgUgovTWVkaWFCb3ggWyAwIDAgMzQxLjY3NDgzODcwOTcgMTgwLjcyIF0gL0NvbnRlbnRzIDkgMCBSIC9Bbm5vdHMgMTAgMCBSID4+CmVuZG9iago5IDAgb2JqCjw8IC9MZW5ndGggMTIgMCBSIC9GaWx0ZXIgL0ZsYXRlRGVjb2RlID4+CnN0cmVhbQp4nFWOSw7CMAxE9z7FnCDfKkmXQKWIZWHBAaJQiCioVKLXx61AhcWzPJbHHtnk1zXlQ9xidyS5qjSSRmE6KBRmgkZkOlKserKVFs5XwdYsb79SByW84Zla2wvRmQZ4YRas4TpvB69qD+2csAbPjBPukBv+MvKrwkx8PeI/2LD4HeYgH+v3cOoh9xrNAy219AYPKzF0CmVuZHN0cmVhbQplbmRvYmoKMTIgMCBvYmoKMTQ4CmVuZG9iagoxMCAwIG9iagpbIF0KZW5kb2JqCjMgMCBvYmoKPDwgPj4KZW5kb2JqCjQgMCBvYmoKPDwgL0ExIDw8IC9UeXBlIC9FeHRHU3RhdGUgL0NBIDEgL2NhIDEgPj4gPj4KZW5kb2JqCjUgMCBvYmoKPDwgPj4KZW5kb2JqCjYgMCBvYmoKPDwgPj4KZW5kb2JqCjcgMCBvYmoKPDwgL0kxIDEzIDAgUiA+PgplbmRvYmoKMTMgMCBvYmoKPDwgL1R5cGUgL1hPYmplY3QgL1N1YnR5cGUgL0ltYWdlIC9XaWR0aCA0NTUgL0hlaWdodCAyMzEKL0NvbG9yU3BhY2UgWy9JbmRleGVkIC9EZXZpY2VSR0IgMjM1ICj////+/v79/f38/Pz7+/v6+vr5+fn4+Pj39/f29vb19fX09PTz8/Py8vLx8fHw8PDv7+/u7u7t7e3s7Ozr6+vq6urp6eno6Ojn5+fm5ubl5eXk5OTj4+Pi4uLh4eHg4ODf39/e3t7d3d3c3Nzb29va2trZ2dnY2NjX19fW1tbV1dXT09PS0tLR0dHQ0NDPz8/Ozs7MzMzLy8vKysrJycnIyMjHx8fGxsbFxcXExMTCwsLBwcHAwMC/v7++vr69vb28vLy7u7u6urq5ubm4uLi3t7e2tra1tbW0tLSzs7OysrKxsbGwsLCvr6+urq6srKyrq6uqqqqpqamoqKinp6empqalpaWkpKSjo6OhoaGgoKCfn5+dnZ2cnJybm5uampqZmZmYmJiXl5eVlZWTk5OQkJCPj4+Ojo6NjY2MjIyLi4uKioqJiYmIiIiHh4eGhoaFhYWEhISDg4OCgoKBgYGAgIB/f39+fn59fX18fHx6enp5eXl4eHh3d3d2dnZ1dXV0dHRycnJxcXFwcHBvb29ubm5tbW1sbGxra2tqamppaWloaGhnZ2dlZWVkZGRjY2NhYWFfX19eXl5cXFxcXFxbW1taWlpZWVlYWFhWVlZVVVVUVFRTU1NSUlJRUVFQUFBPT09OTk5NTU1MTExLS0tKSkpJSUlISEhGRkZFRUVERERDQ0NCQkJBQUFAQEA/Pz8+Pj49PT08PDw7Ozs6Ojo5OTk3Nzc2NjY1NTU0NDQzMzMyMjIxMTEwMDAvLy8uLi4tLS0sLCwrKysqKipcKVwpXClcKFwoXCgmJiYlJSUkJCQjIyMiIiIhISEgICAfHx8eHh4dHR0cHBwbGxsaGhoZGRkYGBgXFxcWFhYVFRUUFBQTExMSEhIREREQEBAPDw8ODg5cclxyXHIMDAwLCwtcblxuXG4JCQkICAgHBwcGBgYFBQUEBAQDAwMCAgIBAQEAAAApXQovQml0c1BlckNvbXBvbmVudCA4IC9GaWx0ZXIgL0ZsYXRlRGVjb2RlCi9EZWNvZGVQYXJtcyA8PCAvUHJlZGljdG9yIDEwIC9Db2xvcnMgMSAvQ29sdW1ucyA0NTUgPj4gL0xlbmd0aCAxNCAwIFIgPj4Kc3RyZWFtCnic7Z3/f1V1HcdhEG3ALPlijsgJiQSG1FSC7AthKqRUYiCBgYWxUpTUvhDfhEIjCZFRkLAsoqwcgUWNvqGFgE3SlQgunZNGEMg/0evlPuexD9u5977P2bny8LPX8xcfj93tPu/nPC+e75/T42siBHqc7Q8gMkEdw0Adw0Adw0Adw0AdwyDqePqNQr6i+NQxDJ86huFTxzB86hiGTx3D8KljGD51DMOnjmH41DEMnzqG4VPHMHzqGIZPHcPwqWMYPnUMw6eOYfjUMQyfOobhU8cwfOoYhk8dw/CpYxi+1B1fBYcPH94JHgaHQBJvYt/LoLGx8Sfge+DgwYO2v0vrO3r06BHwW/BjAHVxfSdPnnwG3Aq+DZL61DEH6mjyJvapY16fOuagO3TcCz4HStp5N9i2bZvZm8jXAGYBz1deXv5dUBzfH8DYsWM9XcmFF174EiiOby1Yvnx5LxD5Bg8e/BQw+9QxBnU0exP51LGgL3HHu0EfUNKZOXPmmL1m300A0co7664H2fs+Bd4COvtuAdn7Pg56g86+ucDsU8cOqKM6pvGpozp6PnvHp8GoUaN6Alr6g+HDh/cAkXj8+PFmb+Ff/CcYNmxYtPK/GCxYsCDSg0EgQx+ZPn36OcAJ/O0O9wky9LW0tHwGDAHxvgngFDD51NFHHdUxqU8d1bGzz9ZxM0DDUc5wKfgdmD17dimIxGPGjMlonPOAt/K/BPBYwIYNG7yOPPSQke+H4JPAvTO+n8Pr6+sbgbdsR4CMfPeD0aNHRwvuMrBv374XgdfxcmD2qSNRR3VM4zsrHaurq3s6ysrKsAqZ3gT4SmVlpf//9JEjR2Ywzg8ANnTvzLXSvYAvrVmzxhvnXSADXw3g27lvCAYxkl9bvLIceD6snRdk4NsE3grwju8EXwbPA7yC/e85nu8rwOxTR3VUxzQ+dVTH3D5Dx/nz5/NtuVy3bNniv+KdguSrK1euzGCcY92ZP77j7t27XwHRSzNnzvTG+QDIwMdX+Xb4jpZt3LixFbhXbgCejzu0Gfj49XPfm1rgvfINEMnw6r+A2aeO6qiOaXzqqI65fYaOWJgfA/8A3k+fAN42jm2n9bRhnA+CK8CTwPvxNNB+KLLkuuuuM+kK+x4Fi4EXkBw4cICbWJFv2bJl/wMZ+A6CiWDPnj3+j+vr66Nj1W8D+Cfyb2D2qaM6qmManzqqY25f6vsC7gNujDxRxrOTpr9L67sHuGPkVQCL2fZ3XfBFDccA620PqX3YxonOAXwImP9OHQv61LGwN7FPHfP6UnXkPh3PCbpxXgOSehP5uErxrn3kpYJF9T0LKioq6OKR3oPmm7tS+njvWvvZ3ZIOu5UmnzrGoI5mbyKfOhb0qWMM3aXjPtAPuPXyeHASJPEm8u0AvBDI+cYB89+m8T0C2s4Qvs614MSJE8Xz/QmUlZX52znHjx9P5FPHGNTR7E3kU8eCPnWMobt0JLvA0KFDeUyQ/k+DbI4/xsCr43nxSo826Lv66qsTHX9M5CNVVVVRR/pg59ZI8XyzZs3q6XHBBRd0OOCb36eOOVBHkzeRSx0L+tQxB92lI3nhhRdmgl6OhoYGszexi0uxsbHxIuB03Doonu/QoUNHAW8icTl/Borna25u5hB5SStlGB+n0jL71DEH6mjyJnapY15fV+fRXQ+i9Ujb1e02b1rfR4C7Y2A3KLrvC8B9b3iateg+ryNvSzD71DE/6pjfm9anjvE+dcxPd+m4BET7WKa5X7vo+yxwXxvOOll038z2exEeA0X3Vbn9VizOPwKzTx3zo475vWl96hjvU8f8BN2Rk8rfCQYOHBjNjvJhcOzYMbPX7HoNVIM7wNSpU3l3Gzc6+vXr9xzI3kd4nze+lLxbYABwHU+YzkGm8fGAw9atW7khFU2lOWLEiBZg9qljDOpo9ppd6mjy2Tr+GfBiwIULF3JuO/Qb2GFSOz5Tw/SZTb5fAj6p47bbbuNF+d70HYRjXbt2rUln8/0U/AgsWrToq+AdAA6u86nltLp1dXUZ+n4D8F3hd3MK4EFcb3FyNo9169Yl8qkjUUd1TONTR3WM9xk6btq0iaPxZsyNlumNN974TcADndld9/g30BdQ4PpFYJuDD1/pcGd0F33YwNjKUrwL2PtiQsfvy9cBnzuVne99wC3PkjO5CiwDz5sOVJ/pU0d1VMc0PnVUx9w+Q8fa2lqOk2t9DIw7/LxwnTd3WdvFeHP/Ap8oySv/3ffmfMDpLThA7C5n7+NcyJzq2I3xo2ASWLVq1X9BcXycXN5dwDka8CDu5s2bzf8WYnzqqI7qmNanjtn7uk3HpqYmRuMWhumKGJM3/y9tA/WgoaHBNstiF30ZcpZ86pgx6uihjkl9Xb2uI61Xvmx96hiGTx3D8KljGD51DMOnjmH41DEMnzqG4Ys6ijc36hgG6hgG6hgG6hgG6hgG6hgG2n8Mw6eOYfjUMQyfOobhU8cwfOoYhk8dw/CpYxg+dQzDp45h+NQxDJ86huFTxzB86hiGTx3D8KljGD51DMOnjmH41DEMnzqG4VPHMHzqGIZPHcPwpe74H9Dc3PwDcD9oND2E5XT6cXJO/aamJs5c/CCw6rriexl8H/C5AeYpJtP6WlpaDoOfA84tnXR86pjbp46FvYl96pjXl6rjLwAfGuLNATsIFO15xfvBUOD5Jk6c+CtQHN9BMGXKFH+O2z59+kwFhZ/Fksa3B7Q/h/l1zj333HtB4b9Vx9yoo9mbyKeOBX3qGEN36Dgd8GFhJZ2ZN2+e2Wv2cenxwSSddZ8A3NrK1nc3iB/fKPAqyNY3C/QDnX3XgMJ/r47xqKM6pvGpozp6PntHbtxg6yJ60ORg4J7l64nN3sK/yI2bcePGcfr3SOCeGeooB0+DjHzcC588eTLfuWcbHKLnexdoBhn5WltbOXd+tDxLS0s5273nezso/Dbq2AF1VMc0vrPR8VuAD310hksBn4q+fv16v+OQIUPM3vy/dB+oAO6N+QzIJ8CBAwe8x0DyU2Tg4wOY+Qj79wLvnR9//PHVwBteRj7C5Yl/BPyy8J2HgZ07d/Knno9PoTT71FEd1THt+NRRHeN9ho6LFi1yK35KbwU8/shX6urqoo9CzjvvvAzGyacfngPcOz8AopObNTU13jj54MYMfBtBNL4BAwbw0czuAYx8BIzn455eBj5uL/Zpe/oyH+F1CzgO8MpFwPNVArNPHdVRHdP41FEdc/sMHefPn8+3xY5qaYen2iNpJOX2RzbHV1cCviOX6/79+/2XMGRvnE0gA9+XQOTbvn2798pdwPP9HWTgWwro6t27Nx8F6b0yA3g+btmZfeqojuqYxqeO6pjbZ+jY0NCwEBwC3k+fBeXl5ZGUj/gtLD1tGOcz4GKwBWAXPfoxl2JZWVl0HGDFihXHQAa+vwAuUB4n934Me2/vuPyVV16ZUcdXALetOlzNtHfvXv+UwNy5cxONTx3VUR3T+M5ax3j4uGTnHAkQu8H0d2l9vFje+a4CBa/n6Krv9ttvj5bpJcDtMRfP1744S64AbXuwdp865kAdTd7EPnXM61PHHHSHjs8BXhuIvSBud9gOdJ7pTeTDbus6itxRUI6zqD6ejqysrIyW68OgqD5u2g0aNCjahuMh16Q+dYxBHc3eRD51LOhTxxi6S8cnAa/Vd+ceJ4GTIIk3kY83d/cFzvceUNTtnEdB2/B6RttVBe+X64qvDnjHHD4I2s5I2n3qGIM6mr2JfOpY0KeOMXSXjoQr45tvvjm6suUO0NraavYmcp0CC4A31gkTJti+OmnH545bRy2nTZv2Iiieb8eOHdHY6Kyurk40PnXMgTqavIlc6ljQ16V5Ao8cOXItiG79wkcxexO7jrZNMMV1hzsvyCsUi+cDvNzyMtCrDa7FiudDNN5czWPVbj9y+5lXmeT3qWNu1LGwN7FLHfP61DE33aZjm/4kJ5Zgx1WrVpm9aX3MOQCUZHZfQH42ALf5wb3Kovt4x6DryKkfzT51zI865vem9aljvE8d89NdOn4RRPuu9fX1Zm9aHw8HuOX6V1B0353A+QrP8ZKB7/MgjU8d86OO+b1pfeoY70vVkXtSvA5x0qRJ0bXznOwJqy6z1ybCypd7ibwHmjce9O/fP7r2obS0lBfYZ+xzvAR27dpFZZm7FwHDzOj+rhh2gqVLl3IKxvOB21/lpzD71DEGdTR7bSJ1tPrUMYZgO+4APNy3ZMkSLlPOROhWxm6Mvfk8FtNnNvl+D34NVqxY0cvDOXk8F7rXvLvquuh7DDwFFi9efAPgl9IbHy9hqampydD3CHjooYfuAe8HnLnSDY8+jq+2tjaRTx2JOqpjGp86qmO8z9aRUx2WnAmFVVVVk8F3QGGhdZy8CSASdHBWVFTcBGwnAq2+taC/twXVDmfV4nGA5803sVl83HrixPJ9+/b1h4n/8J8HJ89Hw9qkPnVUR3VM41NHdczts3Xklf/uSk5eIDMRcPinTp0y+zp4c/8CJ9Hnir9HG5eD2YBTlZgmIknqw7bbEnfTAcfHaZ+HgNWrVxd+yEsa3xrADSk3yQpvcrgeLF261DbRSrxPHdVRHdP4zmbHDDH5sLdWswu8UT6us7gXPGPGjDfElyHq6KGOab35f0kdk/rUkahjWq982frUMQyfOobhU8cwfOoYhk8dw/CpYxg+dQzDF3UUb27UMQzUMQzUMQzUMQzUMQzUMQz+D9hCFMoKZW5kc3RyZWFtCmVuZG9iagoxNCAwIG9iagozNTMyCmVuZG9iagoyIDAgb2JqCjw8IC9UeXBlIC9QYWdlcyAvS2lkcyBbIDExIDAgUiBdIC9Db3VudCAxID4+CmVuZG9iagoxNSAwIG9iago8PCAvQ3JlYXRvciAoTWF0cGxvdGxpYiB2My43LjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcpCi9Qcm9kdWNlciAoTWF0cGxvdGxpYiBwZGYgYmFja2VuZCB2My43LjEpIC9DcmVhdGlvbkRhdGUgKEQ6MjAyMzAzMTQxNjA5MDFaKQo+PgplbmRvYmoKeHJlZgowIDE2CjAwMDAwMDAwMDAgNjU1MzUgZiAKMDAwMDAwMDAxNiAwMDAwMCBuIAowMDAwMDA1Mjg1IDAwMDAwIG4gCjAwMDAwMDA2MDcgMDAwMDAgbiAKMDAwMDAwMDYyOCAwMDAwMCBuIAowMDAwMDAwNjg4IDAwMDAwIG4gCjAwMDAwMDA3MDkgMDAwMDAgbiAKMDAwMDAwMDczMCAwMDAwMCBuIAowMDAwMDAwMDY1IDAwMDAwIG4gCjAwMDAwMDAzNDQgMDAwMDAgbiAKMDAwMDAwMDU4NyAwMDAwMCBuIAowMDAwMDAwMjA4IDAwMDAwIG4gCjAwMDAwMDA1NjcgMDAwMDAgbiAKMDAwMDAwMDc2MiAwMDAwMCBuIAowMDAwMDA1MjY0IDAwMDAwIG4gCjAwMDAwMDUzNDUgMDAwMDAgbiAKdHJhaWxlcgo8PCAvU2l6ZSAxNiAvUm9vdCAxIDAgUiAvSW5mbyAxNSAwIFIgPj4Kc3RhcnR4cmVmCjU0OTYKJSVFT0YK", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:01.290435\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:01.368132\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["L.seed_everything(44)\n", "for _ in range(3):\n", " z_init = flow_dict[\"multiscale\"][\"model\"].prior.sample(sample_shape=[1, 8, 7, 7])\n", " z_init = z_init.expand(8, -1, -1, -1)\n", " samples = flow_dict[\"multiscale\"][\"model\"].sample(img_shape=z_init.shape, z_init=z_init)\n", " show_imgs(samples.cpu())"]}, {"cell_type": "markdown", "id": "994633e5", "metadata": {"papermill": {"duration": 0.021725, "end_time": "2023-03-14T16:09:01.456409", "exception": false, "start_time": "2023-03-14T16:09:01.434684", "status": "completed"}, "tags": []}, "source": ["We see that the early split variables indeed have a smaller effect on the image.\n", "Still, small differences can be spot when we look carefully at the borders of the digits.\n", "For instance, the hole at the top of the 8 changes for different samples although all of them represent the same coarse structure.\n", "This shows that the flow indeed learns to separate the higher-level\n", "information in the final variables, while the early split ones contain\n", "local noise patterns."]}, {"cell_type": "markdown", "id": "5df5dcc5", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.021783, "end_time": "2023-03-14T16:09:01.499950", "exception": false, "start_time": "2023-03-14T16:09:01.478167", "status": "completed"}, "tags": []}, "source": ["### Visualizing Dequantization\n", "\n", "As a final part of this notebook, we will look at the effect of variational dequantization.\n", "We have motivated variational dequantization by the issue of sharp edges/boarders being difficult to model,\n", "and a flow would rather prefer smooth, prior-like distributions.\n", "To check how what noise distribution $q(u|x)$ the flows in the\n", "variational dequantization module have learned, we can plot a histogram\n", "of output values from the dequantization and variational dequantization\n", "module."]}, {"cell_type": "code", "execution_count": 33, "id": "c062358b", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:09:01.545124Z", "iopub.status.busy": "2023-03-14T16:09:01.544837Z", "iopub.status.idle": "2023-03-14T16:09:01.574131Z", "shell.execute_reply": "2023-03-14T16:09:01.573585Z"}, "papermill": {"duration": 0.054855, "end_time": "2023-03-14T16:09:01.576533", "exception": false, "start_time": "2023-03-14T16:09:01.521678", "status": "completed"}, "tags": []}, "outputs": [], "source": ["def visualize_dequant_distribution(model: ImageFlow, imgs: Tensor, title: str = None):\n", " \"\"\"\n", " Args:\n", " model: The flow of which we want to visualize the dequantization distribution\n", " imgs: Example training images of which we want to visualize the dequantization distribution\n", " \"\"\"\n", " imgs = imgs.to(device)\n", " ldj = torch.zeros(imgs.shape[0], dtype=torch.float32).to(device)\n", " with torch.no_grad():\n", " dequant_vals = []\n", " for _ in tqdm(range(8), leave=False):\n", " d, _ = model.flows[0](imgs, ldj, reverse=False)\n", " dequant_vals.append(d)\n", " dequant_vals = torch.cat(dequant_vals, dim=0)\n", " dequant_vals = dequant_vals.view(-1).cpu().numpy()\n", " sns.set()\n", " plt.figure(figsize=(10, 3))\n", " plt.hist(dequant_vals, bins=256, color=to_rgb(\"C0\") + (0.5,), edgecolor=\"C0\", density=True)\n", " if title is not None:\n", " plt.title(title)\n", " plt.show()\n", " plt.close()\n", "\n", "\n", "sample_imgs, _ = next(iter(train_loader))"]}, {"cell_type": "code", "execution_count": 34, "id": "e66ef8de", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:09:01.626826Z", "iopub.status.busy": "2023-03-14T16:09:01.626504Z", "iopub.status.idle": "2023-03-14T16:09:02.283549Z", "shell.execute_reply": "2023-03-14T16:09:02.282977Z"}, "papermill": {"duration": 0.684094, "end_time": "2023-03-14T16:09:02.288395", "exception": false, "start_time": "2023-03-14T16:09:01.604301", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "bbe213511b3245368a116d011cf7a97d", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/8 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:01.959626\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["visualize_dequant_distribution(flow_dict[\"simple\"][\"model\"], sample_imgs, title=\"Dequantization\")"]}, {"cell_type": "code", "execution_count": 35, "id": "50826bcf", "metadata": {"execution": {"iopub.execute_input": "2023-03-14T16:09:02.341693Z", "iopub.status.busy": "2023-03-14T16:09:02.341310Z", "iopub.status.idle": "2023-03-14T16:09:03.127501Z", "shell.execute_reply": "2023-03-14T16:09:03.126890Z"}, "papermill": {"duration": 0.815664, "end_time": "2023-03-14T16:09:03.131581", "exception": false, "start_time": "2023-03-14T16:09:02.315917", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "b2b8c9cb3fdd4e77bf734fc7a847d4ff", "version_major": 2, "version_minor": 0}, "text/plain": [" 0%| | 0/8 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-03-14T16:09:02.796830\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["visualize_dequant_distribution(flow_dict[\"vardeq\"][\"model\"], sample_imgs, title=\"Variational dequantization\")"]}, {"cell_type": "markdown", "id": "1c94ee1f", "metadata": {"papermill": {"duration": 0.026779, "end_time": "2023-03-14T16:09:03.188381", "exception": false, "start_time": "2023-03-14T16:09:03.161602", "status": "completed"}, "tags": []}, "source": ["The dequantization distribution in the first plot shows that the MNIST images have a strong bias towards 0 (black),\n", "and the distribution of them have a sharp border as mentioned before.\n", "The variational dequantization module has indeed learned a much smoother distribution with a Gaussian-like curve which can be modeled much better.\n", "For the other values, we would need to visualize the distribution $q(u|x)$ on a deeper level, depending on $x$.\n", "However, as all $u$'s interact and depend on each other, we would need\n", "to visualize a distribution in 784 dimensions, which is not that\n", "intuitive anymore."]}, {"cell_type": "markdown", "id": "ae3bae45", "metadata": {"papermill": {"duration": 0.026724, "end_time": "2023-03-14T16:09:03.241886", "exception": false, "start_time": "2023-03-14T16:09:03.215162", "status": "completed"}, "tags": []}, "source": ["## Conclusion\n", "\n", "In conclusion, we have seen how to implement our own normalizing flow, and what difficulties arise if we want to apply them on images.\n", "Dequantization is a crucial step in mapping the discrete images into continuous space to prevent underisable delta-peak solutions.\n", "While dequantization creates hypercubes with hard border, variational dequantization allows us to fit a flow much better on the data.\n", "This allows us to obtain a lower bits per dimension score, while not affecting the sampling speed.\n", "The most common flow element, the coupling layer, is simple to implement, and yet effective.\n", "Furthermore, multi-scale architectures help to capture the global image context while allowing us to efficiently scale up the flow.\n", "Normalizing flows are an interesting alternative to VAEs as they allow an exact likelihood estimate in continuous space,\n", "and we have the guarantee that every possible input $x$ has a corresponding latent vector $z$.\n", "However, even beyond continuous inputs and images, flows can be applied and allow us to exploit\n", "the data structure in latent space, as e.g. on graphs for the task of molecule generation [6].\n", "Recent advances in [Neural ODEs](https://arxiv.org/pdf/1806.07366.pdf) allow a flow with infinite number of layers,\n", "called Continuous Normalizing Flows, whose potential is yet to fully explore.\n", "Overall, normalizing flows are an exciting research area which will continue over the next couple of years."]}, {"cell_type": "markdown", "id": "ab8843c8", "metadata": {"papermill": {"duration": 0.026911, "end_time": "2023-03-14T16:09:03.295703", "exception": false, "start_time": "2023-03-14T16:09:03.268792", "status": "completed"}, "tags": []}, "source": ["## References\n", "\n", "[1] Dinh, L., Sohl-Dickstein, J., and Bengio, S. (2017).\n", "\u201cDensity estimation using Real NVP,\u201d In: 5th International Conference on Learning Representations, ICLR 2017.\n", "[Link](https://arxiv.org/abs/1605.08803)\n", "\n", "[2] Kingma, D. P., and Dhariwal, P. (2018).\n", "\u201cGlow: Generative Flow with Invertible 1x1 Convolutions,\u201d In: Advances in Neural Information Processing Systems, vol.\n", "31, pp.\n", "10215--10224.\n", "[Link](http://papers.nips.cc/paper/8224-glow-generative-flow-with-invertible-1x1-convolutions.pdf)\n", "\n", "[3] Ho, J., Chen, X., Srinivas, A., Duan, Y., and Abbeel, P. (2019).\n", "\u201cFlow++: Improving Flow-Based Generative Models with Variational Dequantization and Architecture Design,\u201d\n", "in Proceedings of the 36th International Conference on Machine Learning, vol.\n", "97, pp.\n", "2722\u20132730.\n", "[Link](https://arxiv.org/abs/1902.00275)\n", "\n", "[4] Durkan, C., Bekasov, A., Murray, I., and Papamakarios, G. (2019).\n", "\u201cNeural Spline Flows,\u201d In: Advances in Neural Information Processing Systems, pp.\n", "7509\u20137520.\n", "[Link](http://papers.neurips.cc/paper/8969-neural-spline-flows.pdf)\n", "\n", "[5] Hoogeboom, E., Cohen, T. S., and Tomczak, J. M. (2020).\n", "\u201cLearning Discrete Distributions by Dequantization,\u201d arXiv preprint arXiv2001.11235v1.\n", "[Link](https://arxiv.org/abs/2001.11235)\n", "\n", "[6] Lippe, P., and Gavves, E. (2021).\n", "\u201cCategorical Normalizing Flows via Continuous Transformations,\u201d\n", "In: International Conference on Learning Representations, ICLR 2021.\n", "[Link](https://openreview.net/pdf?id=-GLNZeVDuik)"]}, {"cell_type": "markdown", "id": "24215753", "metadata": {"papermill": {"duration": 0.026354, "end_time": "2023-03-14T16:09:03.348421", "exception": false, "start_time": "2023-03-14T16:09:03.322067", "status": "completed"}, "tags": []}, "source": ["## Congratulations - Time to Join the Community!\n", "\n", "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n", "movement, you can do so in the following ways!\n", "\n", "### Star [Lightning](https://github.com/Lightning-AI/lightning) on GitHub\n", "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n", "tools we're building.\n", "\n", "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n", "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n", "and share your interests in `#general` channel\n", "\n", "\n", "### Contributions !\n", "The best way to contribute to our community is to become a code contributor! At any time you can go to\n", "[Lightning](https://github.com/Lightning-AI/lightning) or [Bolt](https://github.com/Lightning-AI/lightning-bolts)\n", "GitHub Issues page and filter for \"good first issue\".\n", "\n", "* [Lightning good first issue](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* [Bolt good first issue](https://github.com/Lightning-AI/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* You can also contribute your own notebooks with useful examples !\n", "\n", "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", "\n", "[![Pytorch Lightning](){height=\"60px\" width=\"240px\"}](https://pytorchlightning.ai)"]}, {"cell_type": "raw", "metadata": {"raw_mimetype": "text/restructuredtext"}, "source": [".. customcarditem::\n", " :header: Tutorial 9: Normalizing Flows for Image Modeling\n", " :card_description: In this tutorial, we will take a closer look at complex, deep normalizing flows. The most popular, current application of deep normalizing flows is to model datasets of...\n", " :tags: Image,GPU/TPU,UvA-DL-Course\n", " :image: _static/images/course_UvA-DL/09-normalizing-flows.jpg"]}], "metadata": {"jupytext": {"cell_metadata_filter": "id,colab_type,colab,-all", "formats": "ipynb,py:percent", "main_language": "python"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16"}, "papermill": {"default_parameters": {}, "duration": 23.837617, "end_time": "2023-03-14T16:09:04.999019", "environment_variables": {}, "exception": null, "input_path": "course_UvA-DL/09-normalizing-flows/NF_image_modeling.ipynb", "output_path": ".notebooks/course_UvA-DL/09-normalizing-flows.ipynb", "parameters": {}, "start_time": "2023-03-14T16:08:41.161402", "version": "2.4.0"}, "widgets": {"application/vnd.jupyter.widget-state+json": {"state": {"022d5e4feba64ae0928ed7f3872066d4": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "03403854507e4c1d82b2987d56a49a06": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_10dd727b74354aedafd5edc7db9549c8", "placeholder": "\u200b", "style": "IPY_MODEL_72e2bef6f95849f3b2a0157e2ae7d38f", "tabbable": null, "tooltip": null, "value": " 4542/4542 [00:00<00:00, 302461.36it/s]"}}, "04427c133ec74f93997b1bb09c212b93": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "07c5caa1cdc1422c953a73791e4deb9e": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "0aeadef176b2496588e4f17a7b736316": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_93e00f14a44948dcbd5b7b278de95bf1", "max": 1648877.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_35c8abf4b2d44eb1a9e990d070993da4", "tabbable": null, "tooltip": null, "value": 1648877.0}}, "0dfc77c724fe4dda8fa5ed6733af68d9": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_bb1e7ecd74f646189cde10378bfed99a", "placeholder": "\u200b", "style": "IPY_MODEL_758c5cc85a124fd6b1a1e5fc08e4578b", "tabbable": null, "tooltip": null, "value": " 0/8 [00:00<?, ?it/s]"}}, "0eb9bbfb6b8b423e90e7466594745546": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "10dd727b74354aedafd5edc7db9549c8": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "16f7dfe27ae0431f882ae8a7521045b1": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "1995f3296f964c10b7b2c669529f7ae2": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "1d86d0fabf6a402bb5546127da7d9bfd": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "1db7ae52a4e14296ad74930df9a97e79": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "21b86501888548e285936fadbc245102": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_c2ebe72c3bd64974b4786470efc97363", "placeholder": "\u200b", "style": "IPY_MODEL_1995f3296f964c10b7b2c669529f7ae2", "tabbable": null, "tooltip": null, "value": " 1648877/1648877 [00:00<00:00, 15388647.04it/s]"}}, "25676a4efd8540fb93d4cf022d30d53f": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_ced0fd121fc2474ba4c1abb075e13f5d", "placeholder": "\u200b", "style": "IPY_MODEL_e69479b2f0cf41f2b1b48ebb2ebf58d4", "tabbable": null, "tooltip": null, "value": " 0%"}}, "25d20e49e81241cebfe05a6a212c597c": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "29068e73798d4582a1e52fb849969cf1": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "29f6933ebb224f55bdab78facfbd0916": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "2c5126982c5e4f6698e08f10d0afc360": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_758427a471744b47ab6192598eefc69b", "max": 8.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_1db7ae52a4e14296ad74930df9a97e79", "tabbable": null, "tooltip": null, "value": 8.0}}, "312e269520644f68abf6c03135cf0e0d": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "339023c700c64174b9a147df03787e07": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "35c8abf4b2d44eb1a9e990d070993da4": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "411e2bdd0d104a759cf2219457a9ea10": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": "hidden", "width": null}}, "558441ee629d4bd9a4fdc90089a3586c": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "60985b1c703d4f8aa3aba175811d2a69": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_8a2bd283700345c0a750d978690229cc", "IPY_MODEL_0aeadef176b2496588e4f17a7b736316", "IPY_MODEL_21b86501888548e285936fadbc245102"], "layout": "IPY_MODEL_04427c133ec74f93997b1bb09c212b93", "tabbable": null, "tooltip": null}}, "6224fdd5ab06423482c0205dc09216d1": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d97bd4b5fe8a49499d119431d07cf0fb", "placeholder": "\u200b", "style": "IPY_MODEL_a0812f0a0f8e40a0991db82254c34e4f", "tabbable": null, "tooltip": null, "value": " 0%"}}, "66fc512f4207461b9e493495d88854e7": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "67a200eee455445bad5167f0917cb600": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_07c5caa1cdc1422c953a73791e4deb9e", "max": 4542.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_25d20e49e81241cebfe05a6a212c597c", "tabbable": null, "tooltip": null, "value": 4542.0}}, "6e93eead70d64710b21580c9dfdcab47": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_c496aa960e8f441a96ebcfb22e62e988", "IPY_MODEL_fa7f221d159f44a399abb85eb67a0201", "IPY_MODEL_e249257f48ab4642814161e5cbe3a3c4"], "layout": "IPY_MODEL_339023c700c64174b9a147df03787e07", "tabbable": null, "tooltip": null}}, "6ef3a3aa22d64d32972473557f20dc75": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": "hidden", "width": null}}, "70b1f09cd80c4b6d954c95765acf99ac": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "72e2bef6f95849f3b2a0157e2ae7d38f": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "758427a471744b47ab6192598eefc69b": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "758c5cc85a124fd6b1a1e5fc08e4578b": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "85c1da41f2e24252b1bfd646e70cdce8": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_c73775d8b0f64318b72d97215a38fdd9", "IPY_MODEL_67a200eee455445bad5167f0917cb600", "IPY_MODEL_03403854507e4c1d82b2987d56a49a06"], "layout": "IPY_MODEL_d5feb3c4282f42c4950e191fc88be2a9", "tabbable": null, "tooltip": null}}, "8a2bd283700345c0a750d978690229cc": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_ab54194ee2f84f5e83f08a888c5a305f", "placeholder": "\u200b", "style": "IPY_MODEL_ff9e6a2733e34e72bfb2541bdf073901", "tabbable": null, "tooltip": null, "value": "100%"}}, "8e868462f756444d9b1a566c8f77d457": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_d63d18e016b24e3786cada711468334a", "IPY_MODEL_a5e6be0f0d264bd68359c69613627269", "IPY_MODEL_ed5753a1b5e24a52b3de7d67a56394be"], "layout": "IPY_MODEL_1d86d0fabf6a402bb5546127da7d9bfd", "tabbable": null, "tooltip": null}}, "93b1209a50b74dabbb125bc942d22e42": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "93e00f14a44948dcbd5b7b278de95bf1": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "a0812f0a0f8e40a0991db82254c34e4f": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "a50cba50ade742a4b7490036e1e11dc5": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "a5e6be0f0d264bd68359c69613627269": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_29f6933ebb224f55bdab78facfbd0916", "max": 9912422.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_d6df66553ca0434198434c840620964a", "tabbable": null, "tooltip": null, "value": 9912422.0}}, "a9b352eea54247f68d5c4fb54771ffff": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_a50cba50ade742a4b7490036e1e11dc5", "placeholder": "\u200b", "style": "IPY_MODEL_cf866df37d1d4998924e0ec0647dc3bb", "tabbable": null, "tooltip": null, "value": " 0/8 [00:00<?, ?it/s]"}}, "ab54194ee2f84f5e83f08a888c5a305f": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "b2b8c9cb3fdd4e77bf734fc7a847d4ff": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_6224fdd5ab06423482c0205dc09216d1", "IPY_MODEL_2c5126982c5e4f6698e08f10d0afc360", "IPY_MODEL_0dfc77c724fe4dda8fa5ed6733af68d9"], "layout": "IPY_MODEL_411e2bdd0d104a759cf2219457a9ea10", "tabbable": null, "tooltip": null}}, "b86389928edd42a6a9e52ef9cd86f311": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "bb1e7ecd74f646189cde10378bfed99a": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "bbe213511b3245368a116d011cf7a97d": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_25676a4efd8540fb93d4cf022d30d53f", "IPY_MODEL_ed5fb11f3a9e47b69aad48f8bf1f7461", "IPY_MODEL_a9b352eea54247f68d5c4fb54771ffff"], "layout": "IPY_MODEL_6ef3a3aa22d64d32972473557f20dc75", "tabbable": null, "tooltip": null}}, "c2ebe72c3bd64974b4786470efc97363": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "c496aa960e8f441a96ebcfb22e62e988": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_0eb9bbfb6b8b423e90e7466594745546", "placeholder": "\u200b", "style": "IPY_MODEL_b86389928edd42a6a9e52ef9cd86f311", "tabbable": null, "tooltip": null, "value": "100%"}}, "c73775d8b0f64318b72d97215a38fdd9": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_ebe1de3fe436417e9648098be8d9b038", "placeholder": "\u200b", "style": "IPY_MODEL_558441ee629d4bd9a4fdc90089a3586c", "tabbable": null, "tooltip": null, "value": "100%"}}, "ced0fd121fc2474ba4c1abb075e13f5d": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "cf866df37d1d4998924e0ec0647dc3bb": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "d1b544192fff4a4da583c624d12890b3": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "d5feb3c4282f42c4950e191fc88be2a9": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "d63d18e016b24e3786cada711468334a": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_e21ced0da6c14114927d42a13bb051a8", "placeholder": "\u200b", "style": "IPY_MODEL_93b1209a50b74dabbb125bc942d22e42", "tabbable": null, "tooltip": null, "value": "100%"}}, "d6df66553ca0434198434c840620964a": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "d97bd4b5fe8a49499d119431d07cf0fb": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "e21ced0da6c14114927d42a13bb051a8": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "e249257f48ab4642814161e5cbe3a3c4": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_70b1f09cd80c4b6d954c95765acf99ac", "placeholder": "\u200b", "style": "IPY_MODEL_16f7dfe27ae0431f882ae8a7521045b1", "tabbable": null, "tooltip": null, "value": " 28881/28881 [00:00<00:00, 1838816.19it/s]"}}, "e69479b2f0cf41f2b1b48ebb2ebf58d4": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "ebe1de3fe436417e9648098be8d9b038": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "ed5753a1b5e24a52b3de7d67a56394be": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_022d5e4feba64ae0928ed7f3872066d4", "placeholder": "\u200b", "style": "IPY_MODEL_29068e73798d4582a1e52fb849969cf1", "tabbable": null, "tooltip": null, "value": " 9912422/9912422 [00:00<00:00, 26360107.18it/s]"}}, "ed5fb11f3a9e47b69aad48f8bf1f7461": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_d1b544192fff4a4da583c624d12890b3", "max": 8.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_ff4e278d816b407d972984608281faaa", "tabbable": null, "tooltip": null, "value": 8.0}}, "fa7f221d159f44a399abb85eb67a0201": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_66fc512f4207461b9e493495d88854e7", "max": 28881.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_312e269520644f68abf6c03135cf0e0d", "tabbable": null, "tooltip": null, "value": 28881.0}}, "ff4e278d816b407d972984608281faaa": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "ff9e6a2733e34e72bfb2541bdf073901": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}}, "version_major": 2, "version_minor": 0}}}, "nbformat": 4, "nbformat_minor": 5}