{"cells": [{"cell_type": "markdown", "id": "4f979e51", "metadata": {"papermill": {"duration": 0.014037, "end_time": "2023-10-11T16:03:06.630612", "exception": false, "start_time": "2023-10-11T16:03:06.616575", "status": "completed"}, "tags": []}, "source": ["\n", "# Tutorial 6: Basics of Graph Neural Networks\n", "\n", "* **Author:** Phillip Lippe\n", "* **License:** CC BY-SA\n", "* **Generated:** 2023-10-11T16:02:31.112587\n", "\n", "In this tutorial, we will discuss the application of neural networks on graphs.\n", "Graph Neural Networks (GNNs) have recently gained increasing popularity in both applications and research,\n", "including domains such as social networks, knowledge graphs, recommender systems, and bioinformatics.\n", "While the theory and math behind GNNs might first seem complicated,\n", "the implementation of those models is quite simple and helps in understanding the methodology.\n", "Therefore, we will discuss the implementation of basic network layers of a GNN,\n", "namely graph convolutions, and attention layers.\n", "Finally, we will apply a GNN on semi-supervised node classification and molecule categorization.\n", "This notebook is part of a lecture series on Deep Learning at the University of Amsterdam.\n", "The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.\n", "\n", "\n", "---\n", "Open in [![Open In Colab](){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/course_UvA-DL/06-graph-neural-networks.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "5214d4fa", "metadata": {"papermill": {"duration": 0.012019, "end_time": "2023-10-11T16:03:06.661853", "exception": false, "start_time": "2023-10-11T16:03:06.649834", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "1c5351a8", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2023-10-11T16:03:06.687521Z", "iopub.status.busy": "2023-10-11T16:03:06.687127Z", "iopub.status.idle": "2023-10-11T16:03:10.716383Z", "shell.execute_reply": "2023-10-11T16:03:10.715453Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 4.044652, "end_time": "2023-10-11T16:03:10.718369", "exception": false, "start_time": "2023-10-11T16:03:06.673717", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}], "source": ["! pip install --quiet \"torch-geometric\" \"ipython[notebook]>=8.0.0, <8.17.0\" \"lightning>=2.0.0\" \"torch-sparse\" \"torch-cluster\" \"torch-scatter\" \"torch-spline-conv\" \"pytorch-lightning>=1.4, <2.1.0\" \"torchmetrics>=0.7, <1.3\" \"setuptools>=68.0.0, <68.3.0\" \"matplotlib>=3.0.0, <3.9.0\" \"torch>=1.8.1, <2.1.0\" \"urllib3\""]}, {"cell_type": "markdown", "id": "648a5e0d", "metadata": {"papermill": {"duration": 0.009288, "end_time": "2023-10-11T16:03:10.737471", "exception": false, "start_time": "2023-10-11T16:03:10.728183", "status": "completed"}, "tags": []}, "source": ["
\n", "We start by importing our standard libraries below."]}, {"cell_type": "code", "execution_count": 2, "id": "c7750212", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:10.757524Z", "iopub.status.busy": "2023-10-11T16:03:10.756995Z", "iopub.status.idle": "2023-10-11T16:03:15.530236Z", "shell.execute_reply": "2023-10-11T16:03:15.529272Z"}, "papermill": {"duration": 4.790934, "end_time": "2023-10-11T16:03:15.537517", "exception": false, "start_time": "2023-10-11T16:03:10.746583", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 42\n"]}], "source": ["# Standard libraries\n", "import os\n", "\n", "# For downloading pre-trained models\n", "import urllib.request\n", "from urllib.error import HTTPError\n", "\n", "# PyTorch Lightning\n", "import lightning as L\n", "\n", "# PyTorch\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "\n", "# PyTorch geometric\n", "import torch_geometric\n", "import torch_geometric.data as geom_data\n", "import torch_geometric.nn as geom_nn\n", "\n", "# PL callbacks\n", "from lightning.pytorch.callbacks import ModelCheckpoint\n", "from torch import Tensor\n", "\n", "AVAIL_GPUS = min(1, torch.cuda.device_count())\n", "BATCH_SIZE = 256 if AVAIL_GPUS else 64\n", "# Path to the folder where the datasets are/should be downloaded\n", "DATASET_PATH = os.environ.get(\"PATH_DATASETS\", \"data/\")\n", "# Path to the folder where the pretrained models are saved\n", "CHECKPOINT_PATH = os.environ.get(\"PATH_CHECKPOINT\", \"saved_models/GNNs/\")\n", "\n", "# Setting the seed\n", "L.seed_everything(42)\n", "\n", "# Ensure that all operations are deterministic on GPU (if used) for reproducibility\n", "torch.backends.cudnn.deterministic = True\n", "torch.backends.cudnn.benchmark = False"]}, {"cell_type": "markdown", "id": "32a3eca2", "metadata": {"papermill": {"duration": 0.009097, "end_time": "2023-10-11T16:03:15.557823", "exception": false, "start_time": "2023-10-11T16:03:15.548726", "status": "completed"}, "tags": []}, "source": ["We also have a few pre-trained models we download below."]}, {"cell_type": "code", "execution_count": 3, "id": "2a6a3f6a", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:15.582822Z", "iopub.status.busy": "2023-10-11T16:03:15.581987Z", "iopub.status.idle": "2023-10-11T16:03:16.134259Z", "shell.execute_reply": "2023-10-11T16:03:16.133238Z"}, "papermill": {"duration": 0.564879, "end_time": "2023-10-11T16:03:16.136013", "exception": false, "start_time": "2023-10-11T16:03:15.571134", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/NodeLevelMLP.ckpt...\n", "Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/NodeLevelGNN.ckpt...\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/GraphLevelGraphConv.ckpt...\n"]}], "source": ["# Github URL where saved models are stored for this tutorial\n", "base_url = \"https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/\"\n", "# Files to download\n", "pretrained_files = [\"NodeLevelMLP.ckpt\", \"NodeLevelGNN.ckpt\", \"GraphLevelGraphConv.ckpt\"]\n", "\n", "# Create checkpoint path if it doesn't exist yet\n", "os.makedirs(CHECKPOINT_PATH, exist_ok=True)\n", "\n", "# For each file, check whether it already exists. If not, try downloading it.\n", "for file_name in pretrained_files:\n", " file_path = os.path.join(CHECKPOINT_PATH, file_name)\n", " if \"/\" in file_name:\n", " os.makedirs(file_path.rsplit(\"/\", 1)[0], exist_ok=True)\n", " if not os.path.isfile(file_path):\n", " file_url = base_url + file_name\n", " print(\"Downloading %s...\" % file_url)\n", " try:\n", " urllib.request.urlretrieve(file_url, file_path)\n", " except HTTPError as e:\n", " print(\n", " \"Something went wrong. Please try to download the file from the GDrive folder,\"\n", " \" or contact the author with the full output including the following error:\\n\",\n", " e,\n", " )"]}, {"cell_type": "markdown", "id": "23531974", "metadata": {"papermill": {"duration": 0.04981, "end_time": "2023-10-11T16:03:16.195921", "exception": false, "start_time": "2023-10-11T16:03:16.146111", "status": "completed"}, "tags": []}, "source": ["## Graph Neural Networks"]}, {"cell_type": "markdown", "id": "123ed1b4", "metadata": {"papermill": {"duration": 0.009251, "end_time": "2023-10-11T16:03:16.214598", "exception": false, "start_time": "2023-10-11T16:03:16.205347", "status": "completed"}, "tags": []}, "source": ["### Graph representation\n", "\n", "Before starting the discussion of specific neural network operations on graphs, we should consider how to represent a graph.\n", "Mathematically, a graph $\\mathcal{G}$ is defined as a tuple of a set of nodes/vertices $V$, and a set of edges/links $E$: $\\mathcal{G}=(V,E)$.\n", "Each edge is a pair of two vertices, and represents a connection between them.\n", "For instance, let's look at the following graph:\n", "\n", "
\n", "\n", "The vertices are $V=\\{1,2,3,4\\}$, and edges $E=\\{(1,2), (2,3), (2,4), (3,4)\\}$.\n", "Note that for simplicity, we assume the graph to be undirected and hence don't add mirrored pairs like $(2,1)$.\n", "In application, vertices and edge can often have specific attributes, and edges can even be directed.\n", "The question is how we could represent this diversity in an efficient way for matrix operations.\n", "Usually, for the edges, we decide between two variants: an adjacency matrix, or a list of paired vertex indices.\n", "\n", "The **adjacency matrix** $A$ is a square matrix whose elements indicate whether pairs of vertices are adjacent,\n", "i.e. connected, or not.\n", "In the simplest case, $A_{ij}$ is 1 if there is a connection from node $i$ to $j$, and otherwise 0.\n", "If we have edge attributes or different categories of edges in a graph, this information can be added to the matrix as well.\n", "For an undirected graph, keep in mind that $A$ is a symmetric matrix ($A_{ij}=A_{ji}$).\n", "For the example graph above, we have the following adjacency matrix:\n", "\n", "$$\n", "A = \\begin{bmatrix}\n", " 0 & 1 & 0 & 0\\\\\n", " 1 & 0 & 1 & 1\\\\\n", " 0 & 1 & 0 & 1\\\\\n", " 0 & 1 & 1 & 0\n", "\\end{bmatrix}\n", "$$\n", "\n", "While expressing a graph as a list of edges is more efficient in terms of memory and (possibly) computation,\n", "using an adjacency matrix is more intuitive and simpler to implement.\n", "In our implementations below, we will rely on the adjacency matrix to keep the code simple.\n", "However, common libraries use edge lists, which we will discuss later more.\n", "Alternatively, we could also use the list of edges to define a sparse adjacency matrix with which we can work\n", "as if it was a dense matrix, but allows more memory-efficient operations.\n", "PyTorch supports this with the sub-package `torch.sparse`\n", "([documentation](https://pytorch.org/docs/stable/sparse.html)) which is however still in a beta-stage\n", "(API might change in future)."]}, {"cell_type": "markdown", "id": "d28b1897", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.012645, "end_time": "2023-10-11T16:03:16.236437", "exception": false, "start_time": "2023-10-11T16:03:16.223792", "status": "completed"}, "tags": []}, "source": ["### Graph Convolutions\n", "\n", "Graph Convolutional Networks have been introduced by [Kipf et al. ](https://openreview.net/pdf?id=SJU4ayYgl)\n", "in 2016 at the University of Amsterdam.\n", "He also wrote a great [blog post](https://tkipf.github.io/graph-convolutional-networks/) about this topic,\n", "which is recommended if you want to read about GCNs from a different perspective.\n", "GCNs are similar to convolutions in images in the sense that the \"filter\" parameters are typically shared over all locations in the graph.\n", "At the same time, GCNs rely on message passing methods, which means that vertices exchange information with the neighbors,\n", "and send \"messages\" to each other.\n", "Before looking at the math, we can try to visually understand how GCNs work.\n", "The first step is that each node creates a feature vector that represents the message it wants to send to all its neighbors.\n", "In the second step, the messages are sent to the neighbors, so that a node receives one message per adjacent node.\n", "Below we have visualized the two steps for our example graph.\n", "\n", "
\n", "\n", "If we want to formulate that in more mathematical terms, we need to first decide how to combine\n", "all the messages a node receives.\n", "As the number of messages vary across nodes, we need an operation that works for any number.\n", "Hence, the usual way to go is to sum or take the mean.\n", "Given the previous features of nodes $H^{(l)}$, the GCN layer is defined as follows:\n", "\n", "$$H^{(l+1)} = \\sigma\\left(\\hat{D}^{-1/2}\\hat{A}\\hat{D}^{-1/2}H^{(l)}W^{(l)}\\right)$$\n", "\n", "$W^{(l)}$ is the weight parameters with which we transform the input features into messages ($H^{(l)}W^{(l)}$).\n", "To the adjacency matrix $A$ we add the identity matrix so that each node sends its own message also to itself:\n", "$\\hat{A}=A+I$.\n", "Finally, to take the average instead of summing, we calculate the matrix $\\hat{D}$ which is a diagonal\n", "matrix with $D_{ii}$ denoting the number of neighbors node $i$ has.\n", "$\\sigma$ represents an arbitrary activation function, and not necessarily the sigmoid (usually a ReLU-based\n", "activation function is used in GNNs).\n", "\n", "When implementing the GCN layer in PyTorch, we can take advantage of the flexible operations on tensors.\n", "Instead of defining a matrix $\\hat{D}$, we can simply divide the summed messages by the number of neighbors afterward.\n", "Additionally, we replace the weight matrix with a linear layer, which additionally allows us to add a bias.\n", "Written as a PyTorch module, the GCN layer is defined as follows:"]}, {"cell_type": "code", "execution_count": 4, "id": "6ce21fc3", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:16.257386Z", "iopub.status.busy": "2023-10-11T16:03:16.256846Z", "iopub.status.idle": "2023-10-11T16:03:16.268948Z", "shell.execute_reply": "2023-10-11T16:03:16.267658Z"}, "papermill": {"duration": 0.025196, "end_time": "2023-10-11T16:03:16.271397", "exception": false, "start_time": "2023-10-11T16:03:16.246201", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class GCNLayer(nn.Module):\n", " def __init__(self, c_in, c_out):\n", " super().__init__()\n", " self.projection = nn.Linear(c_in, c_out)\n", "\n", " def forward(self, node_feats, adj_matrix):\n", " \"\"\"Forward.\n", "\n", " Args:\n", " node_feats: Tensor with node features of shape [batch_size, num_nodes, c_in]\n", " adj_matrix: Batch of adjacency matrices of the graph. If there is an edge from i to j,\n", " adj_matrix[b,i,j]=1 else 0. Supports directed edges by non-symmetric matrices.\n", " Assumes to already have added the identity connections.\n", " Shape: [batch_size, num_nodes, num_nodes]\n", " \"\"\"\n", " # Num neighbours = number of incoming edges\n", " num_neighbours = adj_matrix.sum(dim=-1, keepdims=True)\n", " node_feats = self.projection(node_feats)\n", " node_feats = torch.bmm(adj_matrix, node_feats)\n", " node_feats = node_feats / num_neighbours\n", " return node_feats"]}, {"cell_type": "markdown", "id": "f465968d", "metadata": {"papermill": {"duration": 0.009388, "end_time": "2023-10-11T16:03:16.290193", "exception": false, "start_time": "2023-10-11T16:03:16.280805", "status": "completed"}, "tags": []}, "source": ["To further understand the GCN layer, we can apply it to our example graph above.\n", "First, let's specify some node features and the adjacency matrix with added self-connections:"]}, {"cell_type": "code", "execution_count": 5, "id": "ae773b51", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:16.310470Z", "iopub.status.busy": "2023-10-11T16:03:16.309858Z", "iopub.status.idle": "2023-10-11T16:03:16.324044Z", "shell.execute_reply": "2023-10-11T16:03:16.323176Z"}, "papermill": {"duration": 0.026083, "end_time": "2023-10-11T16:03:16.325522", "exception": false, "start_time": "2023-10-11T16:03:16.299439", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Node features:\n", " tensor([[[0., 1.],\n", " [2., 3.],\n", " [4., 5.],\n", " [6., 7.]]])\n", "\n", "Adjacency matrix:\n", " tensor([[[1., 1., 0., 0.],\n", " [1., 1., 1., 1.],\n", " [0., 1., 1., 1.],\n", " [0., 1., 1., 1.]]])\n"]}], "source": ["node_feats = torch.arange(8, dtype=torch.float32).view(1, 4, 2)\n", "adj_matrix = Tensor([[[1, 1, 0, 0], [1, 1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 1]]])\n", "\n", "print(\"Node features:\\n\", node_feats)\n", "print(\"\\nAdjacency matrix:\\n\", adj_matrix)"]}, {"cell_type": "markdown", "id": "85a343c9", "metadata": {"papermill": {"duration": 0.015853, "end_time": "2023-10-11T16:03:16.350759", "exception": false, "start_time": "2023-10-11T16:03:16.334906", "status": "completed"}, "tags": []}, "source": ["Next, let's apply a GCN layer to it.\n", "For simplicity, we initialize the linear weight matrix as an identity matrix so that the input features are equal to the messages.\n", "This makes it easier for us to verify the message passing operation."]}, {"cell_type": "code", "execution_count": 6, "id": "f3352c18", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:16.373984Z", "iopub.status.busy": "2023-10-11T16:03:16.373472Z", "iopub.status.idle": "2023-10-11T16:03:16.381312Z", "shell.execute_reply": "2023-10-11T16:03:16.380405Z"}, "papermill": {"duration": 0.020795, "end_time": "2023-10-11T16:03:16.382866", "exception": false, "start_time": "2023-10-11T16:03:16.362071", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Adjacency matrix tensor([[[1., 1., 0., 0.],\n", " [1., 1., 1., 1.],\n", " [0., 1., 1., 1.],\n", " [0., 1., 1., 1.]]])\n", "Input features tensor([[[0., 1.],\n", " [2., 3.],\n", " [4., 5.],\n", " [6., 7.]]])\n", "Output features tensor([[[1., 2.],\n", " [3., 4.],\n", " [4., 5.],\n", " [4., 5.]]])\n"]}], "source": ["layer = GCNLayer(c_in=2, c_out=2)\n", "layer.projection.weight.data = Tensor([[1.0, 0.0], [0.0, 1.0]])\n", "layer.projection.bias.data = Tensor([0.0, 0.0])\n", "\n", "with torch.no_grad():\n", " out_feats = layer(node_feats, adj_matrix)\n", "\n", "print(\"Adjacency matrix\", adj_matrix)\n", "print(\"Input features\", node_feats)\n", "print(\"Output features\", out_feats)"]}, {"cell_type": "markdown", "id": "3556a93a", "metadata": {"papermill": {"duration": 0.009481, "end_time": "2023-10-11T16:03:16.401826", "exception": false, "start_time": "2023-10-11T16:03:16.392345", "status": "completed"}, "tags": []}, "source": ["As we can see, the first node's output values are the average of itself and the second node.\n", "Similarly, we can verify all other nodes.\n", "However, in a GNN, we would also want to allow feature exchange between nodes beyond its neighbors.\n", "This can be achieved by applying multiple GCN layers, which gives us the final layout of a GNN.\n", "The GNN can be build up by a sequence of GCN layers and non-linearities such as ReLU.\n", "For a visualization, see below (figure credit - [Thomas Kipf, 2016](https://tkipf.github.io/graph-convolutional-networks/)).\n", "\n", "
\n", "\n", "However, one issue we can see from looking at the example above is that the output features for nodes 3 and 4 are\n", "the same because they have the same adjacent nodes (including itself).\n", "Therefore, GCN layers can make the network forget node-specific information if we just take a mean over all messages.\n", "Multiple possible improvements have been proposed.\n", "While the simplest option might be using residual connections, the more common approach is to either weigh\n", "the self-connections higher or define a separate weight matrix for the self-connections.\n", "Alternatively, we can use a well-known concept: attention."]}, {"cell_type": "markdown", "id": "43d1a129", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.009353, "end_time": "2023-10-11T16:03:16.420575", "exception": false, "start_time": "2023-10-11T16:03:16.411222", "status": "completed"}, "tags": []}, "source": ["### Graph Attention\n", "\n", "Attention describes a weighted average of multiple elements with the weights dynamically computed based on an input\n", "query and elements' keys (if you don't know what attention is, it is recommended to at least go through\n", "the very first section called [What is Attention?](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html#What-is-Attention?)).\n", "This concept can be similarly applied to graphs, one of such is the Graph Attention Network\n", "(called GAT, proposed by [Velickovic et al., 2017](https://arxiv.org/abs/1710.10903)).\n", "Similarly to the GCN, the graph attention layer creates a message for each node using a linear layer/weight matrix.\n", "For the attention part, it uses the message from the node itself as a query, and the messages to average as both\n", "keys and values (note that this also includes the message to itself).\n", "The score function $f_{attn}$ is implemented as a one-layer MLP which maps the query and key to a single value.\n", "The MLP looks as follows (figure credit - [Velickovic et al. ](https://arxiv.org/abs/1710.10903)):\n", "\n", "
\n", "\n", "$h_i$ and $h_j$ are the original features from node $i$ and $j$ respectively, and represent the messages\n", "of the layer with $\\mathbf{W}$ as weight matrix.\n", "$\\mathbf{a}$ is the weight matrix of the MLP, which has the shape $[1,2\\times d_{\\text{message}}]$,\n", "and $\\alpha_{ij}$ the final attention weight from node $i$ to $j$.\n", "The calculation can be described as follows:\n", "\n", "$$\\alpha_{ij} = \\frac{\\exp\\left(\\text{LeakyReLU}\\left(\\mathbf{a}\\left[\\mathbf{W}h_i||\\mathbf{W}h_j\\right]\\right)\\right)}{\\sum_{k\\in\\mathcal{N}_i} \\exp\\left(\\text{LeakyReLU}\\left(\\mathbf{a}\\left[\\mathbf{W}h_i||\\mathbf{W}h_k\\right]\\right)\\right)}$$\n", "\n", "The operator $||$ represents the concatenation, and $\\mathcal{N}_i$ the indices of the neighbors of node $i$.\n", "Note that in contrast to usual practice, we apply a non-linearity (here LeakyReLU) before the softmax over elements.\n", "Although it seems like a minor change at first, it is crucial for the attention to depend on the original input.\n", "Specifically, let's remove the non-linearity for a second, and try to simplify the expression:\n", "\n", "$$\n", "\\begin{split}\n", " \\alpha_{ij} & = \\frac{\\exp\\left(\\mathbf{a}\\left[\\mathbf{W}h_i||\\mathbf{W}h_j\\right]\\right)}{\\sum_{k\\in\\mathcal{N}_i} \\exp\\left(\\mathbf{a}\\left[\\mathbf{W}h_i||\\mathbf{W}h_k\\right]\\right)}\\\\[5pt]\n", " & = \\frac{\\exp\\left(\\mathbf{a}_{:,:d/2}\\mathbf{W}h_i+\\mathbf{a}_{:,d/2:}\\mathbf{W}h_j\\right)}{\\sum_{k\\in\\mathcal{N}_i} \\exp\\left(\\mathbf{a}_{:,:d/2}\\mathbf{W}h_i+\\mathbf{a}_{:,d/2:}\\mathbf{W}h_k\\right)}\\\\[5pt]\n", " & = \\frac{\\exp\\left(\\mathbf{a}_{:,:d/2}\\mathbf{W}h_i\\right)\\cdot\\exp\\left(\\mathbf{a}_{:,d/2:}\\mathbf{W}h_j\\right)}{\\sum_{k\\in\\mathcal{N}_i} \\exp\\left(\\mathbf{a}_{:,:d/2}\\mathbf{W}h_i\\right)\\cdot\\exp\\left(\\mathbf{a}_{:,d/2:}\\mathbf{W}h_k\\right)}\\\\[5pt]\n", " & = \\frac{\\exp\\left(\\mathbf{a}_{:,d/2:}\\mathbf{W}h_j\\right)}{\\sum_{k\\in\\mathcal{N}_i} \\exp\\left(\\mathbf{a}_{:,d/2:}\\mathbf{W}h_k\\right)}\\\\\n", "\\end{split}\n", "$$\n", "\n", "We can see that without the non-linearity, the attention term with $h_i$ actually cancels itself out,\n", "resulting in the attention being independent of the node itself.\n", "Hence, we would have the same issue as the GCN of creating the same output features for nodes with the same neighbors.\n", "This is why the LeakyReLU is crucial and adds some dependency on $h_i$ to the attention.\n", "\n", "Once we obtain all attention factors, we can calculate the output features for each node by performing\n", "the weighted average:\n", "\n", "$$h_i'=\\sigma\\left(\\sum_{j\\in\\mathcal{N}_i}\\alpha_{ij}\\mathbf{W}h_j\\right)$$\n", "\n", "$\\sigma$ is yet another non-linearity, as in the GCN layer.\n", "Visually, we can represent the full message passing in an attention layer as follows\n", "(figure credit - [Velickovic et al. ](https://arxiv.org/abs/1710.10903)):\n", "\n", "
\n", "\n", "To increase the expressiveness of the graph attention network, [Velickovic et al. ](https://arxiv.org/abs/1710.10903)\n", "proposed to extend it to multiple heads similar to the Multi-Head Attention block in Transformers.\n", "This results in $N$ attention layers being applied in parallel.\n", "In the image above, it is visualized as three different colors of arrows (green, blue, and purple)\n", "that are afterward concatenated.\n", "The average is only applied for the very final prediction layer in a network.\n", "\n", "After having discussed the graph attention layer in detail, we can implement it below:"]}, {"cell_type": "code", "execution_count": 7, "id": "7c3c6d13", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:16.441063Z", "iopub.status.busy": "2023-10-11T16:03:16.440489Z", "iopub.status.idle": "2023-10-11T16:03:16.456587Z", "shell.execute_reply": "2023-10-11T16:03:16.455511Z"}, "papermill": {"duration": 0.028039, "end_time": "2023-10-11T16:03:16.458052", "exception": false, "start_time": "2023-10-11T16:03:16.430013", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class GATLayer(nn.Module):\n", " def __init__(self, c_in, c_out, num_heads=1, concat_heads=True, alpha=0.2):\n", " \"\"\"\n", " Args:\n", " c_in: Dimensionality of input features\n", " c_out: Dimensionality of output features\n", " num_heads: Number of heads, i.e. attention mechanisms to apply in parallel. The\n", " output features are equally split up over the heads if concat_heads=True.\n", " concat_heads: If True, the output of the different heads is concatenated instead of averaged.\n", " alpha: Negative slope of the LeakyReLU activation.\n", " \"\"\"\n", " super().__init__()\n", " self.num_heads = num_heads\n", " self.concat_heads = concat_heads\n", " if self.concat_heads:\n", " assert c_out % num_heads == 0, \"Number of output features must be a multiple of the count of heads.\"\n", " c_out = c_out // num_heads\n", "\n", " # Sub-modules and parameters needed in the layer\n", " self.projection = nn.Linear(c_in, c_out * num_heads)\n", " self.a = nn.Parameter(Tensor(num_heads, 2 * c_out)) # One per head\n", " self.leakyrelu = nn.LeakyReLU(alpha)\n", "\n", " # Initialization from the original implementation\n", " nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414)\n", " nn.init.xavier_uniform_(self.a.data, gain=1.414)\n", "\n", " def forward(self, node_feats, adj_matrix, print_attn_probs=False):\n", " \"\"\"Forward.\n", "\n", " Args:\n", " node_feats: Input features of the node. Shape: [batch_size, c_in]\n", " adj_matrix: Adjacency matrix including self-connections. Shape: [batch_size, num_nodes, num_nodes]\n", " print_attn_probs: If True, the attention weights are printed during the forward pass\n", " (for debugging purposes)\n", " \"\"\"\n", " batch_size, num_nodes = node_feats.size(0), node_feats.size(1)\n", "\n", " # Apply linear layer and sort nodes by head\n", " node_feats = self.projection(node_feats)\n", " node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)\n", "\n", " # We need to calculate the attention logits for every edge in the adjacency matrix\n", " # Doing this on all possible combinations of nodes is very expensive\n", " # => Create a tensor of [W*h_i||W*h_j] with i and j being the indices of all edges\n", " # Returns indices where the adjacency matrix is not 0 => edges\n", " edges = adj_matrix.nonzero(as_tuple=False)\n", " node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)\n", " edge_indices_row = edges[:, 0] * num_nodes + edges[:, 1]\n", " edge_indices_col = edges[:, 0] * num_nodes + edges[:, 2]\n", " a_input = torch.cat(\n", " [\n", " torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0),\n", " torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0),\n", " ],\n", " dim=-1,\n", " ) # Index select returns a tensor with node_feats_flat being indexed at the desired positions\n", "\n", " # Calculate attention MLP output (independent for each head)\n", " attn_logits = torch.einsum(\"bhc,hc->bh\", a_input, self.a)\n", " attn_logits = self.leakyrelu(attn_logits)\n", "\n", " # Map list of attention values back into a matrix\n", " attn_matrix = attn_logits.new_zeros(adj_matrix.shape + (self.num_heads,)).fill_(-9e15)\n", " attn_matrix[adj_matrix[..., None].repeat(1, 1, 1, self.num_heads) == 1] = attn_logits.reshape(-1)\n", "\n", " # Weighted average of attention\n", " attn_probs = F.softmax(attn_matrix, dim=2)\n", " if print_attn_probs:\n", " print(\"Attention probs\\n\", attn_probs.permute(0, 3, 1, 2))\n", " node_feats = torch.einsum(\"bijh,bjhc->bihc\", attn_probs, node_feats)\n", "\n", " # If heads should be concatenated, we can do this by reshaping. Otherwise, take mean\n", " if self.concat_heads:\n", " node_feats = node_feats.reshape(batch_size, num_nodes, -1)\n", " else:\n", " node_feats = node_feats.mean(dim=2)\n", "\n", " return node_feats"]}, {"cell_type": "markdown", "id": "bf5f5993", "metadata": {"papermill": {"duration": 0.009437, "end_time": "2023-10-11T16:03:16.477084", "exception": false, "start_time": "2023-10-11T16:03:16.467647", "status": "completed"}, "tags": []}, "source": ["Again, we can apply the graph attention layer on our example graph above to understand the dynamics better.\n", "As before, the input layer is initialized as an identity matrix, but we set $\\mathbf{a}$\n", "to be a vector of arbitrary numbers to obtain different attention values.\n", "We use two heads to show the parallel, independent attention mechanisms working in the layer."]}, {"cell_type": "code", "execution_count": 8, "id": "4d348ba1", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:16.497519Z", "iopub.status.busy": "2023-10-11T16:03:16.496931Z", "iopub.status.idle": "2023-10-11T16:03:16.566018Z", "shell.execute_reply": "2023-10-11T16:03:16.565240Z"}, "papermill": {"duration": 0.084686, "end_time": "2023-10-11T16:03:16.571182", "exception": false, "start_time": "2023-10-11T16:03:16.486496", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Attention probs\n", " tensor([[[[0.3543, 0.6457, 0.0000, 0.0000],\n", " [0.1096, 0.1450, 0.2642, 0.4813],\n", " [0.0000, 0.1858, 0.2885, 0.5257],\n", " [0.0000, 0.2391, 0.2696, 0.4913]],\n", "\n", " [[0.5100, 0.4900, 0.0000, 0.0000],\n", " [0.2975, 0.2436, 0.2340, 0.2249],\n", " [0.0000, 0.3838, 0.3142, 0.3019],\n", " [0.0000, 0.4018, 0.3289, 0.2693]]]])\n", "Adjacency matrix tensor([[[1., 1., 0., 0.],\n", " [1., 1., 1., 1.],\n", " [0., 1., 1., 1.],\n", " [0., 1., 1., 1.]]])\n", "Input features tensor([[[0., 1.],\n", " [2., 3.],\n", " [4., 5.],\n", " [6., 7.]]])\n", "Output features tensor([[[1.2913, 1.9800],\n", " [4.2344, 3.7725],\n", " [4.6798, 4.8362],\n", " [4.5043, 4.7351]]])\n"]}], "source": ["layer = GATLayer(2, 2, num_heads=2)\n", "layer.projection.weight.data = Tensor([[1.0, 0.0], [0.0, 1.0]])\n", "layer.projection.bias.data = Tensor([0.0, 0.0])\n", "layer.a.data = Tensor([[-0.2, 0.3], [0.1, -0.1]])\n", "\n", "with torch.no_grad():\n", " out_feats = layer(node_feats, adj_matrix, print_attn_probs=True)\n", "\n", "print(\"Adjacency matrix\", adj_matrix)\n", "print(\"Input features\", node_feats)\n", "print(\"Output features\", out_feats)"]}, {"cell_type": "markdown", "id": "2ab15650", "metadata": {"papermill": {"duration": 0.015782, "end_time": "2023-10-11T16:03:16.610501", "exception": false, "start_time": "2023-10-11T16:03:16.594719", "status": "completed"}, "tags": []}, "source": ["We recommend that you try to calculate the attention matrix at least for one head and one node for yourself.\n", "The entries are 0 where there does not exist an edge between $i$ and $j$.\n", "For the others, we see a diverse set of attention probabilities.\n", "Moreover, the output features of node 3 and 4 are now different although they have the same neighbors."]}, {"cell_type": "markdown", "id": "6d1738ad", "metadata": {"papermill": {"duration": 0.009688, "end_time": "2023-10-11T16:03:16.636046", "exception": false, "start_time": "2023-10-11T16:03:16.626358", "status": "completed"}, "tags": []}, "source": ["## PyTorch Geometric\n", "\n", "We had mentioned before that implementing graph networks with adjacency matrix is simple and straight-forward\n", "but can be computationally expensive for large graphs.\n", "Many real-world graphs can reach over 200k nodes, for which adjacency matrix-based implementations fail.\n", "There are a lot of optimizations possible when implementing GNNs, and luckily, there exist packages that provide such layers.\n", "The most popular packages for PyTorch are [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/)\n", "and the [Deep Graph Library](https://www.dgl.ai/) (the latter being actually framework agnostic).\n", "Which one to use depends on the project you are planning to do and personal taste.\n", "In this tutorial, we will look at PyTorch Geometric as part of the PyTorch family.\n", "\n", "PyTorch Geometric provides us a set of common graph layers, including the GCN and GAT layer we implemented above.\n", "Additionally, similar to PyTorch's torchvision, it provides the common graph datasets and transformations\n", "on those to simplify training.\n", "Compared to our implementation above, PyTorch Geometric uses a list of index pairs to represent the edges.\n", "The details of this library will be explored further in our experiments.\n", "\n", "In our tasks below, we want to allow us to pick from a multitude of graph layers.\n", "Thus, we define again below a dictionary to access those using a string:"]}, {"cell_type": "code", "execution_count": 9, "id": "3ef60900", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:16.663161Z", "iopub.status.busy": "2023-10-11T16:03:16.662658Z", "iopub.status.idle": "2023-10-11T16:03:16.685582Z", "shell.execute_reply": "2023-10-11T16:03:16.681240Z"}, "papermill": {"duration": 0.039783, "end_time": "2023-10-11T16:03:16.689352", "exception": false, "start_time": "2023-10-11T16:03:16.649569", "status": "completed"}, "tags": []}, "outputs": [], "source": ["gnn_layer_by_name = {\"GCN\": geom_nn.GCNConv, \"GAT\": geom_nn.GATConv, \"GraphConv\": geom_nn.GraphConv}"]}, {"cell_type": "markdown", "id": "13e41e62", "metadata": {"papermill": {"duration": 0.009584, "end_time": "2023-10-11T16:03:16.708583", "exception": false, "start_time": "2023-10-11T16:03:16.698999", "status": "completed"}, "tags": []}, "source": ["Additionally to GCN and GAT, we added the layer `geom_nn.GraphConv`\n", "([documentation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GraphConv)).\n", "GraphConv is a GCN with a separate weight matrix for the self-connections.\n", "Mathematically, this would be:\n", "\n", "$$\n", "\\mathbf{x}_i^{(l+1)} = \\mathbf{W}^{(l + 1)}_1 \\mathbf{x}_i^{(l)} + \\mathbf{W}^{(\\ell + 1)}_2 \\sum_{j \\in \\mathcal{N}_i} \\mathbf{x}_j^{(l)}\n", "$$\n", "\n", "In this formula, the neighbor's messages are added instead of averaged.\n", "However, PyTorch Geometric provides the argument `aggr` to switch between summing, averaging, and max pooling."]}, {"cell_type": "markdown", "id": "034c4dbc", "metadata": {"papermill": {"duration": 0.009575, "end_time": "2023-10-11T16:03:16.727815", "exception": false, "start_time": "2023-10-11T16:03:16.718240", "status": "completed"}, "tags": []}, "source": ["## Experiments on graph structures\n", "\n", "
\n", "\n", "Tasks on graph-structured data can be grouped into three groups: node-level, edge-level and graph-level.\n", "The different levels describe on which level we want to perform classification/regression.\n", "We will discuss all three types in more detail below."]}, {"cell_type": "markdown", "id": "f25835f5", "metadata": {"papermill": {"duration": 0.009577, "end_time": "2023-10-11T16:03:16.747148", "exception": false, "start_time": "2023-10-11T16:03:16.737571", "status": "completed"}, "tags": []}, "source": ["### Node-level tasks: Semi-supervised node classification\n", "\n", "Node-level tasks have the goal to classify nodes in a graph.\n", "Usually, we have given a single, large graph with >1000 nodes of which a certain amount of nodes are labeled.\n", "We learn to classify those labeled examples during training and try to generalize to the unlabeled nodes.\n", "\n", "A popular example that we will use in this tutorial is the Cora dataset, a citation network among papers.\n", "The Cora consists of 2708 scientific publications with links between each other representing\n", "the citation of one paper by another.\n", "The task is to classify each publication into one of seven classes.\n", "Each publication is represented by a bag-of-words vector.\n", "This means that we have a vector of 1433 elements for each publication, where a 1 at feature $i$ indicates\n", "that the $i$-th word of a pre-defined dictionary is in the article.\n", "Binary bag-of-words representations are commonly used when we need very simple encodings,\n", "and already have an intuition of what words to expect in a network.\n", "There exist much better approaches, but we will leave this to the NLP courses to discuss.\n", "\n", "We will load the dataset below:"]}, {"cell_type": "code", "execution_count": 10, "id": "64e4c45d", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:16.769275Z", "iopub.status.busy": "2023-10-11T16:03:16.768197Z", "iopub.status.idle": "2023-10-11T16:03:18.147101Z", "shell.execute_reply": "2023-10-11T16:03:18.146012Z"}, "papermill": {"duration": 1.39751, "end_time": "2023-10-11T16:03:18.154239", "exception": false, "start_time": "2023-10-11T16:03:16.756729", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Processing...\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Done!\n"]}], "source": ["cora_dataset = torch_geometric.datasets.Planetoid(root=DATASET_PATH, name=\"Cora\")"]}, {"cell_type": "markdown", "id": "b46bad32", "metadata": {"papermill": {"duration": 0.011189, "end_time": "2023-10-11T16:03:18.180670", "exception": false, "start_time": "2023-10-11T16:03:18.169481", "status": "completed"}, "tags": []}, "source": ["Let's look at how PyTorch Geometric represents the graph data.\n", "Note that although we have a single graph, PyTorch Geometric returns a dataset for compatibility to other datasets."]}, {"cell_type": "code", "execution_count": 11, "id": "065b8b71", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:18.210861Z", "iopub.status.busy": "2023-10-11T16:03:18.210003Z", "iopub.status.idle": "2023-10-11T16:03:18.219012Z", "shell.execute_reply": "2023-10-11T16:03:18.218544Z"}, "papermill": {"duration": 0.033004, "end_time": "2023-10-11T16:03:18.228178", "exception": false, "start_time": "2023-10-11T16:03:18.195174", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/plain": ["Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])"]}, "execution_count": 11, "metadata": {}, "output_type": "execute_result"}], "source": ["cora_dataset[0]"]}, {"cell_type": "markdown", "id": "41f9f836", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.010081, "end_time": "2023-10-11T16:03:18.248995", "exception": false, "start_time": "2023-10-11T16:03:18.238914", "status": "completed"}, "tags": []}, "source": ["The graph is represented by a `Data` object\n", "([documentation](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data))\n", "which we can access as a standard Python namespace.\n", "The edge index tensor is the list of edges in the graph and contains the mirrored version of each edge for undirected graphs.\n", "The `train_mask`, `val_mask`, and `test_mask` are boolean masks that indicate which nodes we should use for training,\n", "validation, and testing.\n", "The `x` tensor is the feature tensor of our 2708 publications, and `y` the labels for all nodes.\n", "\n", "After having seen the data, we can implement a simple graph neural network.\n", "The GNN applies a sequence of graph layers (GCN, GAT, or GraphConv), ReLU as activation function,\n", "and dropout for regularization.\n", "See below for the specific implementation."]}, {"cell_type": "code", "execution_count": 12, "id": "bd92f2e4", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:18.274053Z", "iopub.status.busy": "2023-10-11T16:03:18.273754Z", "iopub.status.idle": "2023-10-11T16:03:18.281026Z", "shell.execute_reply": "2023-10-11T16:03:18.280532Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.023402, "end_time": "2023-10-11T16:03:18.282495", "exception": false, "start_time": "2023-10-11T16:03:18.259093", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class GNNModel(nn.Module):\n", " def __init__(\n", " self,\n", " c_in,\n", " c_hidden,\n", " c_out,\n", " num_layers=2,\n", " layer_name=\"GCN\",\n", " dp_rate=0.1,\n", " **kwargs,\n", " ):\n", " \"\"\"GNNModel.\n", "\n", " Args:\n", " c_in: Dimension of input features\n", " c_hidden: Dimension of hidden features\n", " c_out: Dimension of the output features. Usually number of classes in classification\n", " num_layers: Number of \"hidden\" graph layers\n", " layer_name: String of the graph layer to use\n", " dp_rate: Dropout rate to apply throughout the network\n", " kwargs: Additional arguments for the graph layer (e.g. number of heads for GAT)\n", " \"\"\"\n", " super().__init__()\n", " gnn_layer = gnn_layer_by_name[layer_name]\n", "\n", " layers = []\n", " in_channels, out_channels = c_in, c_hidden\n", " for l_idx in range(num_layers - 1):\n", " layers += [\n", " gnn_layer(in_channels=in_channels, out_channels=out_channels, **kwargs),\n", " nn.ReLU(inplace=True),\n", " nn.Dropout(dp_rate),\n", " ]\n", " in_channels = c_hidden\n", " layers += [gnn_layer(in_channels=in_channels, out_channels=c_out, **kwargs)]\n", " self.layers = nn.ModuleList(layers)\n", "\n", " def forward(self, x, edge_index):\n", " \"\"\"Forward.\n", "\n", " Args:\n", " x: Input features per node\n", " edge_index: List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)\n", " \"\"\"\n", " for layer in self.layers:\n", " # For graph layers, we need to add the \"edge_index\" tensor as additional input\n", " # All PyTorch Geometric graph layer inherit the class \"MessagePassing\", hence\n", " # we can simply check the class type.\n", " if isinstance(layer, geom_nn.MessagePassing):\n", " x = layer(x, edge_index)\n", " else:\n", " x = layer(x)\n", " return x"]}, {"cell_type": "markdown", "id": "cdf52d9d", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.010266, "end_time": "2023-10-11T16:03:18.302816", "exception": false, "start_time": "2023-10-11T16:03:18.292550", "status": "completed"}, "tags": []}, "source": ["Good practice in node-level tasks is to create an MLP baseline that is applied to each node independently.\n", "This way we can verify whether adding the graph information to the model indeed improves the prediction, or not.\n", "It might also be that the features per node are already expressive enough to clearly point towards a specific class.\n", "To check this, we implement a simple MLP below."]}, {"cell_type": "code", "execution_count": 13, "id": "4877e955", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:18.324957Z", "iopub.status.busy": "2023-10-11T16:03:18.324275Z", "iopub.status.idle": "2023-10-11T16:03:18.330174Z", "shell.execute_reply": "2023-10-11T16:03:18.329347Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.018324, "end_time": "2023-10-11T16:03:18.331570", "exception": false, "start_time": "2023-10-11T16:03:18.313246", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class MLPModel(nn.Module):\n", " def __init__(self, c_in, c_hidden, c_out, num_layers=2, dp_rate=0.1):\n", " \"\"\"MLPModel.\n", "\n", " Args:\n", " c_in: Dimension of input features\n", " c_hidden: Dimension of hidden features\n", " c_out: Dimension of the output features. Usually number of classes in classification\n", " num_layers: Number of hidden layers\n", " dp_rate: Dropout rate to apply throughout the network\n", " \"\"\"\n", " super().__init__()\n", " layers = []\n", " in_channels, out_channels = c_in, c_hidden\n", " for l_idx in range(num_layers - 1):\n", " layers += [nn.Linear(in_channels, out_channels), nn.ReLU(inplace=True), nn.Dropout(dp_rate)]\n", " in_channels = c_hidden\n", " layers += [nn.Linear(in_channels, c_out)]\n", " self.layers = nn.Sequential(*layers)\n", "\n", " def forward(self, x, *args, **kwargs):\n", " \"\"\"Forward.\n", "\n", " Args:\n", " x: Input features per node\n", " \"\"\"\n", " return self.layers(x)"]}, {"cell_type": "markdown", "id": "447b52a0", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.010195, "end_time": "2023-10-11T16:03:18.352006", "exception": false, "start_time": "2023-10-11T16:03:18.341811", "status": "completed"}, "tags": []}, "source": ["Finally, we can merge the models into a PyTorch Lightning module which handles the training,\n", "validation, and testing for us."]}, {"cell_type": "code", "execution_count": 14, "id": "d1281945", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:18.377636Z", "iopub.status.busy": "2023-10-11T16:03:18.377121Z", "iopub.status.idle": "2023-10-11T16:03:18.393331Z", "shell.execute_reply": "2023-10-11T16:03:18.392651Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.02932, "end_time": "2023-10-11T16:03:18.395619", "exception": false, "start_time": "2023-10-11T16:03:18.366299", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class NodeLevelGNN(L.LightningModule):\n", " def __init__(self, model_name, **model_kwargs):\n", " super().__init__()\n", " # Saving hyperparameters\n", " self.save_hyperparameters()\n", "\n", " if model_name == \"MLP\":\n", " self.model = MLPModel(**model_kwargs)\n", " else:\n", " self.model = GNNModel(**model_kwargs)\n", " self.loss_module = nn.CrossEntropyLoss()\n", "\n", " def forward(self, data, mode=\"train\"):\n", " x, edge_index = data.x, data.edge_index\n", " x = self.model(x, edge_index)\n", "\n", " # Only calculate the loss on the nodes corresponding to the mask\n", " if mode == \"train\":\n", " mask = data.train_mask\n", " elif mode == \"val\":\n", " mask = data.val_mask\n", " elif mode == \"test\":\n", " mask = data.test_mask\n", " else:\n", " assert False, \"Unknown forward mode: %s\" % mode\n", "\n", " loss = self.loss_module(x[mask], data.y[mask])\n", " acc = (x[mask].argmax(dim=-1) == data.y[mask]).sum().float() / mask.sum()\n", " return loss, acc\n", "\n", " def configure_optimizers(self):\n", " # We use SGD here, but Adam works as well\n", " optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=2e-3)\n", " return optimizer\n", "\n", " def training_step(self, batch, batch_idx):\n", " loss, acc = self.forward(batch, mode=\"train\")\n", " self.log(\"train_loss\", loss)\n", " self.log(\"train_acc\", acc)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " _, acc = self.forward(batch, mode=\"val\")\n", " self.log(\"val_acc\", acc)\n", "\n", " def test_step(self, batch, batch_idx):\n", " _, acc = self.forward(batch, mode=\"test\")\n", " self.log(\"test_acc\", acc)"]}, {"cell_type": "markdown", "id": "2aa96907", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.010188, "end_time": "2023-10-11T16:03:18.415853", "exception": false, "start_time": "2023-10-11T16:03:18.405665", "status": "completed"}, "tags": []}, "source": ["Additionally to the Lightning module, we define a training function below.\n", "As we have a single graph, we use a batch size of 1 for the data loader and share the same data loader for the train,\n", "validation, and test set (the mask is picked inside the Lightning module).\n", "Besides, we set the argument `enable_progress_bar` to False as it usually shows the progress per epoch,\n", "but an epoch only consists of a single step.\n", "If you have downloaded the pre-trained models in the beginning of the tutorial, we load those instead of training from scratch.\n", "Finally, we test the model and return the results."]}, {"cell_type": "code", "execution_count": 15, "id": "47ae5b35", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:18.438295Z", "iopub.status.busy": "2023-10-11T16:03:18.437639Z", "iopub.status.idle": "2023-10-11T16:03:18.445846Z", "shell.execute_reply": "2023-10-11T16:03:18.445108Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.020995, "end_time": "2023-10-11T16:03:18.447153", "exception": false, "start_time": "2023-10-11T16:03:18.426158", "status": "completed"}, "tags": []}, "outputs": [], "source": ["def train_node_classifier(model_name, dataset, **model_kwargs):\n", " L.seed_everything(42)\n", " node_data_loader = geom_data.DataLoader(dataset, batch_size=1)\n", "\n", " # Create a PyTorch Lightning trainer\n", " root_dir = os.path.join(CHECKPOINT_PATH, \"NodeLevel\" + model_name)\n", " os.makedirs(root_dir, exist_ok=True)\n", " trainer = L.Trainer(\n", " default_root_dir=root_dir,\n", " callbacks=[ModelCheckpoint(save_weights_only=True, mode=\"max\", monitor=\"val_acc\")],\n", " accelerator=\"auto\",\n", " devices=AVAIL_GPUS,\n", " max_epochs=200,\n", " enable_progress_bar=False,\n", " ) # 0 because epoch size is 1\n", " trainer.logger._default_hp_metric = None # Optional logging argument that we don't need\n", "\n", " # Check whether pretrained model exists. If yes, load it and skip training\n", " pretrained_filename = os.path.join(CHECKPOINT_PATH, \"NodeLevel%s.ckpt\" % model_name)\n", " if os.path.isfile(pretrained_filename):\n", " print(\"Found pretrained model, loading...\")\n", " model = NodeLevelGNN.load_from_checkpoint(pretrained_filename)\n", " else:\n", " L.seed_everything()\n", " model = NodeLevelGNN(\n", " model_name=model_name, c_in=dataset.num_node_features, c_out=dataset.num_classes, **model_kwargs\n", " )\n", " trainer.fit(model, node_data_loader, node_data_loader)\n", " model = NodeLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)\n", "\n", " # Test best model on the test set\n", " test_result = trainer.test(model, dataloaders=node_data_loader, verbose=False)\n", " batch = next(iter(node_data_loader))\n", " batch = batch.to(model.device)\n", " _, train_acc = model.forward(batch, mode=\"train\")\n", " _, val_acc = model.forward(batch, mode=\"val\")\n", " result = {\"train\": train_acc, \"val\": val_acc, \"test\": test_result[0][\"test_acc\"]}\n", " return model, result"]}, {"cell_type": "markdown", "id": "63d4b255", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.010246, "end_time": "2023-10-11T16:03:18.467747", "exception": false, "start_time": "2023-10-11T16:03:18.457501", "status": "completed"}, "tags": []}, "source": ["Now, we can train our models. First, let's train the simple MLP:"]}, {"cell_type": "code", "execution_count": 16, "id": "a871d384", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:18.489049Z", "iopub.status.busy": "2023-10-11T16:03:18.488856Z", "iopub.status.idle": "2023-10-11T16:03:18.492689Z", "shell.execute_reply": "2023-10-11T16:03:18.492181Z"}, "papermill": {"duration": 0.015812, "end_time": "2023-10-11T16:03:18.493806", "exception": false, "start_time": "2023-10-11T16:03:18.477994", "status": "completed"}, "tags": []}, "outputs": [], "source": ["# Small function for printing the test scores\n", "def print_results(result_dict):\n", " if \"train\" in result_dict:\n", " print(\"Train accuracy: %4.2f%%\" % (100.0 * result_dict[\"train\"]))\n", " if \"val\" in result_dict:\n", " print(\"Val accuracy: %4.2f%%\" % (100.0 * result_dict[\"val\"]))\n", " print(\"Test accuracy: %4.2f%%\" % (100.0 * result_dict[\"test\"]))"]}, {"cell_type": "code", "execution_count": 17, "id": "6d78fad1", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:18.515575Z", "iopub.status.busy": "2023-10-11T16:03:18.515071Z", "iopub.status.idle": "2023-10-11T16:03:19.423216Z", "shell.execute_reply": "2023-10-11T16:03:19.422172Z"}, "papermill": {"duration": 0.920887, "end_time": "2023-10-11T16:03:19.425008", "exception": false, "start_time": "2023-10-11T16:03:18.504121", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 42\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/torch_geometric/deprecation.py:22: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", " warnings.warn(out)\n", "GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", " warning_cache.warn(\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Found pretrained model, loading...\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v2.0.9.post0. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint --file saved_models/GNNs/NodeLevelMLP.ckpt`\n"]}, {"name": "stderr", "output_type": "stream", "text": ["You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:442: PossibleUserWarning: The dataloader, test_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/data.py:76: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2708. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", " warning_cache.warn(\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Train accuracy: 97.14%\n", "Val accuracy: 54.60%\n", "Test accuracy: 60.60%\n"]}], "source": ["node_mlp_model, node_mlp_result = train_node_classifier(\n", " model_name=\"MLP\", dataset=cora_dataset, c_hidden=16, num_layers=2, dp_rate=0.1\n", ")\n", "\n", "print_results(node_mlp_result)"]}, {"cell_type": "markdown", "id": "e4fda810", "metadata": {"papermill": {"duration": 0.011495, "end_time": "2023-10-11T16:03:19.448936", "exception": false, "start_time": "2023-10-11T16:03:19.437441", "status": "completed"}, "tags": []}, "source": ["Although the MLP can overfit on the training dataset because of the high-dimensional input features,\n", "it does not perform too well on the test set.\n", "Let's see if we can beat this score with our graph networks:"]}, {"cell_type": "code", "execution_count": 18, "id": "0e0fd1c4", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:19.474172Z", "iopub.status.busy": "2023-10-11T16:03:19.473856Z", "iopub.status.idle": "2023-10-11T16:03:20.736499Z", "shell.execute_reply": "2023-10-11T16:03:20.731572Z"}, "papermill": {"duration": 1.280765, "end_time": "2023-10-11T16:03:20.740917", "exception": false, "start_time": "2023-10-11T16:03:19.460152", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 42\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/torch_geometric/deprecation.py:22: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", " warnings.warn(out)\n", "GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v2.0.9.post0. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint --file saved_models/GNNs/NodeLevelGNN.ckpt`\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Found pretrained model, loading...\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Train accuracy: 100.00%\n", "Val accuracy: 78.60%\n", "Test accuracy: 82.40%\n"]}], "source": ["node_gnn_model, node_gnn_result = train_node_classifier(\n", " model_name=\"GNN\", layer_name=\"GCN\", dataset=cora_dataset, c_hidden=16, num_layers=2, dp_rate=0.1\n", ")\n", "print_results(node_gnn_result)"]}, {"cell_type": "markdown", "id": "fa912878", "metadata": {"papermill": {"duration": 0.011993, "end_time": "2023-10-11T16:03:20.777315", "exception": false, "start_time": "2023-10-11T16:03:20.765322", "status": "completed"}, "tags": []}, "source": ["As we would have hoped for, the GNN model outperforms the MLP by quite a margin.\n", "This shows that using the graph information indeed improves our predictions and lets us generalizes better.\n", "\n", "The hyperparameters in the model have been chosen to create a relatively small network.\n", "This is because the first layer with an input dimension of 1433 can be relatively expensive to perform for large graphs.\n", "In general, GNNs can become relatively expensive for very big graphs.\n", "This is why such GNNs either have a small hidden size or use a special batching strategy\n", "where we sample a connected subgraph of the big, original graph."]}, {"cell_type": "markdown", "id": "5dcd5632", "metadata": {"papermill": {"duration": 0.014035, "end_time": "2023-10-11T16:03:20.803784", "exception": false, "start_time": "2023-10-11T16:03:20.789749", "status": "completed"}, "tags": []}, "source": ["### Edge-level tasks: Link prediction\n", "\n", "In some applications, we might have to predict on an edge-level instead of node-level.\n", "The most common edge-level task in GNN is link prediction.\n", "Link prediction means that given a graph, we want to predict whether there will be/should be an edge between two nodes or not.\n", "For example, in a social network, this is used by Facebook and co to propose new friends to you.\n", "Again, graph level information can be crucial to perform this task.\n", "The output prediction is usually done by performing a similarity metric on the pair of node features,\n", "which should be 1 if there should be a link, and otherwise close to 0.\n", "To keep the tutorial short, we will not implement this task ourselves.\n", "Nevertheless, there are many good resources out there if you are interested in looking closer at this task.\n", "Tutorials and papers for this topic include:\n", "\n", "* [PyTorch Geometric example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/link_pred.py)\n", "* [Graph Neural Networks: A Review of Methods and Applications](https://arxiv.org/pdf/1812.08434.pdf), Zhou et al.\n", "2019\n", "* [Link Prediction Based on Graph Neural Networks](https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf), Zhang and Chen, 2018."]}, {"cell_type": "markdown", "id": "3212535a", "metadata": {"papermill": {"duration": 0.011616, "end_time": "2023-10-11T16:03:20.826996", "exception": false, "start_time": "2023-10-11T16:03:20.815380", "status": "completed"}, "tags": []}, "source": ["### Graph-level tasks: Graph classification\n", "\n", "Finally, in this part of the tutorial, we will have a closer look at how to apply GNNs to the task of graph classification.\n", "The goal is to classify an entire graph instead of single nodes or edges.\n", "Therefore, we are also given a dataset of multiple graphs that we need to classify based on some structural graph properties.\n", "The most common task for graph classification is molecular property prediction, in which molecules are represented as graphs.\n", "Each atom is linked to a node, and edges in the graph are the bonds between atoms.\n", "For example, look at the figure below.\n", "\n", "
\n", "\n", "On the left, we have an arbitrary, small molecule with different atoms, whereas the right part of the image shows the graph representation.\n", "The atom types are abstracted as node features (e.g. a one-hot vector), and the different bond types are used as edge features.\n", "For simplicity, we will neglect the edge attributes in this tutorial, but you can include by using methods like the\n", "[Relational Graph Convolution](https://arxiv.org/abs/1703.06103) that uses a different weight matrix for each edge type.\n", "\n", "The dataset we will use below is called the MUTAG dataset.\n", "It is a common small benchmark for graph classification algorithms, and contain 188 graphs with 18 nodes\n", "and 20 edges on average for each graph.\n", "The graph nodes have 7 different labels/atom types, and the binary graph labels represent \"their mutagenic effect\n", "on a specific gram negative bacterium\" (the specific meaning of the labels are not too important here).\n", "The dataset is part of a large collection of different graph classification datasets, known as the\n", "[TUDatasets](https://chrsmrrs.github.io/datasets/), which is directly accessible\n", "via `torch_geometric.datasets.TUDataset` ([documentation](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.TUDataset)) in PyTorch Geometric.\n", "We can load the dataset below."]}, {"cell_type": "code", "execution_count": 19, "id": "916022ac", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:20.856516Z", "iopub.status.busy": "2023-10-11T16:03:20.856076Z", "iopub.status.idle": "2023-10-11T16:03:21.889799Z", "shell.execute_reply": "2023-10-11T16:03:21.889240Z"}, "papermill": {"duration": 1.053683, "end_time": "2023-10-11T16:03:21.893346", "exception": false, "start_time": "2023-10-11T16:03:20.839663", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Extracting /__w/15/s/.datasets/MUTAG/MUTAG.zip\n", "Processing...\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Done!\n"]}], "source": ["tu_dataset = torch_geometric.datasets.TUDataset(root=DATASET_PATH, name=\"MUTAG\")"]}, {"cell_type": "markdown", "id": "cc5fd7be", "metadata": {"papermill": {"duration": 0.012581, "end_time": "2023-10-11T16:03:21.923932", "exception": false, "start_time": "2023-10-11T16:03:21.911351", "status": "completed"}, "tags": []}, "source": ["Let's look at some statistics for the dataset:"]}, {"cell_type": "code", "execution_count": 20, "id": "f1857455", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:21.949045Z", "iopub.status.busy": "2023-10-11T16:03:21.948703Z", "iopub.status.idle": "2023-10-11T16:03:21.957830Z", "shell.execute_reply": "2023-10-11T16:03:21.957240Z"}, "papermill": {"duration": 0.023647, "end_time": "2023-10-11T16:03:21.959503", "exception": false, "start_time": "2023-10-11T16:03:21.935856", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Data object: Data(x=[3371, 7], edge_index=[2, 7442], edge_attr=[7442, 4], y=[188])\n", "Length: 188\n", "Average label: 0.66\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.\n", " warnings.warn(msg)\n"]}], "source": ["print(\"Data object:\", tu_dataset.data)\n", "print(\"Length:\", len(tu_dataset))\n", "print(\"Average label: %4.2f\" % (tu_dataset.data.y.float().mean().item()))"]}, {"cell_type": "markdown", "id": "68a2b8c5", "metadata": {"papermill": {"duration": 0.012255, "end_time": "2023-10-11T16:03:21.988097", "exception": false, "start_time": "2023-10-11T16:03:21.975842", "status": "completed"}, "tags": []}, "source": ["The first line shows how the dataset stores different graphs.\n", "The nodes, edges, and labels of each graph are concatenated to one tensor, and the dataset stores the indices\n", "where to split the tensors correspondingly.\n", "The length of the dataset is the number of graphs we have, and the \"average label\"\n", "denotes the percentage of the graph with label 1.\n", "As long as the percentage is in the range of 0.5, we have a relatively balanced dataset.\n", "It happens quite often that graph datasets are very imbalanced, hence checking the class balance\n", "is always a good thing to do.\n", "\n", "Next, we will split our dataset into a training and test part.\n", "Note that we do not use a validation set this time because of the small size of the dataset.\n", "Therefore, our model might overfit slightly on the validation set due to the noise of the evaluation,\n", "but we still get an estimate of the performance on untrained data."]}, {"cell_type": "code", "execution_count": 21, "id": "29d0a9c6", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:22.014982Z", "iopub.status.busy": "2023-10-11T16:03:22.014390Z", "iopub.status.idle": "2023-10-11T16:03:22.018493Z", "shell.execute_reply": "2023-10-11T16:03:22.017953Z"}, "papermill": {"duration": 0.021347, "end_time": "2023-10-11T16:03:22.022150", "exception": false, "start_time": "2023-10-11T16:03:22.000803", "status": "completed"}, "tags": []}, "outputs": [], "source": ["torch.manual_seed(42)\n", "tu_dataset.shuffle()\n", "train_dataset = tu_dataset[:150]\n", "test_dataset = tu_dataset[150:]"]}, {"cell_type": "markdown", "id": "0558ab51", "metadata": {"papermill": {"duration": 0.012813, "end_time": "2023-10-11T16:03:22.047078", "exception": false, "start_time": "2023-10-11T16:03:22.034265", "status": "completed"}, "tags": []}, "source": ["When using a data loader, we encounter a problem with batching $N$ graphs.\n", "Each graph in the batch can have a different number of nodes and edges, and hence we would require a lot of padding to obtain a single tensor.\n", "Torch geometric uses a different, more efficient approach: we can view the $N$ graphs in a batch as a single large graph with concatenated node and edge list.\n", "As there is no edge between the $N$ graphs, running GNN layers on the large graph gives us the same output as running the GNN on each graph separately.\n", "Visually, this batching strategy is visualized below (figure credit - PyTorch Geometric team,\n", "[tutorial here](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb)).\n", "\n", "
\n", "\n", "The adjacency matrix is zero for any nodes that come from two different graphs, and otherwise according to the adjacency matrix of the individual graph.\n", "Luckily, this strategy is already implemented in torch geometric, and hence we can use the corresponding data loader:"]}, {"cell_type": "code", "execution_count": 22, "id": "137c9f19", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:22.077104Z", "iopub.status.busy": "2023-10-11T16:03:22.072997Z", "iopub.status.idle": "2023-10-11T16:03:22.081741Z", "shell.execute_reply": "2023-10-11T16:03:22.081006Z"}, "papermill": {"duration": 0.02352, "end_time": "2023-10-11T16:03:22.083057", "exception": false, "start_time": "2023-10-11T16:03:22.059537", "status": "completed"}, "tags": []}, "outputs": [], "source": ["graph_train_loader = geom_data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n", "graph_val_loader = geom_data.DataLoader(test_dataset, batch_size=BATCH_SIZE) # Additional loader for a larger datasets\n", "graph_test_loader = geom_data.DataLoader(test_dataset, batch_size=BATCH_SIZE)"]}, {"cell_type": "markdown", "id": "cfbc9eee", "metadata": {"papermill": {"duration": 0.012036, "end_time": "2023-10-11T16:03:22.106505", "exception": false, "start_time": "2023-10-11T16:03:22.094469", "status": "completed"}, "tags": []}, "source": ["Let's load a batch below to see the batching in action:"]}, {"cell_type": "code", "execution_count": 23, "id": "30662c4c", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:22.132189Z", "iopub.status.busy": "2023-10-11T16:03:22.132000Z", "iopub.status.idle": "2023-10-11T16:03:22.144650Z", "shell.execute_reply": "2023-10-11T16:03:22.143918Z"}, "papermill": {"duration": 0.026959, "end_time": "2023-10-11T16:03:22.145991", "exception": false, "start_time": "2023-10-11T16:03:22.119032", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Batch: DataBatch(edge_index=[2, 1512], x=[687, 7], edge_attr=[1512, 4], y=[38], batch=[687], ptr=[39])\n", "Labels: tensor([1, 1, 1, 0, 0, 0, 1, 1, 1, 0])\n", "Batch indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2])\n"]}], "source": ["batch = next(iter(graph_test_loader))\n", "print(\"Batch:\", batch)\n", "print(\"Labels:\", batch.y[:10])\n", "print(\"Batch indices:\", batch.batch[:40])"]}, {"cell_type": "markdown", "id": "21440b10", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.012337, "end_time": "2023-10-11T16:03:22.171681", "exception": false, "start_time": "2023-10-11T16:03:22.159344", "status": "completed"}, "tags": []}, "source": ["We have 38 graphs stacked together for the test dataset.\n", "The batch indices, stored in `batch`, show that the first 12 nodes belong to the first graph,\n", "the next 22 to the second graph, and so on.\n", "These indices are important for performing the final prediction.\n", "To perform a prediction over a whole graph, we usually perform a pooling operation over all nodes after running the GNN model.\n", "In this case, we will use the average pooling.\n", "Hence, we need to know which nodes should be included in which average pool.\n", "Using this pooling, we can already create our graph network below.\n", "Specifically, we reuse our class `GNNModel` from before,\n", "and simply add an average pool and single linear layer for the graph prediction task."]}, {"cell_type": "code", "execution_count": 24, "id": "a0b2ff13", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:22.200230Z", "iopub.status.busy": "2023-10-11T16:03:22.199343Z", "iopub.status.idle": "2023-10-11T16:03:22.206433Z", "shell.execute_reply": "2023-10-11T16:03:22.205598Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.028572, "end_time": "2023-10-11T16:03:22.211969", "exception": false, "start_time": "2023-10-11T16:03:22.183397", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class GraphGNNModel(nn.Module):\n", " def __init__(self, c_in, c_hidden, c_out, dp_rate_linear=0.5, **kwargs):\n", " \"\"\"GraphGNNModel.\n", "\n", " Args:\n", " c_in: Dimension of input features\n", " c_hidden: Dimension of hidden features\n", " c_out: Dimension of output features (usually number of classes)\n", " dp_rate_linear: Dropout rate before the linear layer (usually much higher than inside the GNN)\n", " kwargs: Additional arguments for the GNNModel object\n", " \"\"\"\n", " super().__init__()\n", " self.GNN = GNNModel(c_in=c_in, c_hidden=c_hidden, c_out=c_hidden, **kwargs) # Not our prediction output yet!\n", " self.head = nn.Sequential(nn.Dropout(dp_rate_linear), nn.Linear(c_hidden, c_out))\n", "\n", " def forward(self, x, edge_index, batch_idx):\n", " \"\"\"Forward.\n", "\n", " Args:\n", " x: Input features per node\n", " edge_index: List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)\n", " batch_idx: Index of batch element for each node\n", " \"\"\"\n", " x = self.GNN(x, edge_index)\n", " x = geom_nn.global_mean_pool(x, batch_idx) # Average pooling\n", " x = self.head(x)\n", " return x"]}, {"cell_type": "markdown", "id": "b553c870", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.013432, "end_time": "2023-10-11T16:03:22.242667", "exception": false, "start_time": "2023-10-11T16:03:22.229235", "status": "completed"}, "tags": []}, "source": ["Finally, we can create a PyTorch Lightning module to handle the training.\n", "It is similar to the modules we have seen before and does nothing surprising in terms of training.\n", "As we have a binary classification task, we use the Binary Cross Entropy loss."]}, {"cell_type": "code", "execution_count": 25, "id": "033fde2f", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:22.269885Z", "iopub.status.busy": "2023-10-11T16:03:22.269100Z", "iopub.status.idle": "2023-10-11T16:03:22.284187Z", "shell.execute_reply": "2023-10-11T16:03:22.283350Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.030592, "end_time": "2023-10-11T16:03:22.285650", "exception": false, "start_time": "2023-10-11T16:03:22.255058", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class GraphLevelGNN(L.LightningModule):\n", " def __init__(self, **model_kwargs):\n", " super().__init__()\n", " # Saving hyperparameters\n", " self.save_hyperparameters()\n", "\n", " self.model = GraphGNNModel(**model_kwargs)\n", " self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()\n", "\n", " def forward(self, data, mode=\"train\"):\n", " x, edge_index, batch_idx = data.x, data.edge_index, data.batch\n", " x = self.model(x, edge_index, batch_idx)\n", " x = x.squeeze(dim=-1)\n", "\n", " if self.hparams.c_out == 1:\n", " preds = (x > 0).float()\n", " data.y = data.y.float()\n", " else:\n", " preds = x.argmax(dim=-1)\n", " loss = self.loss_module(x, data.y)\n", " acc = (preds == data.y).sum().float() / preds.shape[0]\n", " return loss, acc\n", "\n", " def configure_optimizers(self):\n", " # High lr because of small dataset and small model\n", " optimizer = optim.AdamW(self.parameters(), lr=1e-2, weight_decay=0.0)\n", " return optimizer\n", "\n", " def training_step(self, batch, batch_idx):\n", " loss, acc = self.forward(batch, mode=\"train\")\n", " self.log(\"train_loss\", loss)\n", " self.log(\"train_acc\", acc)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " _, acc = self.forward(batch, mode=\"val\")\n", " self.log(\"val_acc\", acc)\n", "\n", " def test_step(self, batch, batch_idx):\n", " _, acc = self.forward(batch, mode=\"test\")\n", " self.log(\"test_acc\", acc)"]}, {"cell_type": "markdown", "id": "e1c854a2", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.012575, "end_time": "2023-10-11T16:03:22.311353", "exception": false, "start_time": "2023-10-11T16:03:22.298778", "status": "completed"}, "tags": []}, "source": ["Below we train the model on our dataset. It resembles the typical training functions we have seen so far."]}, {"cell_type": "code", "execution_count": 26, "id": "d031d1d4", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:22.337771Z", "iopub.status.busy": "2023-10-11T16:03:22.337052Z", "iopub.status.idle": "2023-10-11T16:03:22.353709Z", "shell.execute_reply": "2023-10-11T16:03:22.352823Z"}, "papermill": {"duration": 0.031695, "end_time": "2023-10-11T16:03:22.355074", "exception": false, "start_time": "2023-10-11T16:03:22.323379", "status": "completed"}, "tags": []}, "outputs": [], "source": ["def train_graph_classifier(model_name, **model_kwargs):\n", " L.seed_everything(42)\n", "\n", " # Create a PyTorch Lightning trainer with the generation callback\n", " root_dir = os.path.join(CHECKPOINT_PATH, \"GraphLevel\" + model_name)\n", " os.makedirs(root_dir, exist_ok=True)\n", " trainer = L.Trainer(\n", " default_root_dir=root_dir,\n", " callbacks=[ModelCheckpoint(save_weights_only=True, mode=\"max\", monitor=\"val_acc\")],\n", " accelerator=\"cuda\",\n", " devices=AVAIL_GPUS,\n", " max_epochs=500,\n", " enable_progress_bar=False,\n", " )\n", " trainer.logger._default_hp_metric = None\n", "\n", " # Check whether pretrained model exists. If yes, load it and skip training\n", " pretrained_filename = os.path.join(CHECKPOINT_PATH, \"GraphLevel%s.ckpt\" % model_name)\n", " if os.path.isfile(pretrained_filename):\n", " print(\"Found pretrained model, loading...\")\n", " model = GraphLevelGNN.load_from_checkpoint(pretrained_filename)\n", " else:\n", " L.seed_everything(42)\n", " model = GraphLevelGNN(\n", " c_in=tu_dataset.num_node_features,\n", " c_out=1 if tu_dataset.num_classes == 2 else tu_dataset.num_classes,\n", " **model_kwargs,\n", " )\n", " trainer.fit(model, graph_train_loader, graph_val_loader)\n", " model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)\n", "\n", " # Test best model on validation and test set\n", " train_result = trainer.test(model, dataloaders=graph_train_loader, verbose=False)\n", " test_result = trainer.test(model, dataloaders=graph_test_loader, verbose=False)\n", " result = {\"test\": test_result[0][\"test_acc\"], \"train\": train_result[0][\"test_acc\"]}\n", " return model, result"]}, {"cell_type": "markdown", "id": "97cb8ad3", "metadata": {"papermill": {"duration": 0.012021, "end_time": "2023-10-11T16:03:22.379407", "exception": false, "start_time": "2023-10-11T16:03:22.367386", "status": "completed"}, "tags": []}, "source": ["Finally, let's perform the training and testing.\n", "Feel free to experiment with different GNN layers, hyperparameters, etc."]}, {"cell_type": "code", "execution_count": 27, "id": "b139207e", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:22.405188Z", "iopub.status.busy": "2023-10-11T16:03:22.404648Z", "iopub.status.idle": "2023-10-11T16:03:22.515903Z", "shell.execute_reply": "2023-10-11T16:03:22.510992Z"}, "papermill": {"duration": 0.125803, "end_time": "2023-10-11T16:03:22.517426", "exception": false, "start_time": "2023-10-11T16:03:22.391623", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Global seed set to 42\n"]}, {"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v2.0.9.post0. To apply the upgrade to your files permanently, run `python -m lightning.pytorch.utilities.upgrade_checkpoint --file saved_models/GNNs/GraphLevelGraphConv.ckpt`\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:490: PossibleUserWarning: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.\n", " rank_zero_warn(\n", "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/data.py:76: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", " warning_cache.warn(\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Found pretrained model, loading...\n"]}], "source": ["model, result = train_graph_classifier(\n", " model_name=\"GraphConv\", c_hidden=256, layer_name=\"GraphConv\", num_layers=3, dp_rate_linear=0.5, dp_rate=0.0\n", ")"]}, {"cell_type": "code", "execution_count": 28, "id": "1f6c10e3", "metadata": {"execution": {"iopub.execute_input": "2023-10-11T16:03:22.551971Z", "iopub.status.busy": "2023-10-11T16:03:22.551207Z", "iopub.status.idle": "2023-10-11T16:03:22.556279Z", "shell.execute_reply": "2023-10-11T16:03:22.555416Z"}, "papermill": {"duration": 0.026658, "end_time": "2023-10-11T16:03:22.557748", "exception": false, "start_time": "2023-10-11T16:03:22.531090", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Train performance: 92.67%\n", "Test performance: 92.11%\n"]}], "source": ["print(\"Train performance: %4.2f%%\" % (100.0 * result[\"train\"]))\n", "print(\"Test performance: %4.2f%%\" % (100.0 * result[\"test\"]))"]}, {"cell_type": "markdown", "id": "ce1459df", "metadata": {"papermill": {"duration": 0.016168, "end_time": "2023-10-11T16:03:22.587068", "exception": false, "start_time": "2023-10-11T16:03:22.570900", "status": "completed"}, "tags": []}, "source": ["The test performance shows that we obtain quite good scores on an unseen part of the dataset.\n", "It should be noted that as we have been using the test set for validation as well, we might have overfitted slightly to this set.\n", "Nevertheless, the experiment shows us that GNNs can be indeed powerful to predict the properties of graphs and/or molecules."]}, {"cell_type": "markdown", "id": "9a2e0201", "metadata": {"papermill": {"duration": 0.016881, "end_time": "2023-10-11T16:03:22.617035", "exception": false, "start_time": "2023-10-11T16:03:22.600154", "status": "completed"}, "tags": []}, "source": ["## Conclusion\n", "\n", "In this tutorial, we have seen the application of neural networks to graph structures.\n", "We looked at how a graph can be represented (adjacency matrix or edge list),\n", "and discussed the implementation of common graph layers: GCN and GAT.\n", "The implementations showed the practical side of the layers, which is often easier than the theory.\n", "Finally, we experimented with different tasks, on node-, edge- and graph-level.\n", "Overall, we have seen that including graph information in the predictions can be crucial for achieving high performance.\n", "There are a lot of applications that benefit from GNNs,\n", "and the importance of these networks will likely increase over the next years."]}, {"cell_type": "markdown", "id": "1ef601bc", "metadata": {"papermill": {"duration": 0.013155, "end_time": "2023-10-11T16:03:22.643763", "exception": false, "start_time": "2023-10-11T16:03:22.630608", "status": "completed"}, "tags": []}, "source": ["## Congratulations - Time to Join the Community!\n", "\n", "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n", "movement, you can do so in the following ways!\n", "\n", "### Star [Lightning](https://github.com/Lightning-AI/lightning) on GitHub\n", "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n", "tools we're building.\n", "\n", "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n", "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n", "and share your interests in `#general` channel\n", "\n", "\n", "### Contributions !\n", "The best way to contribute to our community is to become a code contributor! At any time you can go to\n", "[Lightning](https://github.com/Lightning-AI/lightning) or [Bolt](https://github.com/Lightning-AI/lightning-bolts)\n", "GitHub Issues page and filter for \"good first issue\".\n", "\n", "* [Lightning good first issue](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* [Bolt good first issue](https://github.com/Lightning-AI/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* You can also contribute your own notebooks with useful examples !\n", "\n", "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", "\n", "[![Pytorch Lightning](){height=\"60px\" width=\"240px\"}](https://pytorchlightning.ai)"]}, {"cell_type": "raw", "metadata": {"raw_mimetype": "text/restructuredtext"}, "source": [".. customcarditem::\n", " :header: Tutorial 6: Basics of Graph Neural Networks\n", " :card_description: In this tutorial, we will discuss the application of neural networks on graphs. Graph Neural Networks (GNNs) have recently gained increasing popularity in both applications...\n", " :tags: Graph,GPU/TPU,UvA-DL-Course\n", " :image: _static/images/course_UvA-DL/06-graph-neural-networks.jpg"]}], "metadata": {"jupytext": {"cell_metadata_filter": "colab_type,colab,id,-all", "formats": "ipynb,py:percent", "main_language": "python"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12"}, "papermill": {"default_parameters": {}, "duration": 18.450614, "end_time": "2023-10-11T16:03:23.982888", "environment_variables": {}, "exception": null, "input_path": "course_UvA-DL/06-graph-neural-networks/GNN_overview.ipynb", "output_path": ".notebooks/course_UvA-DL/06-graph-neural-networks.ipynb", "parameters": {}, "start_time": "2023-10-11T16:03:05.532274", "version": "2.4.0"}}, "nbformat": 4, "nbformat_minor": 5}