{"cells": [{"cell_type": "markdown", "id": "3c99ec5d", "metadata": {"papermill": {"duration": 0.033967, "end_time": "2021-12-04T16:50:09.385288", "exception": false, "start_time": "2021-12-04T16:50:09.351321", "status": "completed"}, "tags": []}, "source": ["\n", "# Finetune Transformers Models with PyTorch Lightning\n", "\n", "* **Author:** PL team\n", "* **License:** CC BY-SA\n", "* **Generated:** 2021-12-04T16:53:11.286202\n", "\n", "This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`.\n", "Then, we write a class to perform text classification on any dataset from the [GLUE Benchmark](https://gluebenchmark.com/).\n", "(We just show CoLA and MRPC due to constraint on compute/disk)\n", "\n", "\n", "---\n", "Open in [![Open In Colab](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAHUAAAAUCAYAAACzrHJDAAAIuUlEQVRoQ+1ZaVRURxb+qhdolmbTUVSURpZgmLhHbQVFZIlGQBEXcMvJhKiTEzfigjQg7oNEJ9GMGidnjnNMBs2czIzajksEFRE1xklCTKJiQLRFsUGkoUWw+82pamn79etGYoKek1B/4NW99/tu3e/dquJBAGD27NkHALxKf39WY39gyrOi+i3xqGtUoePJrFmznrmgtModorbTu8YRNZk5cybXTvCtwh7o6NR2KzuZMWNGh6jtVt7nA0ymT5/eJlF9POrh7PAQl6s8bGYa3PUum//htmebVtLRqW0q01M5keTk5FZFzU0oRle3+zxwg5Hgtb+PZiL/ZVohxCI+hL5JgjmfjPxZ26+33BG3dA+ealHPM4gQAo5rU59gsI8bRvl54t3Ca62mvHyUAhtOlLd5WSQpKcluBjumnoCLs1EARkVd9E8l3p9y2i7RbQ1B6pFwu/YDgW8KbHJHMTQrwnjz2oZm9M4pavOCfo5jWrgCaaMVcMs6/pNhDr0+AMN93XlxV7R6DNpyzi7W/OE+yIrsjU6rTrbKV5cd/pNyItOmTbMp6sbBB+EqaYJY4cWE3VUciNt1TpgfcRFv71Fi54xT5kSoyLvOBEJMOMxWXkFlBeBSX4u6Zkcs+3KszYRtiapbNRqF31UgetVuc8z9vBXIv1qD+F1f83B6uDlCUyfsZGepGPpmg01OB7EITQbhS9ribKy+DmP1DUiClLz4bnIHVOqa7BY+Z1wg5g3zgUvyehiNpnJKxSLc/ts76LKm0BzX3c0RNy1yXjDcB5lWoro4iNHQxM+f1kWeWQARAWQS++trISJTp061Kep25X/MycwtjuctSC5rxo7ppi7VNUox5+PhPHtrsS2O1qJ6yx1QujQUzm9sh6hbkBlvvGcN8hYnwjUjH6kjfZEd5c/jitz5Jc5U3ENnFynKl4eB7nyEgP2UZ+Yz3/rVEbyYr27qELrtC4FIC0J7sc7xWnmccdHfRRTs0VB+cA4lt+oFcRR/wUeH8FG5w2Mbx8FQ8TXEvv1xYf4wBP3O2WyL3/UVjpXWgIqaFeUPr+wTmDvUB7njH6/bOv+HRg4SqioAg5GDe1aB3ZeMTJkyRSBqkLsWqSEm0fZVBEN94zEZnYvrdx1JL5cxe+a+AbhSJecRRHW/ikTFRTa38dtQlNZ5CRKwFvUtZU/kvBoEF9Uxni/XqIM+dwKbTw3rhcxIf7gmr2M+H6SMwx8iBzJbw5oxeG3Lv5FX9B3AGaHPS8e8z77H7v9VMpvPG5ug1enh7eGK8h0LBTwUb+GInqzInlRUK65DmTPQu4c3+uQKjwKK77zwUxBX4Tq7yR1RuiwUsqlrABCM6esHdXoy47fk4+prYKy8ZF574x4V5BnHQBuf4g9Z9ld8U36L2aktZNNplNfw7zotwWTy5MkCUft4aLEopJj5/OPHl1BQqeAVOnHgNSQOqmBzq9V9cfEm/yx5ubMGKS9cYPZ3vx2OS/c6PVHUuUO7Y1Pci3BO/1zgq18byebfGemLtNF+6JRtOvMk926ibussZqM+1mNz4TWkH7rCbM5phwGRGDAaoF8fY5OHFnlldAA8sgoEXKnDukA1NgSeNjqkJT9brbN4pC9WRweYXyLugR73c+MYvyWfu0yC6+mjzN1Isfw3FKJS98CU/zI1IHFkFPR52cHL2FJk0sB6kMTERIGo9GzcPkLNfA0cwdwi/hfEYO86ZMd9w+y1egfM2T2Eh/vesMNwljSzuZRT420SW3eqy8N6aHMmwmnFUZ7/PGVPbIoNZvNU1BURdHs0bT2+HjL8sDSM2e6vi4Lj5NW8WOLVA6RTT2azxLV+bglaFNqLieqemS/gWkw7NyoAHo+2dEsiivengjKsPFoqWOvbSh/kxPaxyW/JRzH2Fl3EzD9/xjAefJqB3usKUFn/0Gb+S/d/jy3FN2yLOmnSJJtn6oehByEiHPSeXnDxFGPRnoFoaBJjcdQlbDwcjL1zTNuQpoxD7R0OG0uUTMi0fkVwdzBdYIwcwZunxrVJVLplNm54BZp7jfDfYLoNyqQi1K6KxIdHzmN+QQ2WjFIwUT2zTGdlRXo4NFXVUO4sgX5dFC7f0aP/ZlNeUjFBuL8Xjl6uRuP6aMjSjpjzsH62FDU7JhBuGccEXIvDfJFFBc/gHw80dklfCVYnRaDfpiJcutPA4F7qJsfJeUPQI+1fqMlNhFx1FM0GDqkjFVg7NojlQ0Vt4aM5ReSqcbpaCg8nCW5lRsBvbT4T1TLfFptsfh7gItzuKTdJSEiwKSrt1vcmnEXXrsLbYnWDA1bu+z2WKy9Arq+1KRqdfKsoBo0GcdtEpS/B1bO4v0cFiUhkjskvKcMrWwtAPHuwQq8Z+4LZ1vTQANfXt4J0DwZX9gWa9qh4XDM/voC9JXfwYEMMHJcfNtusn82ihvliVUwg5KrPGVf6GH94ZJpEZBen6EC4qYTHA1dXhW0JIex8txzv//c8lhzXIi/BFxOH9jGbQhZsRalTIBZZ8KkGyZAxeRQvXkFF1TWz/Hm46jNYUnjPbt3JxIkT7f6dSj8qfJJyVvBxgaIlblOyjtysNHWN9fjjqWi7glJfW3/S0Hlj2XnA8PhKT9w6g3Qx3XiXhvuxQsuT1proxBKI/AaZqY1Xz5muvY8G8XkRRCaHsfQsRAFDH/tZPbcYuHotOG0FRIqB4HR3wNVoIPLtz8ycTguu+jpEigE218vd1YCr5m+HpHMvEI9u4LTXwNWaLjl0iPwGAmIpeHx1VeCqTJdPs1/vweweQPO3HC24NhOhnTphwoQnfv6QSY2ICbkNmdSA4h87oaLaiYfn5diIEd4att2erOwJXbPUHp953p6orQVSUVWRAXBT8c/dJ5L9xhzaJGp71GR/wFP8P5V2z10NSC9T93QM2xUg8fHxT+zU9ijeU4naHon8CjFJXFzc8/kn+dN06q9QgF98SYSo2Xen2NjYZy5sR6f+4nLSK5Iam2PH/x87a1YN/t5sBgAAAABJRU5ErkJggg==){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/text-transformers.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", "| Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)"]}, {"cell_type": "markdown", "id": "0e4d7209", "metadata": {"papermill": {"duration": 0.028692, "end_time": "2021-12-04T16:50:09.445776", "exception": false, "start_time": "2021-12-04T16:50:09.417084", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "49b84d9c", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2021-12-04T16:50:09.510470Z", "iopub.status.busy": "2021-12-04T16:50:09.505774Z", "iopub.status.idle": "2021-12-04T16:50:12.925383Z", "shell.execute_reply": "2021-12-04T16:50:12.924802Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 3.45077, "end_time": "2021-12-04T16:50:12.925533", "exception": false, "start_time": "2021-12-04T16:50:09.474763", "status": "completed"}, "tags": []}, "outputs": [], "source": ["! pip install --quiet \"datasets\" \"pytorch-lightning>=1.3\" \"scipy\" \"transformers\" \"torchmetrics>=0.3\" \"scikit-learn\" \"torchtext>=0.9\" \"torch>=1.6, <1.9\""]}, {"cell_type": "code", "execution_count": 2, "id": "8c24af2f", "metadata": {"execution": {"iopub.execute_input": "2021-12-04T16:50:12.990546Z", "iopub.status.busy": "2021-12-04T16:50:12.990070Z", "iopub.status.idle": "2021-12-04T16:50:17.320862Z", "shell.execute_reply": "2021-12-04T16:50:17.320412Z"}, "papermill": {"duration": 4.364813, "end_time": "2021-12-04T16:50:17.320999", "exception": false, "start_time": "2021-12-04T16:50:12.956186", "status": "completed"}, "tags": []}, "outputs": [], "source": ["from datetime import datetime\n", "from typing import Optional\n", "\n", "import datasets\n", "import torch\n", "from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything\n", "from torch.utils.data import DataLoader\n", "from transformers import (\n", " AdamW,\n", " AutoConfig,\n", " AutoModelForSequenceClassification,\n", " AutoTokenizer,\n", " get_linear_schedule_with_warmup,\n", ")\n", "\n", "AVAIL_GPUS = min(1, torch.cuda.device_count())"]}, {"cell_type": "markdown", "id": "2134c12c", "metadata": {"papermill": {"duration": 0.029134, "end_time": "2021-12-04T16:50:17.381739", "exception": false, "start_time": "2021-12-04T16:50:17.352605", "status": "completed"}, "tags": []}, "source": ["## Training BERT with Lightning"]}, {"cell_type": "markdown", "id": "c827ca58", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.029129, "end_time": "2021-12-04T16:50:17.441085", "exception": false, "start_time": "2021-12-04T16:50:17.411956", "status": "completed"}, "tags": []}, "source": ["### Lightning DataModule for GLUE"]}, {"cell_type": "code", "execution_count": 3, "id": "5a856234", "metadata": {"execution": {"iopub.execute_input": "2021-12-04T16:50:17.515779Z", "iopub.status.busy": "2021-12-04T16:50:17.512843Z", "iopub.status.idle": "2021-12-04T16:50:17.517417Z", "shell.execute_reply": "2021-12-04T16:50:17.517796Z"}, "papermill": {"duration": 0.047785, "end_time": "2021-12-04T16:50:17.517928", "exception": false, "start_time": "2021-12-04T16:50:17.470143", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class GLUEDataModule(LightningDataModule):\n", "\n", " task_text_field_map = {\n", " \"cola\": [\"sentence\"],\n", " \"sst2\": [\"sentence\"],\n", " \"mrpc\": [\"sentence1\", \"sentence2\"],\n", " \"qqp\": [\"question1\", \"question2\"],\n", " \"stsb\": [\"sentence1\", \"sentence2\"],\n", " \"mnli\": [\"premise\", \"hypothesis\"],\n", " \"qnli\": [\"question\", \"sentence\"],\n", " \"rte\": [\"sentence1\", \"sentence2\"],\n", " \"wnli\": [\"sentence1\", \"sentence2\"],\n", " \"ax\": [\"premise\", \"hypothesis\"],\n", " }\n", "\n", " glue_task_num_labels = {\n", " \"cola\": 2,\n", " \"sst2\": 2,\n", " \"mrpc\": 2,\n", " \"qqp\": 2,\n", " \"stsb\": 1,\n", " \"mnli\": 3,\n", " \"qnli\": 2,\n", " \"rte\": 2,\n", " \"wnli\": 2,\n", " \"ax\": 3,\n", " }\n", "\n", " loader_columns = [\n", " \"datasets_idx\",\n", " \"input_ids\",\n", " \"token_type_ids\",\n", " \"attention_mask\",\n", " \"start_positions\",\n", " \"end_positions\",\n", " \"labels\",\n", " ]\n", "\n", " def __init__(\n", " self,\n", " model_name_or_path: str,\n", " task_name: str = \"mrpc\",\n", " max_seq_length: int = 128,\n", " train_batch_size: int = 32,\n", " eval_batch_size: int = 32,\n", " **kwargs,\n", " ):\n", " super().__init__()\n", " self.model_name_or_path = model_name_or_path\n", " self.task_name = task_name\n", " self.max_seq_length = max_seq_length\n", " self.train_batch_size = train_batch_size\n", " self.eval_batch_size = eval_batch_size\n", "\n", " self.text_fields = self.task_text_field_map[task_name]\n", " self.num_labels = self.glue_task_num_labels[task_name]\n", " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", "\n", " def setup(self, stage: str):\n", " self.dataset = datasets.load_dataset(\"glue\", self.task_name)\n", "\n", " for split in self.dataset.keys():\n", " self.dataset[split] = self.dataset[split].map(\n", " self.convert_to_features,\n", " batched=True,\n", " remove_columns=[\"label\"],\n", " )\n", " self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n", " self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n", "\n", " self.eval_splits = [x for x in self.dataset.keys() if \"validation\" in x]\n", "\n", " def prepare_data(self):\n", " datasets.load_dataset(\"glue\", self.task_name)\n", " AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.dataset[\"train\"], batch_size=self.train_batch_size)\n", "\n", " def val_dataloader(self):\n", " if len(self.eval_splits) == 1:\n", " return DataLoader(self.dataset[\"validation\"], batch_size=self.eval_batch_size)\n", " elif len(self.eval_splits) > 1:\n", " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", "\n", " def test_dataloader(self):\n", " if len(self.eval_splits) == 1:\n", " return DataLoader(self.dataset[\"test\"], batch_size=self.eval_batch_size)\n", " elif len(self.eval_splits) > 1:\n", " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", "\n", " def convert_to_features(self, example_batch, indices=None):\n", "\n", " # Either encode single sentence or sentence pairs\n", " if len(self.text_fields) > 1:\n", " texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n", " else:\n", " texts_or_text_pairs = example_batch[self.text_fields[0]]\n", "\n", " # Tokenize the text/text pairs\n", " features = self.tokenizer.batch_encode_plus(\n", " texts_or_text_pairs, max_length=self.max_seq_length, pad_to_max_length=True, truncation=True\n", " )\n", "\n", " # Rename label to labels to make it easier to pass to model forward\n", " features[\"labels\"] = example_batch[\"label\"]\n", "\n", " return features"]}, {"cell_type": "markdown", "id": "532d63d6", "metadata": {"papermill": {"duration": 0.029073, "end_time": "2021-12-04T16:50:17.576485", "exception": false, "start_time": "2021-12-04T16:50:17.547412", "status": "completed"}, "tags": []}, "source": ["**You could use this datamodule with standalone PyTorch if you wanted...**"]}, {"cell_type": "code", "execution_count": 4, "id": "f3f27b26", "metadata": {"execution": {"iopub.execute_input": "2021-12-04T16:50:17.637888Z", "iopub.status.busy": "2021-12-04T16:50:17.637424Z", "iopub.status.idle": "2021-12-04T16:50:24.396830Z", "shell.execute_reply": "2021-12-04T16:50:24.396334Z"}, "papermill": {"duration": 6.791378, "end_time": "2021-12-04T16:50:24.396956", "exception": false, "start_time": "2021-12-04T16:50:17.605578", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "232400085c33454e97f802179183d1ed", "version_major": 2, "version_minor": 0}, "text/plain": ["Downloading: 0%| | 0.00/28.0 [00:00= 1:\n", " preds = torch.argmax(logits, axis=1)\n", " elif self.hparams.num_labels == 1:\n", " preds = logits.squeeze()\n", "\n", " labels = batch[\"labels\"]\n", "\n", " return {\"loss\": val_loss, \"preds\": preds, \"labels\": labels}\n", "\n", " def validation_epoch_end(self, outputs):\n", " if self.hparams.task_name == \"mnli\":\n", " for i, output in enumerate(outputs):\n", " # matched or mismatched\n", " split = self.hparams.eval_splits[i].split(\"_\")[-1]\n", " preds = torch.cat([x[\"preds\"] for x in output]).detach().cpu().numpy()\n", " labels = torch.cat([x[\"labels\"] for x in output]).detach().cpu().numpy()\n", " loss = torch.stack([x[\"loss\"] for x in output]).mean()\n", " self.log(f\"val_loss_{split}\", loss, prog_bar=True)\n", " split_metrics = {\n", " f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()\n", " }\n", " self.log_dict(split_metrics, prog_bar=True)\n", " return loss\n", "\n", " preds = torch.cat([x[\"preds\"] for x in outputs]).detach().cpu().numpy()\n", " labels = torch.cat([x[\"labels\"] for x in outputs]).detach().cpu().numpy()\n", " loss = torch.stack([x[\"loss\"] for x in outputs]).mean()\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n", " return loss\n", "\n", " def setup(self, stage=None) -> None:\n", " if stage != \"fit\":\n", " return\n", " # Get dataloader by calling it - train_dataloader() is called after setup() by default\n", " train_loader = self.trainer.datamodule.train_dataloader()\n", "\n", " # Calculate total steps\n", " tb_size = self.hparams.train_batch_size * max(1, self.trainer.gpus)\n", " ab_size = self.trainer.accumulate_grad_batches * float(self.trainer.max_epochs)\n", " self.total_steps = (len(train_loader.dataset) // tb_size) // ab_size\n", "\n", " def configure_optimizers(self):\n", " \"\"\"Prepare optimizer and schedule (linear warmup and decay)\"\"\"\n", " model = self.model\n", " no_decay = [\"bias\", \"LayerNorm.weight\"]\n", " optimizer_grouped_parameters = [\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n", " \"weight_decay\": self.hparams.weight_decay,\n", " },\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n", " \"weight_decay\": 0.0,\n", " },\n", " ]\n", " optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n", "\n", " scheduler = get_linear_schedule_with_warmup(\n", " optimizer,\n", " num_warmup_steps=self.hparams.warmup_steps,\n", " num_training_steps=self.total_steps,\n", " )\n", " scheduler = {\"scheduler\": scheduler, \"interval\": \"step\", \"frequency\": 1}\n", " return [optimizer], [scheduler]"]}, {"cell_type": "markdown", "id": "be59b890", "metadata": {"papermill": {"duration": 0.055801, "end_time": "2021-12-04T16:50:24.863114", "exception": false, "start_time": "2021-12-04T16:50:24.807313", "status": "completed"}, "tags": []}, "source": ["## Training"]}, {"cell_type": "markdown", "id": "c1a5d61a", "metadata": {"papermill": {"duration": 0.054576, "end_time": "2021-12-04T16:50:24.972560", "exception": false, "start_time": "2021-12-04T16:50:24.917984", "status": "completed"}, "tags": []}, "source": ["### CoLA\n", "\n", "See an interactive view of the\n", "CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)"]}, {"cell_type": "code", "execution_count": 6, "id": "774115a0", "metadata": {"execution": {"iopub.execute_input": "2021-12-04T16:50:25.086490Z", "iopub.status.busy": "2021-12-04T16:50:25.086020Z", "iopub.status.idle": "2021-12-04T16:51:46.067948Z", "shell.execute_reply": "2021-12-04T16:51:46.068322Z"}, "papermill": {"duration": 81.04138, "end_time": "2021-12-04T16:51:46.068501", "exception": false, "start_time": "2021-12-04T16:50:25.027121", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 42\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "0071178f4db54b2a9c7a0c792adbfde9", "version_major": 2, "version_minor": 0}, "text/plain": ["Downloading: 0%| | 0.00/684 [00:00