T035 · GNN-based molecular property prediction

Note: This talktorial is a part of TeachOpenCADD, a platform that aims to teach domain-specific skills and to provide pipeline templates as starting points for research projects.

Authors:

Aim of this talktorial

In this tutorial, we will first explain the basic concepts of graph neural networks (GNNs) and present two different GNN architectures. We apply our neural networks to the QM9 dataset, which is a dataset containing small molecules. With this dataset, we want to predict molecular properties. We demonstrate how to train and evaluate GNNs step by step using PyTorch Geometric.

Contents in Theory

  • GNN Tasks

  • Message Passing

  • Graph Convolutional Network (GCN)

  • Graph Isomorphism Network (GIN)

  • Training a GNN

  • Applications of GNNs

Contents in Practical

  • Dataset

  • Defining a GCN and GIN

  • Training a GNN

  • Evaluating the model

References

Theory

Graph Neural Networks

There are several ways to represent molecules which are explained and discussed in Talktorial T033. If we work with molecules, one intuitive approach to apply deep learning to certain tasks is to make use of the graph structure of molecules. Graph neural networks can directly work on given graphs. Molecules can easily be represented as a graph, as seen in Figure 1. Given a graph \(G=(V, E)\), \(V\) describes the vertices or nodes. In molecular graphs, a node \(v_i \in \mathbb{R}^{d_v}\) represents an atom. Nodes can have \(d_v\) different features, such as atomic number and chirality. Edges usually correspond to covalent bonds between the atoms. Each edge \(e_{ij} \in \mathbb{R}^{d_e}\) is described by \(d_e\) number of features, which usually represent the bond type. A graph neural network is a network consisting of learnable and differentiable functions that are invariant for graph permutations. Graph neural networks consist of so-called message-passing layers which will be explained in more detail below, followed by more specific explanations of two different GNN architectures.

simple_graph

Figure 1: Molecular graph overview. Figure taken from [1]

GNN Tasks

We can perform different tasks with a GNN:

  • Graph-level tasks: one application would be to predict a specific property of the entire graph. This can be a classification task such as toxicity prediction or a regression task. In this tutorial, we will implement a regression task to predict molecular properties. Another graph-level task would be to predict entirely new graphs/molecules. This is especially relevant in the area of drug discovery, where new drug candidates are of interest.

  • Node-level tasks: we can predict a property of a specific node in the graph, e.g. the atomic charges of each atom. We could also predict a new node to be added to the graph. This is often done for molecule generation, where we want to add multiple atoms to form new molecules one after the other.

  • Edge-level tasks: we can predict edge properties, e.g. intramolecular forces between atoms, or a new edge in the graph. In the molecule generation context, we want to predict potential bonds between the atoms. Edge prediction can also be used to infer connections/interactions e.g. in a gene regulatory network.

Message Passing

Instead of MLP layers in standard neural networks, GNNs have message-passing layers, where we collect information about the neighboring nodes. For each node \(v\), we look at the direct neighbors \(N(v)\) and gather information. Then all the information is aggregated, for example with summation. Then we update the node \(v\) with the aggregated messages. If we perform this aggregation and combining, each node contains the information about the direct neighbors (1-hop). If we repeat this \(n\) times, we aggregate information about the \(n_{th}\) closest neighbors (\(n\) -hop).

\[a_v^{(k)} = \text{aggregate}^{(k)} (\{ h_u^{(k-1)}: u \in N(v) \})\]
\[h_v^{(k)} = \text{combine}^{(k)} (h_v^{(k-1)}, a_v^{(k)})\]

where \(h_v^{(k)}\) is the embedding of node \(v\) at layer \(k\), \(N(v)\) are the neighbors of node \(v\).

simple_graph

Figure 2: Message passing overview. Figure taken from [2]

One important property of a GNN is permutation invariance. This means that changing the order of nodes in the graph should not affect the outcome. For example, when working with adjacency matrices, changing the order of nodes would mean swapping rows and/or columns. However, this does not change any properties of a graph, but the input would differ. In GNNs, we want to overcome this. We, therefore need an aggregation function and a combining function that are permutation invariant, such as using the mean, the maximum or a sum. Using a permutation invariant aggregation function ensures that the graph-level outputs are also invariant to permutations. In this tutorial, we will explain graph-level regression tasks and in the following, we will present two different GNN architectures.

GCN

One of the simplest GNNs is a Graph Convolutional Network (GCN). For GCNs, we sum over all neighbors of node \(v\), including the node \(v\) itself and aggregate all information. We divide it by the degree to keep the range of different nodes comparable. The node-wise aggregation function for layer \(k\) is

\[h_v^{(k)} = \Theta^{\top} \sum_{u \in N(v) \cup \{v\}} \frac{1}{\sqrt{d_v d_u}} \cdot h_u^{(k-1)}\]

where \(d_j\) and \(d_i\) denote the degree of node \(j\) and \(i\), respectively, and \(\Theta\) represent trainable weights.

One disadvantage of GCNs is, that they use a mean-based aggregation and this function is not injective. This means that different graphs can lead to the same graph embedding and the network cannot distinguish between the two graphs anymore. One example is visualized in Figure 3 below. Assuming the node and edge properties are identical, GCNs could create the same hidden embedding for these two graphs.

simple_graph

Figure 3: Two indistinguishable graphs using GCNs

GIN

Another type of GNN is the Graph Isomorphism Network (GIN), which has been proposed to overcome the disadvantages of GCNs explained above. The aggregation function is defined as follows

\[h_v^{(k)} = h_\Theta((1+ \epsilon) \cdot h_v^{(k-1)} + \sum_{u \in N(v)} h_u^{(k-1)} )\]

The aggregation function here is a sum. The parameter \(\epsilon\) decides on the importance of the node \(v\) compared to its neighbors. \(h_\Theta\) represents a neural network for all nodes \(v\), for example an MLP. The sum aggregation function is more powerful compared to a mean aggregation (used in the GCN above) since we can distinguish between more similar graphs, for example, the two graphs in Figure 3.

GINs are a good example of a simple network, which still is quite powerful, as they are quite good at distinguishing between non-isomorphic graphs. Two graphs are isomorphic if the graphs are identical except for node permutations. While this might be easily visible for smaller graphs, it is a complex problem for larger graphs. When working with GNNs, we would like the model to give us the same output if the input graphs are isomorphic. On the other hand, we also want the model to be able to differentiate between non-isomorphic graphs and output (possibly) different results. GINs can differentiate between non-isomorphic graphs a lot better than other simple GNNs such as GCN and GraphSage. For example, the two graphs in the figure above have different embeddings using GINs, since we are using a sum-based aggregation without any scaling or averaging. It is proven that GINs are as powerful as the Weisfeiler-Lehman test, a common (but not perfect) isomorphism test for graphs. If you are interested in the WL test or more details on GINs, have a look at the original publication about GINs or this blog post about the WL test. GINs cannot distinguish between all non-isomorphic graphs, one example is in Figure 4. Each node in both graphs has the same number of neighbors, therefore \(h_v\) is the same for all nodes \(v\) in both graphs.

simple_graph

Figure 4: Two indistinguishable graphs using GINs

Training a GNN

Similar to training a standard neural network, different design choices and hyperparameters need to be decided on. We will shortly present some concepts commonly used in neural networks, which can also be used for GNNs. Loss functions and activation functions are already discussed in Talktorial T022. We also used the mean squared error loss as well as the ReLU activation function.

Batching

It is common to do batching when training a GNN to improve performance. The batch size indicates how many samples from the training data are fed to the neural network before updating model parameters. Choosing the right batch size is a trade-off between computational cost and generalization. For larger batches, the model is updated fewer times and the training is a lot faster. Models using smaller batches can generalize better, meaning that the test error can be lowered. Since this is not the only hyperparameter, choosing the batch size is also linked to the learning rate, the number of training epochs etc. One way to implement batching in GNNs is to stack the adjacency matrices of all graphs in the batch diagonally and to concatenate the node feature matrices. However, graphs (especially molecular graphs) can have rather sparse adjacency matrices. In this case, it is more efficient to use a sparse representation for the edges. PyTorch Geometric for example uses edge lists, where only the indexes of present edges are saved. These lists are concatenated during batching.

simple_graph

Figure 4: Batching in GNNs, image taken from [3]

Pooling

Pooling layers help a neural network to reduce dimensionality. This makes the model more robust to variations. In graphs, global pooling layers can produce a graph embedding from the different node embeddings. There are different ways for pooling, the most common ones are: mean, max and sum, which are permutation invariant. Hence, pooling layers are also permutation invariant. For our GCN, we use a global mean pooling layer and for our GIN we use a global sum pooling layer, as it was proposed in the original publications listed in the references above. Pooling layers are also very useful to reduce the size of the layer to a fixed size for graph representation, therefore global pooling layers are also referred to as readout layers.

Dropout (Regularization)

One common problem in deep learning tasks is overfitting. This usually means that the dataset used to train the neural network is too small. Applying an overfitted network to a different dataset then leads to a high error in prediction, since the model is fit too closely to the training data and does not generalize well enough. To reduce overfitting, one approach is to use dropout layers, which can lead to a better generalization of the model. During training, nodes are randomly dropped. The probability of dropping nodes is another hyperparameter to be fixed. In each iteration, the nodes in a neural network (and the number of nodes) can therefore differ. This means we incorporate more noise and therefore force the neural network to generalize better.

Applications of GNNs

GNNs can be applied to a wide variety of tasks involving graphs, these could be based on small molecules (like in this tutorial), but also proteins (see Talktorial T038), gene regulatory networks and many more. Some applications are:

  • Property prediction of molecules, such as toxicity and solubility (see: Wieder, Oliver, et al. A compact review of molecular property prediction with graph neural networks Drug Discovery Today: Technologies 37 (2020): 1-12. and MoleculeNet: a benchmark for molecular machine learning by Zhenqin Wu et al., Chemical science 9.2 (2018): 513-530.)

  • Generating new molecules, which is especially relevant in the field of drug discovery (for more details, read this review by Tong, Xiaochu, et al. Generative models for De Novo drug design Journal of Medicinal Chemistry 64.19 (2021): 14011-14027)

  • Inferring new interactions/associations in biological networks, such as gene regulatory networks or protein-protein interaction networks

For a more detailed overview of GNNs and their applications, you can read the article by Zhang, Xiao-Meng, et al. Graph Neural Networks and Their Current Applications in Bioinformatics Frontiers in Genetics 12 (2021).

Practical

For the practical section, we have used PyTorch and PyTorch-Geometric, which helps us to handle graph data efficiently. PyTorch Geometric for example uses sparse matrix representations and implemented efficient graph batching. However, there are also different graph libraries for Python, such as the Deep Graph Library which is not covered in this tutorial.

[1]:
import math
import numpy
import pandas as pd
import matplotlib.pylab as plt
from matplotlib.ticker import MaxNLocator
from pathlib import Path

import torch
import torch.nn.functional as Fun
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU
[2]:
import sys
if sys.platform.startswith(("linux", "darwin")):
    !mamba install -q -y -c pyg pyg
[3]:
from torch_geometric.datasets import QM9
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool, global_add_pool
[4]:
# specify the local data path
HERE = Path(_dh[-1])
DATA = HERE / "data"

Dataset

For this tutorial, we use the QM9 dataset, which can be imported with torch_geometric. The dataset is part of a benchmarking collection called MoleculeNet. It contains around \(130,000\) small molecules with at most 9 heavy atoms as well as various molecular properties. We will choose one property which we will then try to predict.

[5]:
# load dataset
qm9 = QM9(root=DATA)
qm9[0]
Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
Extracting ~/teachopencadd/teachopencadd/talktorials/T035_graph_neural_networks/data/raw/qm9.zip
Downloading https://ndownloader.figshare.com/files/3195404
Processing...
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 133885/133885 [03:03<00:00, 728.09it/s]
Done!
[5]:
Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], z=[5], name='gdb_1', idx=[1])

If you are running this tutorial for the first time, the dataset will be downloaded here. As an example, the first molecule from the dataset is shown below. The dataset contains the following information: - x: contains the different node features, such as atomic number, chirality, hybridization, is aromatic, is ring, - edge_index: adjacency matrix, representing the covalent bonds between the atoms, - edge_attributes: contains the edge features (bond type, is conjugated, stereo configuration), - pos: 3D atom coordinates, we will not use them in this tutorial, - z: atomic numbers, - y: target values, this dataset contains 19 different properties describing each molecule, such as dipole moment, different molecular energies, enthalpy and rotational constants.

In this tutorial, we only use x, edge_index and y to keep it simple. While the dataset has many regression targets, we will only focus on one of the tasks, which is the prediction of the dipole moment \(\mu\). For this tutorial, we only sample a subset of QM9. This keeps the runtime low and this is still enough to show some first results. The dataset is split into training, validation and test sets with a \(80:10:10\) split ratio. In addition, we normalize the training data (\(\mu=0, \sigma=1\)) and apply the same mean and standard deviation to the test and validation set.

[6]:
# get one regression target
y_target = pd.DataFrame(qm9.data.y.numpy())
qm9.data.y = torch.Tensor(y_target[0])

qm9 = qm9.shuffle()

# data split
data_size = 30000
train_index = int(data_size * 0.8)
test_index = train_index + int(data_size * 0.1)
val_index = test_index + int(data_size * 0.1)


# normalizing the data
data_mean = qm9.data.y[0:train_index].mean()
data_std = qm9.data.y[0:train_index].std()

qm9.data.y = (qm9.data.y - data_mean) / data_std

# datasets into DataLoader
train_loader = DataLoader(qm9[0:train_index], batch_size=64, shuffle=True)
test_loader = DataLoader(qm9[train_index:test_index], batch_size=64, shuffle=True)
val_loader = DataLoader(qm9[test_index:val_index], batch_size=64, shuffle=True)
~/.miniconda3/envs/teachopencadd/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. The data of the dataset is already cached, so any modifications to `data` will not be reflected when accessing its elements. Clearing the cache now by removing all elements in `dataset._data_list`. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.
  warnings.warn(msg)
~/.miniconda3/envs/teachopencadd/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.
  warnings.warn(msg)
~/.miniconda3/envs/teachopencadd/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. The given 'InMemoryDataset' only references a subset of examples of the full dataset, but 'data' will contain information of the full dataset. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.
  warnings.warn(msg)

Defining a GCN and a GIN

The following two Python classes are the two GNNs we will consider in this tutorial. Both have 3 convolutional layers, one global pooling layer, linear layers, ReLU activation functions between the layers and a dropout layer.

[7]:
class GCN(torch.nn.Module):
    """Graph Convolutional Network class with 3 convolutional layers and a linear layer"""

    def __init__(self, dim_h):
        """init method for GCN

        Args:
            dim_h (int): the dimension of hidden layers
        """
        super().__init__()
        self.conv1 = GCNConv(qm9.num_features, dim_h)
        self.conv2 = GCNConv(dim_h, dim_h)
        self.conv3 = GCNConv(dim_h, dim_h)
        self.lin = torch.nn.Linear(dim_h, 1)

    def forward(self, data):
        e = data.edge_index
        x = data.x

        x = self.conv1(x, e)
        x = x.relu()
        x = self.conv2(x, e)
        x = x.relu()
        x = self.conv3(x, e)
        x = global_mean_pool(x, data.batch)

        x = Fun.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x
[8]:
class GIN(torch.nn.Module):
    """Graph Isomorphism Network class with 3 GINConv layers and 2 linear layers"""

    def __init__(self, dim_h):
        """Initializing GIN class

        Args:
            dim_h (int): the dimension of hidden layers
        """
        super(GIN, self).__init__()
        self.conv1 = GINConv(
            Sequential(Linear(11, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU())
        )
        self.conv2 = GINConv(
            Sequential(
                Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU()
            )
        )
        self.conv3 = GINConv(
            Sequential(
                Linear(dim_h, dim_h), BatchNorm1d(dim_h), ReLU(), Linear(dim_h, dim_h), ReLU()
            )
        )
        self.lin1 = Linear(dim_h, dim_h)
        self.lin2 = Linear(dim_h, 1)

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch

        # Node embeddings
        h = self.conv1(x, edge_index)
        h = h.relu()
        h = self.conv2(h, edge_index)
        h = h.relu()
        h = self.conv3(h, edge_index)

        # Graph-level readout
        h = global_add_pool(h, batch)

        h = self.lin1(h)
        h = h.relu()
        h = Fun.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)

        return h

Training a GNN

When training a GNN (or any neural network), we have a training set, a validation set and a test set. The training set is used for training, the validation set is used to test the loss in each epoch not only on the training set but also on another dataset (monitor generalization performance). The test set is used to calculate the error of the fully trained model using a dataset, which has not been used during the whole training process.

[9]:
def training(loader, model, loss, optimizer):
    """Training one epoch

    Args:
        loader (DataLoader): loader (DataLoader): training data divided into batches
        model (nn.Module): GNN model to train on
        loss (nn.functional): loss function to use during training
        optimizer (torch.optim): optimizer during training

    Returns:
        float: training loss
    """
    model.train()

    current_loss = 0
    for d in loader:
        optimizer.zero_grad()
        d.x = d.x.float()

        out = model(d)

        l = loss(out, torch.reshape(d.y, (len(d.y), 1)))
        current_loss += l / len(loader)
        l.backward()
        optimizer.step()
    return current_loss, model
[10]:
def validation(loader, model, loss):
    """Validation

    Args:
        loader (DataLoader): validation set in batches
        model (nn.Module): current trained model
        loss (nn.functional): loss function

    Returns:
        float: validation loss
    """
    model.eval()
    val_loss = 0
    for d in loader:
        out = model(d)
        l = loss(out, torch.reshape(d.y, (len(d.y), 1)))
        val_loss += l / len(loader)
    return val_loss
[11]:
@torch.no_grad()
def testing(loader, model):
    """Testing

    Args:
        loader (DataLoader): test dataset
        model (nn.Module): trained model

    Returns:
        float: test loss
    """
    loss = torch.nn.MSELoss()
    test_loss = 0
    test_target = numpy.empty((0))
    test_y_target = numpy.empty((0))
    for d in loader:
        out = model(d)
        # NOTE
        # out = out.view(d.y.size())
        l = loss(out, torch.reshape(d.y, (len(d.y), 1)))
        test_loss += l / len(loader)

        # save prediction vs ground truth values for plotting
        test_target = numpy.concatenate((test_target, out.detach().numpy()[:, 0]))
        test_y_target = numpy.concatenate((test_y_target, d.y.detach().numpy()))

    return test_loss, test_target, test_y_target
[12]:
def train_epochs(epochs, model, train_loader, val_loader, path):
    """Training over all epochs

    Args:
        epochs (int): number of epochs to train for
        model (nn.Module): the current model
        train_loader (DataLoader): training data in batches
        val_loader (DataLoader): validation data in batches
        path (string): path to save the best model

    Returns:
        array: returning train and validation losses over all epochs, prediction and ground truth values for training data in the last epoch
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    loss = torch.nn.MSELoss()

    train_target = numpy.empty((0))
    train_y_target = numpy.empty((0))
    train_loss = numpy.empty(epochs)
    val_loss = numpy.empty(epochs)
    best_loss = math.inf

    for epoch in range(epochs):
        epoch_loss, model = training(train_loader, model, loss, optimizer)
        v_loss = validation(val_loader, model, loss)
        if v_loss < best_loss:
            torch.save(model.state_dict(), path)
        for d in train_loader:
            out = model(d)
            if epoch == epochs - 1:
                # record truly vs predicted values for training data from last epoch
                train_target = numpy.concatenate((train_target, out.detach().numpy()[:, 0]))
                train_y_target = numpy.concatenate((train_y_target, d.y.detach().numpy()))

        train_loss[epoch] = epoch_loss.detach().numpy()
        val_loss[epoch] = v_loss.detach().numpy()

        # print current train and val loss
        if epoch % 2 == 0:
            print(
                "Epoch: "
                + str(epoch)
                + ", Train loss: "
                + str(epoch_loss.item())
                + ", Val loss: "
                + str(v_loss.item())
            )
    return train_loss, val_loss, train_target, train_y_target

We have trained both models with 100 epochs and saved the best models under GCN_best-model-parameters.pt and GIN_best-model-parameters.pt. Since this takes some time, we reduced the number of epochs to 10 for this tutorial for demonstration purposes. The results and the plots below are based on the models trained for 100 epochs. If you want to train your own model using our tutorial, you can change the number of epochs and any other parameters in our models (such as learning rate, batch size, etc.).

[13]:
# training GCN for 10 epochs
epochs = 10

model = GCN(dim_h=128)

# Remember to change the path if you want to keep the previously trained model
gcn_train_loss, gcn_val_loss, gcn_train_target, gcn_train_y_target = train_epochs(
    epochs, model, train_loader, test_loader, "GCN_model.pt"
)
Epoch: 0, Train loss: 0.9262555241584778, Val loss: 0.7875796556472778
Epoch: 2, Train loss: 0.8586100339889526, Val loss: 0.7610003352165222
Epoch: 4, Train loss: 0.831976056098938, Val loss: 0.7375475168228149
Epoch: 6, Train loss: 0.8072418570518494, Val loss: 0.6950475573539734
Epoch: 8, Train loss: 0.7751282453536987, Val loss: 0.6763118505477905
[14]:
# Training GIN for 10 epochs
model = GIN(dim_h=64)

# Remember to change the path if you want to keep the previously trained model
gin_train_loss, gin_val_loss, gin_train_target, gin_train_y_target = train_epochs(
    epochs, model, train_loader, test_loader, "GIN_model.pt"
)
Epoch: 0, Train loss: 0.702804684638977, Val loss: 0.6012566685676575
Epoch: 2, Train loss: 0.5309209823608398, Val loss: 0.45285511016845703
Epoch: 4, Train loss: 0.5022059082984924, Val loss: 0.4036828875541687
Epoch: 6, Train loss: 0.46470388770103455, Val loss: 0.4122920632362366
Epoch: 8, Train loss: 0.4486030638217926, Val loss: 0.36034974455833435

Evaluating the model

For evaluation, we use a validation dataset to find the best model and a test set, to test our model on unseen data. First, we plotted the losses of our training and validation sets. As expected, the GIN model has a lower training and validation loss.

[15]:
def plot_loss(gcn_train_loss, gcn_val_loss, gin_train_loss, gin_val_loss):
    """Plot the loss for each epoch

    Args:
        epochs (int): number of epochs
        train_loss (array): training losses for each epoch
        val_loss (array): validation losses for each epoch
    """
    plt.plot(gcn_train_loss, label="Train loss (GCN)")
    plt.plot(gcn_val_loss, label="Val loss (GCN)")
    plt.plot(gin_train_loss, label="Train loss (GIN)")
    plt.plot(gin_val_loss, label="Val loss (GIN)")
    plt.legend()
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.title("Model Loss")
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.show()
[16]:
def plot_targets(pred, ground_truth):
    """Plot true vs predicted value in a scatter plot

    Args:
        pred (array): predicted values
        ground_truth (array): ground truth values
    """
    f, ax = plt.subplots(figsize=(6, 6))
    ax.scatter(pred, ground_truth, s=0.5)
    plt.xlim(-2, 7)
    plt.ylim(-2, 7)
    ax.axline((1, 1), slope=1)
    plt.xlabel("Predicted Value")
    plt.ylabel("Ground truth")
    plt.title("Ground truth vs prediction")
    plt.show()

When looking at the losses for each epoch, we can see that the GIN model performs better overall. We can also see that the training loss is often lower compared to the validation loss. This is normal since the training loss describes the error of the model using the training set, which is the dataset used for improving the model. The validation loss is calculated on a separate dataset, which is not used for updating the model weights. Therefore, the error is often higher. This is also the reason, the validation loss sometimes fluctuates more. As long as both losses show a decreasing tendency, this is not problematic. It is important to have a low training loss and a low validation loss.

[17]:
# Plot overall losses of GIN and GCN

plot_loss(gcn_train_loss, gcn_val_loss, gin_train_loss, gin_val_loss)
../_images/talktorials_T035_graph_neural_networks_37_0.png

Then, we also plotted the actual predictions of our target value compared to the ground truth for the GIN model, since this model performs better.

[18]:
# Plot target and prediction for training data

plot_targets(gin_train_target, gin_train_y_target)
../_images/talktorials_T035_graph_neural_networks_39_0.png

Below, we have calculated the test loss for both the GCN and the GIN. We also plot the predicted dipole moment compared to the ground truth for both models. If we are interested in the actual numeric range of the predicted dipole moment, the normalization applied during the preprocessing should be subtracted again. Since we only visualize the data in our evaluation, this does not make a difference. In the figures below, we can see that the GIN model performs a lot better compared to the GCN since the test error is lower.

[19]:
# Calculate test loss from the best GCN model (according to validation loss)

# load our model
model = GCN(dim_h=128)
model.load_state_dict(torch.load("GCN_best-model-parameters.pt"))

# calculate test loss
gcn_test_loss, gcn_test_target, gcn_test_y = testing(test_loader, model)

print("Test Loss for GCN: " + str(gcn_test_loss.item()))

# plot prediction vs ground truth
plot_targets(gcn_test_target, gcn_test_y)
Test Loss for GCN: 0.5251887440681458
../_images/talktorials_T035_graph_neural_networks_41_1.png
[20]:
# Calculate test loss from the best GIN model (according to validation loss)

# load our model
model = GIN(dim_h=64)
model.load_state_dict(torch.load("GIN_best-model-parameters.pt"))

# calculate test loss
gin_test_loss, gin_test_target, gin_test_y = testing(test_loader, model)

print("Test Loss for GIN: " + str(gin_test_loss.item()))

# plot prediction vs ground truth
plot_targets(gin_test_target, gin_test_y)
Test Loss for GIN: 0.3028673827648163
../_images/talktorials_T035_graph_neural_networks_42_1.png

Discussion

In this talktorial we have first presented two different graph neural networks. We applied these two GNNs to a molecular dataset to predict molecular properties. We showed how to train and evaluate a simple GNN using pytorch and pytorch_geometric. This model can be used for any type of graph-level regression and, with small changes (such as the loss function), graph-level classification is also easy.

One disadvantage of GNNs is that the quality of the model is extremely data-dependent, the more of the chemical space is covered in the training set, the better the performance would be on new, unseen data. In addition, training a model can be rather complex, since there are many parameters influencing the model. Model parameters, such as learning rate, batch size and number of hidden dimensions could be more thoroughly evaluated to improve the model. To apply this to real tasks, first, a bigger dataset is needed. When using the whole QM9 dataset and not only a small subset, the performance will increase. In addition, the model parameters can also still be optimized. The model architecture can also still be adapted. These changes could lead to longer runtimes, which is why we have chosen this simplified version for demonstration purposes.

Quiz

  1. What is the difference between a GCN and GIN?

  2. How would you change the model for a classification task?

  3. What other parameters can be tuned for better model performance?