{"cells": [{"cell_type": "markdown", "id": "636c4a3d", "metadata": {"papermill": {"duration": 0.009294, "end_time": "2023-03-15T10:17:04.100884", "exception": false, "start_time": "2023-03-15T10:17:04.091590", "status": "completed"}, "tags": []}, "source": ["\n", "# Barlow Twins Tutorial\n", "\n", "* **Author:** Ananya Harsh Jha\n", "* **License:** CC BY-SA\n", "* **Generated:** 2023-03-15T10:15:32.745667\n", "\n", "This notebook describes the self-supervised learning method Barlow Twins.\n", "Barlow Twins differs from other recently proposed algorithms as it doesn't\n", "fall under the category of either contrastive learning, or methods like knowledge\n", "distillation or clustering. The simplicity of the loss function and its effectiveness\n", "in comparison to the current state of the art makes Barlow Twins an interesting\n", "case study.\n", "\n", "\n", "---\n", "Open in [![Open In Colab](){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/barlow-twins.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "ea7bda73", "metadata": {"papermill": {"duration": 0.003524, "end_time": "2023-03-15T10:17:04.108984", "exception": false, "start_time": "2023-03-15T10:17:04.105460", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "033687e1", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2023-03-15T10:17:04.117522Z", "iopub.status.busy": "2023-03-15T10:17:04.117046Z", "iopub.status.idle": "2023-03-15T10:17:07.370173Z", "shell.execute_reply": "2023-03-15T10:17:07.368604Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 3.260711, "end_time": "2023-03-15T10:17:07.373152", "exception": false, "start_time": "2023-03-15T10:17:04.112441", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}], "source": ["! pip install --quiet \"setuptools==67.4.0\" \"ipython[notebook]>=8.0.0, <8.12.0\" \"pytorch-lightning>=1.4, <2.0.0\" \"lightning>=2.0.0rc0\" \"torch>=1.8.1, <1.14.0\" \"torchmetrics>=0.7, <0.12\" \"torchvision\" \"matplotlib\""]}, {"cell_type": "markdown", "id": "9b96f5ac", "metadata": {"papermill": {"duration": 0.006333, "end_time": "2023-03-15T10:17:07.388176", "exception": false, "start_time": "2023-03-15T10:17:07.381843", "status": "completed"}, "tags": []}, "source": ["## Barlow Twins\n", "\n", "Barlow Twins finds itself in unique place amongst the current state-of-the-art self-supervised learning methods. It does not fall under the existing categories of contrastive learning, knowledge distillation or clustering based methods. Instead, it creates its own category of redundancy reductionand achieves competitive performance with a simple yet effective loss function. In this tutorial, we look at coding up a small version of Barlow Twins algorithm using PyTorch Lightning."]}, {"cell_type": "code", "execution_count": 2, "id": "c292211b", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:07.397278Z", "iopub.status.busy": "2023-03-15T10:17:07.396895Z", "iopub.status.idle": "2023-03-15T10:17:10.093146Z", "shell.execute_reply": "2023-03-15T10:17:10.091433Z"}, "papermill": {"duration": 2.704698, "end_time": "2023-03-15T10:17:10.096449", "exception": false, "start_time": "2023-03-15T10:17:07.391751", "status": "completed"}, "tags": []}, "outputs": [], "source": ["from functools import partial\n", "from typing import Sequence, Tuple, Union\n", "\n", "import lightning as L\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torchvision.transforms as transforms\n", "import torchvision.transforms.functional as VisionF\n", "from lightning.pytorch.callbacks import Callback, ModelCheckpoint\n", "from torch import Tensor\n", "from torch.utils.data import DataLoader\n", "from torchmetrics.functional import accuracy\n", "from torchvision.datasets import CIFAR10\n", "from torchvision.models.resnet import resnet18\n", "from torchvision.utils import make_grid\n", "\n", "batch_size = 32\n", "num_workers = 0  # to run notebook on CPU\n", "max_epochs = 200\n", "z_dim = 128"]}, {"cell_type": "markdown", "id": "1c50d589", "metadata": {"papermill": {"duration": 0.004811, "end_time": "2023-03-15T10:17:10.109938", "exception": false, "start_time": "2023-03-15T10:17:10.105127", "status": "completed"}, "tags": []}, "source": ["### Transforms\n", "\n", "We first define the data augmentation pipeline used in Barlow Twins. Here, we use pipeline proposed in SimCLR, which generates two copies/views of an input image by applying the following transformations in a sequence.\n", "\n", "First it takes a random crop of the image and resizes it to a fixed pre-specified size. Then, it applies a left-to-right random flip with a probability of 0.5. This step is followed by a composition of color jitter, conversion to grayscale with a probability of 0.2 and the application of a Gaussian blur filter. Finally, we normalize the image and convert it to a tensor.\n", "\n", "Within this transform, we add a third view for our online finetuner, which we explain later on. But, to explain things quickly here, we add a another transform to perform perform test our encoder on a downstream classification task."]}, {"cell_type": "code", "execution_count": 3, "id": "3f95ed87", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:10.122017Z", "iopub.status.busy": "2023-03-15T10:17:10.120628Z", "iopub.status.idle": "2023-03-15T10:17:10.136355Z", "shell.execute_reply": "2023-03-15T10:17:10.135536Z"}, "papermill": {"duration": 0.024148, "end_time": "2023-03-15T10:17:10.138739", "exception": false, "start_time": "2023-03-15T10:17:10.114591", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class BarlowTwinsTransform:\n", "    def __init__(self, train=True, input_height=224, gaussian_blur=True, jitter_strength=1.0, normalize=None):\n", "        self.input_height = input_height\n", "        self.gaussian_blur = gaussian_blur\n", "        self.jitter_strength = jitter_strength\n", "        self.normalize = normalize\n", "        self.train = train\n", "\n", "        color_jitter = transforms.ColorJitter(\n", "            0.8 * self.jitter_strength,\n", "            0.8 * self.jitter_strength,\n", "            0.8 * self.jitter_strength,\n", "            0.2 * self.jitter_strength,\n", "        )\n", "\n", "        color_transform = [transforms.RandomApply([color_jitter], p=0.8), transforms.RandomGrayscale(p=0.2)]\n", "\n", "        if self.gaussian_blur:\n", "            kernel_size = int(0.1 * self.input_height)\n", "            if kernel_size % 2 == 0:\n", "                kernel_size += 1\n", "\n", "            color_transform.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5))\n", "\n", "        self.color_transform = transforms.Compose(color_transform)\n", "\n", "        if normalize is None:\n", "            self.final_transform = transforms.ToTensor()\n", "        else:\n", "            self.final_transform = transforms.Compose([transforms.ToTensor(), normalize])\n", "\n", "        self.transform = transforms.Compose(\n", "            [\n", "                transforms.RandomResizedCrop(self.input_height),\n", "                transforms.RandomHorizontalFlip(p=0.5),\n", "                self.color_transform,\n", "                self.final_transform,\n", "            ]\n", "        )\n", "\n", "        self.finetune_transform = None\n", "        if self.train:\n", "            self.finetune_transform = transforms.Compose(\n", "                [\n", "                    transforms.RandomCrop(32, padding=4, padding_mode=\"reflect\"),\n", "                    transforms.RandomHorizontalFlip(),\n", "                    transforms.ToTensor(),\n", "                ]\n", "            )\n", "        else:\n", "            self.finetune_transform = transforms.ToTensor()\n", "\n", "    def __call__(self, sample):\n", "        return self.transform(sample), self.transform(sample), self.finetune_transform(sample)"]}, {"cell_type": "markdown", "id": "3618df65", "metadata": {"papermill": {"duration": 0.051171, "end_time": "2023-03-15T10:17:10.195254", "exception": false, "start_time": "2023-03-15T10:17:10.144083", "status": "completed"}, "tags": []}, "source": ["### Dataset\n", "\n", "We select CIFAR10 as the dataset to demonstrate the pre-training process for Barlow Twins. CIFAR10 images are 32x32 in size and we do not apply a Gaussian blur transformation on them. In this step, we create the training and validation dataloaders for CIFAR10."]}, {"cell_type": "code", "execution_count": 4, "id": "558fb49d", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:10.205675Z", "iopub.status.busy": "2023-03-15T10:17:10.205290Z", "iopub.status.idle": "2023-03-15T10:17:56.848780Z", "shell.execute_reply": "2023-03-15T10:17:56.847355Z"}, "papermill": {"duration": 46.652164, "end_time": "2023-03-15T10:17:56.851579", "exception": false, "start_time": "2023-03-15T10:17:10.199415", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "d90dc21138c847f4b380495aaf79540c", "version_major": 2, "version_minor": 0}, "text/plain": ["  0%|          | 0/170498071 [00:00<?, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Extracting ./cifar-10-python.tar.gz to .\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Files already downloaded and verified\n"]}], "source": ["def cifar10_normalization():\n", "    normalize = transforms.Normalize(\n", "        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]\n", "    )\n", "    return normalize\n", "\n", "\n", "train_transform = BarlowTwinsTransform(\n", "    train=True, input_height=32, gaussian_blur=False, jitter_strength=0.5, normalize=cifar10_normalization()\n", ")\n", "train_dataset = CIFAR10(root=\".\", train=True, download=True, transform=train_transform)\n", "\n", "val_transform = BarlowTwinsTransform(\n", "    train=False, input_height=32, gaussian_blur=False, jitter_strength=0.5, normalize=cifar10_normalization()\n", ")\n", "val_dataset = CIFAR10(root=\".\", train=False, download=True, transform=train_transform)\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)\n", "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)"]}, {"cell_type": "markdown", "id": "ecb11640", "metadata": {"papermill": {"duration": 0.003955, "end_time": "2023-03-15T10:17:56.863491", "exception": false, "start_time": "2023-03-15T10:17:56.859536", "status": "completed"}, "tags": []}, "source": ["### Plot images\n", "\n", "To see how the CIFAR10 images look after the data augmentation pipeline, we load a few images from the dataloader and plot them here."]}, {"cell_type": "code", "execution_count": 5, "id": "bb076485", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:56.873797Z", "iopub.status.busy": "2023-03-15T10:17:56.873355Z", "iopub.status.idle": "2023-03-15T10:17:57.197704Z", "shell.execute_reply": "2023-03-15T10:17:57.196982Z"}, "papermill": {"duration": 0.334869, "end_time": "2023-03-15T10:17:57.202254", "exception": false, "start_time": "2023-03-15T10:17:56.867385", "status": "completed"}, "tags": []}, "outputs": [{"data": {"image/png": "", "text/plain": ["<Figure size 640x480 with 1 Axes>"]}, "metadata": {}, "output_type": "display_data"}], "source": ["for batch in val_loader:\n", "    (img1, img2, _), label = batch\n", "    break\n", "\n", "img_grid = make_grid(img1, normalize=True)\n", "\n", "\n", "def show(imgs):\n", "    if not isinstance(imgs, list):\n", "        imgs = [imgs]\n", "    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)\n", "    for i, img in enumerate(imgs):\n", "        img = img.detach()\n", "        img = VisionF.to_pil_image(img)\n", "        axs[0, i].imshow(np.asarray(img))\n", "        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n", "\n", "\n", "show(img_grid)"]}, {"cell_type": "markdown", "id": "5fa7d593", "metadata": {"papermill": {"duration": 0.006015, "end_time": "2023-03-15T10:17:57.219237", "exception": false, "start_time": "2023-03-15T10:17:57.213222", "status": "completed"}, "tags": []}, "source": ["### Barlow Twins Loss\n", "\n", "Here we define the loss function for Barlow Twins. It first normalizes the D dimensinonal vectors from the projection head and then computes the DxD cross-correlation matrix between the normalized vectors of the 2 views of each image.\n", "\n", "Then it splits this cross-correlation matrix into two parts. The first part, the diagonal of this matrix is brought closer to 1, which pushes up the cosine similarity between the latent vectors of two views of each image, thus making the backbone invariant to the transformations applied to the views. The second part of the loss pushes the non-diagonal elements of the cross-corrlelation matrix closes to 0. This reduces the redundancy between the different dimensions of the latent vector."]}, {"cell_type": "code", "execution_count": 6, "id": "a58346bc", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:57.232610Z", "iopub.status.busy": "2023-03-15T10:17:57.232425Z", "iopub.status.idle": "2023-03-15T10:17:57.239404Z", "shell.execute_reply": "2023-03-15T10:17:57.238845Z"}, "papermill": {"duration": 0.016432, "end_time": "2023-03-15T10:17:57.241653", "exception": false, "start_time": "2023-03-15T10:17:57.225221", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class BarlowTwinsLoss(nn.Module):\n", "    def __init__(self, batch_size, lambda_coeff=5e-3, z_dim=128):\n", "        super().__init__()\n", "\n", "        self.z_dim = z_dim\n", "        self.batch_size = batch_size\n", "        self.lambda_coeff = lambda_coeff\n", "\n", "    def off_diagonal_ele(self, x):\n", "        # taken from: https://github.com/facebookresearch/barlowtwins/blob/main/main.py\n", "        # return a flattened view of the off-diagonal elements of a square matrix\n", "        n, m = x.shape\n", "        assert n == m\n", "        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()\n", "\n", "    def forward(self, z1, z2):\n", "        # N x D, where N is the batch size and D is output dim of projection head\n", "        z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)\n", "        z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)\n", "\n", "        cross_corr = torch.matmul(z1_norm.T, z2_norm) / self.batch_size\n", "\n", "        on_diag = torch.diagonal(cross_corr).add_(-1).pow_(2).sum()\n", "        off_diag = self.off_diagonal_ele(cross_corr).pow_(2).sum()\n", "\n", "        return on_diag + self.lambda_coeff * off_diag"]}, {"cell_type": "markdown", "id": "1da0791d", "metadata": {"papermill": {"duration": 0.005874, "end_time": "2023-03-15T10:17:57.258039", "exception": false, "start_time": "2023-03-15T10:17:57.252165", "status": "completed"}, "tags": []}, "source": ["### Backbone\n", "\n", "This is a standard Resnet backbone that we pre-train using the Barlow Twins method. To accommodate the 32x32 CIFAR10 images, we replace the first 7x7 convolution of the Resnet backbone by a 3x3 filter. We also remove the first Maxpool layer from the network for CIFAR10 images."]}, {"cell_type": "code", "execution_count": 7, "id": "58250a78", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:57.270905Z", "iopub.status.busy": "2023-03-15T10:17:57.270739Z", "iopub.status.idle": "2023-03-15T10:17:57.414693Z", "shell.execute_reply": "2023-03-15T10:17:57.413480Z"}, "papermill": {"duration": 0.153123, "end_time": "2023-03-15T10:17:57.417187", "exception": false, "start_time": "2023-03-15T10:17:57.264064", "status": "completed"}, "tags": []}, "outputs": [], "source": ["encoder = resnet18()\n", "\n", "# for CIFAR10, replace the first 7x7 conv with smaller 3x3 conv and remove the first maxpool\n", "encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n", "encoder.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)\n", "\n", "# replace classification fc layer of Resnet to obtain representations from the backbone\n", "encoder.fc = nn.Identity()"]}, {"cell_type": "markdown", "id": "f886affa", "metadata": {"papermill": {"duration": 0.006082, "end_time": "2023-03-15T10:17:57.434282", "exception": false, "start_time": "2023-03-15T10:17:57.428200", "status": "completed"}, "tags": []}, "source": ["### Projection head\n", "\n", "Unlike SimCLR and BYOL, the downstream performance of Barlow Twins greatly benefits from having a larger projection head after the backbone network. The paper utilizes a 3 layer MLP with 8192 hidden dimensions and 8192 as the output dimenion of the projection head. For the purposes of the tutorial, we use a smaller projection head. But, it is imperative to mention here that in practice, Barlow Twins needs to be trained using a bigger projection head as it is highly sensitive to its architecture and output dimensionality."]}, {"cell_type": "code", "execution_count": 8, "id": "ea05fb87", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:57.447305Z", "iopub.status.busy": "2023-03-15T10:17:57.447111Z", "iopub.status.idle": "2023-03-15T10:17:57.452667Z", "shell.execute_reply": "2023-03-15T10:17:57.451898Z"}, "papermill": {"duration": 0.013833, "end_time": "2023-03-15T10:17:57.454000", "exception": false, "start_time": "2023-03-15T10:17:57.440167", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class ProjectionHead(nn.Module):\n", "    def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):\n", "        super().__init__()\n", "\n", "        self.projection_head = nn.Sequential(\n", "            nn.Linear(input_dim, hidden_dim, bias=True),\n", "            nn.BatchNorm1d(hidden_dim),\n", "            nn.ReLU(),\n", "            nn.Linear(hidden_dim, output_dim, bias=False),\n", "        )\n", "\n", "    def forward(self, x):\n", "        return self.projection_head(x)"]}, {"cell_type": "markdown", "id": "f732ceef", "metadata": {"papermill": {"duration": 0.005938, "end_time": "2023-03-15T10:17:57.470391", "exception": false, "start_time": "2023-03-15T10:17:57.464453", "status": "completed"}, "tags": []}, "source": ["### Learning rate warmup\n", "\n", "For the purposes of this tutorial, we keep things simple and use a linear warmup schedule with Adam optimizer. In our previous experiments we have found that linear warmup part is much more important for the final performance of a model than the cosine decay component of the schedule."]}, {"cell_type": "code", "execution_count": 9, "id": "9c76696c", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:57.483473Z", "iopub.status.busy": "2023-03-15T10:17:57.483159Z", "iopub.status.idle": "2023-03-15T10:17:57.487653Z", "shell.execute_reply": "2023-03-15T10:17:57.486894Z"}, "papermill": {"duration": 0.012493, "end_time": "2023-03-15T10:17:57.488968", "exception": false, "start_time": "2023-03-15T10:17:57.476475", "status": "completed"}, "tags": []}, "outputs": [], "source": ["def fn(warmup_steps, step):\n", "    if step < warmup_steps:\n", "        return float(step) / float(max(1, warmup_steps))\n", "    else:\n", "        return 1.0\n", "\n", "\n", "def linear_warmup_decay(warmup_steps):\n", "    return partial(fn, warmup_steps)"]}, {"cell_type": "markdown", "id": "82a4ff69", "metadata": {"papermill": {"duration": 0.00783, "end_time": "2023-03-15T10:17:57.503250", "exception": false, "start_time": "2023-03-15T10:17:57.495420", "status": "completed"}, "tags": []}, "source": ["### Barlow Twins Lightning Module\n", "\n", "We keep the LightningModule for Barlow Twins neat and simple. It takes in an backbone encoder and initializes the projection head and the loss function. We configure the optimizer and the learning rate scheduler in the ``configure_optimizers`` method."]}, {"cell_type": "code", "execution_count": 10, "id": "ec48d4b6", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:57.516294Z", "iopub.status.busy": "2023-03-15T10:17:57.515981Z", "iopub.status.idle": "2023-03-15T10:17:57.525416Z", "shell.execute_reply": "2023-03-15T10:17:57.524533Z"}, "papermill": {"duration": 0.01854, "end_time": "2023-03-15T10:17:57.527757", "exception": false, "start_time": "2023-03-15T10:17:57.509217", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class BarlowTwins(L.LightningModule):\n", "    def __init__(\n", "        self,\n", "        encoder,\n", "        encoder_out_dim,\n", "        num_training_samples,\n", "        batch_size,\n", "        lambda_coeff=5e-3,\n", "        z_dim=128,\n", "        learning_rate=1e-4,\n", "        warmup_epochs=10,\n", "        max_epochs=200,\n", "    ):\n", "        super().__init__()\n", "\n", "        self.encoder = encoder\n", "        self.projection_head = ProjectionHead(input_dim=encoder_out_dim, hidden_dim=encoder_out_dim, output_dim=z_dim)\n", "        self.loss_fn = BarlowTwinsLoss(batch_size=batch_size, lambda_coeff=lambda_coeff, z_dim=z_dim)\n", "\n", "        self.learning_rate = learning_rate\n", "        self.warmup_epochs = warmup_epochs\n", "        self.max_epochs = max_epochs\n", "\n", "        self.train_iters_per_epoch = num_training_samples // batch_size\n", "\n", "    def forward(self, x):\n", "        return self.encoder(x)\n", "\n", "    def shared_step(self, batch):\n", "        (x1, x2, _), _ = batch\n", "\n", "        z1 = self.projection_head(self.encoder(x1))\n", "        z2 = self.projection_head(self.encoder(x2))\n", "\n", "        return self.loss_fn(z1, z2)\n", "\n", "    def training_step(self, batch, batch_idx):\n", "        loss = self.shared_step(batch)\n", "        self.log(\"train_loss\", loss, on_step=True, on_epoch=False)\n", "        return loss\n", "\n", "    def validation_step(self, batch, batch_idx):\n", "        loss = self.shared_step(batch)\n", "        self.log(\"val_loss\", loss, on_step=False, on_epoch=True)\n", "\n", "    def configure_optimizers(self):\n", "        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", "\n", "        warmup_steps = self.train_iters_per_epoch * self.warmup_epochs\n", "\n", "        scheduler = {\n", "            \"scheduler\": torch.optim.lr_scheduler.LambdaLR(\n", "                optimizer,\n", "                linear_warmup_decay(warmup_steps),\n", "            ),\n", "            \"interval\": \"step\",\n", "            \"frequency\": 1,\n", "        }\n", "\n", "        return [optimizer], [scheduler]"]}, {"cell_type": "markdown", "id": "9b77cfb1", "metadata": {"papermill": {"duration": 0.007469, "end_time": "2023-03-15T10:17:57.544592", "exception": false, "start_time": "2023-03-15T10:17:57.537123", "status": "completed"}, "tags": []}, "source": ["### Evaluation\n", "\n", "We define a callback which appends a linear layer on top of the encoder and trains the classification evaluation head in an online manner. We make sure not to backpropagate the gradients back to the encoder while tuning the linear layer. This technique was used in SimCLR as well and they showed that the final downstream classification peformance is pretty much similar to the results on online finetuning as the training progresses."]}, {"cell_type": "code", "execution_count": 11, "id": "7370d0e8", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:57.557764Z", "iopub.status.busy": "2023-03-15T10:17:57.557441Z", "iopub.status.idle": "2023-03-15T10:17:57.569207Z", "shell.execute_reply": "2023-03-15T10:17:57.568339Z"}, "papermill": {"duration": 0.021182, "end_time": "2023-03-15T10:17:57.571787", "exception": false, "start_time": "2023-03-15T10:17:57.550605", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class OnlineFineTuner(Callback):\n", "    def __init__(\n", "        self,\n", "        encoder_output_dim: int,\n", "        num_classes: int,\n", "    ) -> None:\n", "        super().__init__()\n", "\n", "        self.optimizer: torch.optim.Optimizer\n", "\n", "        self.encoder_output_dim = encoder_output_dim\n", "        self.num_classes = num_classes\n", "\n", "    def on_fit_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:\n", "        # add linear_eval layer and optimizer\n", "        pl_module.online_finetuner = nn.Linear(self.encoder_output_dim, self.num_classes).to(pl_module.device)\n", "        self.optimizer = torch.optim.Adam(pl_module.online_finetuner.parameters(), lr=1e-4)\n", "\n", "    def extract_online_finetuning_view(\n", "        self, batch: Sequence, device: Union[str, torch.device]\n", "    ) -> Tuple[Tensor, Tensor]:\n", "        (_, _, finetune_view), y = batch\n", "        finetune_view = finetune_view.to(device)\n", "        y = y.to(device)\n", "\n", "        return finetune_view, y\n", "\n", "    def on_train_batch_end(\n", "        self,\n", "        trainer: L.Trainer,\n", "        pl_module: L.LightningModule,\n", "        outputs: Sequence,\n", "        batch: Sequence,\n", "        batch_idx: int,\n", "    ) -> None:\n", "        x, y = self.extract_online_finetuning_view(batch, pl_module.device)\n", "\n", "        with torch.no_grad():\n", "            feats = pl_module(x)\n", "\n", "        feats = feats.detach()\n", "        preds = pl_module.online_finetuner(feats)\n", "        loss = F.cross_entropy(preds, y)\n", "\n", "        loss.backward()\n", "        self.optimizer.step()\n", "        self.optimizer.zero_grad()\n", "\n", "        acc = accuracy(F.softmax(preds, dim=1), y, task=\"multiclass\", num_classes=10)\n", "        pl_module.log(\"online_train_acc\", acc, on_step=True, on_epoch=False)\n", "        pl_module.log(\"online_train_loss\", loss, on_step=True, on_epoch=False)\n", "\n", "    def on_validation_batch_end(\n", "        self,\n", "        trainer: L.Trainer,\n", "        pl_module: L.LightningModule,\n", "        outputs: Sequence,\n", "        batch: Sequence,\n", "        batch_idx: int,\n", "    ) -> None:\n", "        x, y = self.extract_online_finetuning_view(batch, pl_module.device)\n", "\n", "        with torch.no_grad():\n", "            feats = pl_module(x)\n", "\n", "        feats = feats.detach()\n", "        preds = pl_module.online_finetuner(feats)\n", "        loss = F.cross_entropy(preds, y)\n", "\n", "        acc = accuracy(F.softmax(preds, dim=1), y, task=\"multiclass\", num_classes=10)\n", "        pl_module.log(\"online_val_acc\", acc, on_step=False, on_epoch=True, sync_dist=True)\n", "        pl_module.log(\"online_val_loss\", loss, on_step=False, on_epoch=True, sync_dist=True)"]}, {"cell_type": "markdown", "id": "da4103e9", "metadata": {"papermill": {"duration": 0.006051, "end_time": "2023-03-15T10:17:57.589781", "exception": false, "start_time": "2023-03-15T10:17:57.583730", "status": "completed"}, "tags": []}, "source": ["Finally, we define the trainer for training the model. We pass in the ``train_loader`` and ``val_loader`` we had initialized earlier to the ``fit`` function."]}, {"cell_type": "code", "execution_count": 12, "id": "3ba99c80", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:57.602719Z", "iopub.status.busy": "2023-03-15T10:17:57.602551Z", "iopub.status.idle": "2023-03-15T10:17:58.113146Z", "shell.execute_reply": "2023-03-15T10:17:58.112039Z"}, "papermill": {"duration": 0.51993, "end_time": "2023-03-15T10:17:58.115737", "exception": false, "start_time": "2023-03-15T10:17:57.595807", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None) will duplicate the last checkpoint saved.\n"]}, {"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}], "source": ["encoder_out_dim = 512\n", "\n", "model = BarlowTwins(\n", "    encoder=encoder,\n", "    encoder_out_dim=encoder_out_dim,\n", "    num_training_samples=len(train_dataset),\n", "    batch_size=batch_size,\n", "    z_dim=z_dim,\n", ")\n", "\n", "online_finetuner = OnlineFineTuner(encoder_output_dim=encoder_out_dim, num_classes=10)\n", "checkpoint_callback = ModelCheckpoint(every_n_epochs=100, save_top_k=-1, save_last=True)\n", "\n", "trainer = L.Trainer(\n", "    max_epochs=max_epochs,\n", "    accelerator=\"auto\",\n", "    devices=1,\n", "    callbacks=[online_finetuner, checkpoint_callback],\n", ")\n", "\n", "# uncomment this to train the model\n", "# this is done for the tutorial so that the notebook compiles\n", "# trainer.fit(model, train_loader, val_loader)"]}, {"cell_type": "markdown", "id": "da7ab591", "metadata": {"papermill": {"duration": 0.007408, "end_time": "2023-03-15T10:17:58.131109", "exception": false, "start_time": "2023-03-15T10:17:58.123701", "status": "completed"}, "tags": []}, "source": ["### Using the trained encoder for downstream tasks\n", "\n", "Once the encoder is pretrained on CIFAR10, we can use it to get image embeddings and use them further downstream on tasks like classification, detection, segmentation etc.\n", "\n", "In this tutorial, we did not completely train our encoder for 100s of epochs using the Barlow Twins pretraining method. So, we will load the pretrained encoder weights from a checkpoint and show the image embeddings obtained from that.\n", "\n", "To create this checkpoint, the encoder was pretrained for 200 epochs, and obtained a online finetune accuracy of x% on CIFAR-10."]}, {"cell_type": "code", "execution_count": 13, "id": "13d028f6", "metadata": {"execution": {"iopub.execute_input": "2023-03-15T10:17:58.152561Z", "iopub.status.busy": "2023-03-15T10:17:58.151484Z", "iopub.status.idle": "2023-03-15T10:17:58.536740Z", "shell.execute_reply": "2023-03-15T10:17:58.535792Z"}, "papermill": {"duration": 0.396014, "end_time": "2023-03-15T10:17:58.539687", "exception": false, "start_time": "2023-03-15T10:17:58.143673", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["torch.Size([4, 512])\n"]}], "source": ["# ckpt_model = torch.load('')  # upload checkpoint to aws\n", "# encoder = ckpt_model.encoder\n", "encoder = model.encoder\n", "\n", "downstream_dataset = CIFAR10(root=\".\", train=False, transform=transforms.ToTensor())\n", "dataloader = DataLoader(downstream_dataset, batch_size=4, shuffle=False)\n", "\n", "for batch in dataloader:\n", "    img, label = batch\n", "    print(encoder(img).shape)\n", "    break"]}, {"cell_type": "markdown", "id": "9d6287a2", "metadata": {"papermill": {"duration": 0.006401, "end_time": "2023-03-15T10:17:58.559796", "exception": false, "start_time": "2023-03-15T10:17:58.553395", "status": "completed"}, "tags": []}, "source": ["## Congratulations - Time to Join the Community!\n", "\n", "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n", "movement, you can do so in the following ways!\n", "\n", "### Star [Lightning](https://github.com/Lightning-AI/lightning) on GitHub\n", "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n", "tools we're building.\n", "\n", "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n", "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n", "and share your interests in `#general` channel\n", "\n", "\n", "### Contributions !\n", "The best way to contribute to our community is to become a code contributor! At any time you can go to\n", "[Lightning](https://github.com/Lightning-AI/lightning) or [Bolt](https://github.com/Lightning-AI/lightning-bolts)\n", "GitHub Issues page and filter for \"good first issue\".\n", "\n", "* [Lightning good first issue](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* [Bolt good first issue](https://github.com/Lightning-AI/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* You can also contribute your own notebooks with useful examples !\n", "\n", "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", "\n", "[![Pytorch Lightning](){height=\"60px\" width=\"240px\"}](https://pytorchlightning.ai)"]}, {"cell_type": "raw", "metadata": {"raw_mimetype": "text/restructuredtext"}, "source": [".. customcarditem::\n", "   :header: Barlow Twins Tutorial\n", "   :card_description: This notebook describes the self-supervised learning method Barlow Twins. Barlow Twins differs from other recently proposed algorithms as it doesn't fall under the category of...\n", "   :tags: Image,Self-Supervised,GPU/TPU,Lightning-Examples"]}], "metadata": {"jupytext": {"cell_metadata_filter": "id,colab,colab_type,-all", "formats": "ipynb,py:percent", "main_language": "python"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16"}, "papermill": {"default_parameters": {}, "duration": 56.709385, "end_time": "2023-03-15T10:17:59.486021", "environment_variables": {}, "exception": null, "input_path": "lightning_examples/barlow-twins/barlow_twins.ipynb", "output_path": ".notebooks/lightning_examples/barlow-twins.ipynb", "parameters": {}, "start_time": "2023-03-15T10:17:02.776636", "version": "2.4.0"}, "widgets": {"application/vnd.jupyter.widget-state+json": {"state": {"2091c9f87a1a42029f7a7e016384c0bc": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "3d9362ccb5d34c399157d5398c76a07e": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "5c7edd67371348a4a3f68eca312db6cf": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "64a7e761a91d455cb347d3bc9f9f2576": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "7b9e6b1be4074f1ba6bce914a5034941": {"model_module": "@jupyter-widgets/base", "model_module_version": "2.0.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "2.0.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border_bottom": null, "border_left": null, "border_right": null, "border_top": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "a5a88b90515f45268fbcf861a1ed2431": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "background": null, "description_width": "", "font_size": null, "text_color": null}}, "bf3026ad59f34865b04917ac5f539d66": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_5c7edd67371348a4a3f68eca312db6cf", "placeholder": "\u200b", "style": "IPY_MODEL_2091c9f87a1a42029f7a7e016384c0bc", "tabbable": null, "tooltip": null, "value": " 170498071/170498071 [00:41&lt;00:00, 1597381.59it/s]"}}, "d90dc21138c847f4b380495aaf79540c": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_dc0615b9648b4a26822edffce51c6326", "IPY_MODEL_e80f807dc00e4b4e88c0a38d01741045", "IPY_MODEL_bf3026ad59f34865b04917ac5f539d66"], "layout": "IPY_MODEL_7b9e6b1be4074f1ba6bce914a5034941", "tabbable": null, "tooltip": null}}, "d98cdb7e9b0c4730b299cc505f2c6b6c": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "2.0.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "dc0615b9648b4a26822edffce51c6326": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "HTMLView", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_64a7e761a91d455cb347d3bc9f9f2576", "placeholder": "\u200b", "style": "IPY_MODEL_a5a88b90515f45268fbcf861a1ed2431", "tabbable": null, "tooltip": null, "value": "100%"}}, "e80f807dc00e4b4e88c0a38d01741045": {"model_module": "@jupyter-widgets/controls", "model_module_version": "2.0.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "2.0.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "2.0.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_allow_html": false, "layout": "IPY_MODEL_3d9362ccb5d34c399157d5398c76a07e", "max": 170498071.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_d98cdb7e9b0c4730b299cc505f2c6b6c", "tabbable": null, "tooltip": null, "value": 170498071.0}}}, "version_major": 2, "version_minor": 0}}}, "nbformat": 4, "nbformat_minor": 5}