T036 · An introduction to E(3)-invariant graph neural networks¶
Note: This talktorial is a part of TeachOpenCADD, a platform that aims to teach domain-specific skills and to provide pipeline templates as starting points for research projects.
Authors:
Joschka Groß, 2022, Chair for Modelling and Simulation, NextAID project, Saarland University
Aim of this talktorial¶
This talktorial is supposed to serve as an introduction to machine learning on point cloud representations of molecules with 3D conformer information, i.e., molecular graphs that are embedded into Euclidean space (see Talktorial 033). You will learn why Euclidean equivariance and invariance are important properties of neural networks (NNs) that take point clouds as input and learn how to implement and train such NNs. In addition to discussing them in theory, this notebook also aims to demonstrate the shortcomings of plain graph neural networks (GNNs) when working with point clouds practically.
Contents in Theory¶
Why 3D coordinates?
Representing molecules as point clouds
Equivariance and Invariance in euclidean space and why we care
How to construct \(\text{E}(n)\)-invariant and equivariant models
The QM9 dataset
Contents in Practical¶
Visualization of point clouds
Set up and inspect the QM9 dataset
Preprocessing
Atomic number distribution and point cloud size
Data split, distribution of regression target electronic spatial extent
Model implementation
Plain “naive Euclidean” GNN
Demo: Plain GNNs are not E(3)-invariant
EGNN model
Demo: Our EGNN is E(3)-invariant
Training and evaluation
Setup
Training the EGNN
Training the plain GNN
Comparative evaluation
References¶
Theoretical¶
Quantum chemistry structures and properties of 134k molecules (QM9): Scientific data (2014)
MoleculeNet: a benchmark for molecular machine learning: Chem. Sci., 2018, 9, 513-530
E(n)-Equivariant Graph Neural Networks: International conference on machine learning (2021), 139, 99323-9332
SE(3)-transformers: 3D roto-translation equivariant attention networks: Advances in Neural Information Processing Systems (2021), 33, 1970-1981
TorchMD-NET: Equivariant Transformers for Neural Network based Molecular Potentials: arXiv preprint (2022)
DiffDock: arXiv preprint (2022)
Practical¶
Theory¶
Why 3D coordinates?¶
Some properties are more easily derived when 3D coordinates are known.
Sometimes the task is to predict properties that are directly linked to Euclidean space, e.g. future atom positions or forces that apply to atoms.
Compared to molecular graph representations, we in principle only gain information. Covalent bonds can still be inferred from atom types and positions because they can be attributed to overlapping atomic orbitals. Note that one could still include structural information s.t. the model does not have to learn this information itself
An example CADD application that may require the use of 3D coordinates is protein-ligand docking (see Talktorial 015). Recent work from 2022 uses E(3) equivariant graph neural networks as the backbone for a generative model that learns to predict potential ligand docking positions (3D coordinates for the atoms of a given ligand) when additionally given protein structures with 3D information as input.
Molecules as point clouds: mathematical background¶
In this talktorial we will focus on atoms and their 3D positions and ignore structural (bond) information. Our mathematical representations of a molecule is thus a point cloud (also see Talktorial T033), i.e., a tuple \((X, Z)\) where \(Z \in \mathbb{R}^{m \times d}\) is a matrix of \(m\) atoms represented by \(d\) features each and \(X \in \mathbb{R}^{m \times 3}\) captures the atom 3D coordinates. We will assume that the coordinates correspond to a specific molecular conformation (see Talktorial T033) of the molecule.
Equivariance and Invariance in Euclidean space and why we care¶
When representing molecules as graphs equi- and/or invariance w.r.t. to node permutations are desirable model properties (Talktorial T033/T035). When working with point clouds, i.e., when atoms/nodes are embedded into Euclidean space, we should also be concerned about Euclidean symmetry groups. These are groups of transformations \(g: \mathbb{R}^n \to \mathbb{R}^n\) that preserve distance, i.e., translations, rotations, reflections, or combinations thereof. For the Euclidean space \(\mathbb{R}^n\) with \(n\) spatial dimensions, one typically distinguishes between
the Euclidean group \(\text{E}(n)\), which consists of all distance-preserving transformations, and
the special Euclidean group \(\text{SE}(n)\), which consists only of translations and rotations.
Say \(\theta\) is a model that learns atom embeddings \(H = \theta(X, Z) \in \mathbb{R}^{m \times q}\) where \(q\) is the number of embedding dimensions. We call \(\theta\) \(\text{E}(n)\)-invariant, if for all \(g \in \text{E}(n)\)
where \(g\) is applied row-wise to \(X\). Put simply the output of \(\theta\) remains unaffected, no matter how we rotate, translate, or reflect the molecule.
If we consider a model that makes predictions about objects which are coupled to the Euclidean space \(X' = \theta(X, Z) \in \mathbb{R}^{m \times n}\) (e.g. future atom positions in a dynamical system), we can define \(\text{E}(n)\)-equivariance as
for all \(g \in \text{E}(n)\) applied in row-wise fashion. This is saying that the output of \(\theta\) is transformed in the same way as its input. Note that this definition can easily be extended to arbitrary Euclidean features (velocities, electromagnetic forces, …).
So, why do we care about these properties?
Let’s assume our goal was to train a model that predicts the docking position of a ligand when given a fixed protein structure, also with 3D coordinates. Would you trust a model that predicted different relative positions for the ligand atoms when the protein was simply rotated by 180 degrees? If your answer is no, then you should consider using a model that is at least \(\text{SE}(3)\)-equivariant. In addition to being a “natural” choice given such considerations, euclidean equivariance empirically also increases the sample complexity (efficiency) of training and improves the model’s ability to generalize to unseen data.
To sum up it may be helpful to address the problem from a slightly different point of view: Point clouds as representations for molecular conformations are not unique. In fact, for one molecular conformation, there are infinitely many valid point cloud representations. If \((X, Z)\) is such a representation then \((g(X), Z)\) with \(g \in \text{E}(3)\) is too and there are infinitely many such \(g\). All \(\text{E}(3)\)-invariance and equivariance are thus saying, is that our machine learning models should not care which of these representations we end up using.
Figure 1: An illustration of a 2D-rotationally equi- and invariant transformation \(\phi\). Taken from the EGNN paper by Satoras et. al.
How to construct \(\text{E}(n)\)-invariant and equivariant models¶
Constructing such models is simple if we focus on the fact that all \(g \in \text{E}(n)\) are distance-preserving. We will not give a fully-fledged proof, but it should not come as a great surprise that a model which only considers relative distances between atoms for computing node (atom) embeddings is guaranteed to be \(E(n)\)-invariant. We can thus define a message passing network \(\theta(Z, X)\) with \(l=1,\ldots,L\) layers where
and \(\psi_0\) computes the initial node embeddings, the \(\phi_l\) MLPs \(\text{}^1\) construct messages and \(\psi_l\) MLPs take care of combining previous embeddings and aggregated messages into new embeddings. The final node embeddings \(H = (h_1^L \ldots h_n^L)^t\) computed by this scheme are \(E(n)\)-invariant.
In the practical part, we will only predict properties that are not directly linked to the Euclidean space, so this kind of network suffices for our purposes. If your goal is to predict e.g. atom positions, you will need to define additional, slightly more sophisticated transformations to ensure that they are \(E(3)\)-equivariant, but they usually follow the same principle of only using distances in their computations. If you want to read up on this you can take a look at these papers
The QM9 dataset¶
The QM9 dataset [1] [2] is part of the MoleculeNet benchmark and consists of ~130k small, organic molecules with up to 9 heavy atoms. It also includes targets for various geometric, energetic, electronic and thermodynamic properties. Crucially, it also includes atom 3D coordinates, which makes it suitable for this talktorial.
Practical¶
For the practical part, we will be working with a version of QM9 that is already included in PyTorch Geometric, as implementing the dataset from scratch would go beyond the scope of this talktorial. We will just inspect the data and briefly discuss how point clouds can be represented by several tensors. Then we will demonstrate how one could use plain GNNs to work with point clouds and why this approach would yield models that are not \(\text{E}(3)\) invariant/equivariant. Finally, you will learn how to implement, train and evaluate equivariant GNNs.
[1]:
import math
import operator
from itertools import chain, product
from functools import partial
from pathlib import Path
from typing import Any, Optional, Callable, Tuple, Dict, Sequence, NamedTuple
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch import Tensor, LongTensor
[2]:
import sys
if sys.platform.startswith(("linux", "darwin")):
!mamba install -q -y -c pyg pyg
!mamba install -q -y -c conda-forge pytorch_scatter
[3]:
import torch_geometric
from torch_geometric.transforms import BaseTransform, Compose
from torch_geometric.datasets import QM9
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset
from torch_geometric.nn.aggr import SumAggregation
import torch_geometric.nn as geom_nn
import matplotlib as mpl
import matplotlib.pyplot as plt
from torch_scatter import scatter
[4]:
# Set path to this notebook
HERE = Path(_dh[-1])
DATA = HERE / "data"
Visualization of point clouds¶
The following auxiliary function plot_point_cloud_3d
will see heavy use later on for the visualization of model input and model outputs. Note that to visualize molecules rather than their tensor representations used for machine learning, it would be better to use e.g. RDKit
or NGLview
.
[5]:
def to_perceived_brightness(rgb: np.ndarray) -> np.ndarray:
"""
Auxiliary function, useful for choosing label colors
with good visibility
"""
r, g, b = rgb
return 0.1 * r + 0.8 * g + 0.1
def plot_point_cloud_3d(
fig: mpl.figure.Figure,
ax_pos: int,
color: np.ndarray,
pos: np.ndarray,
cmap: str = "plasma",
point_size: float = 180.0,
label_axes: bool = False,
annotate_points: bool = True,
remove_axes_ticks: bool = True,
cbar_label: str = "",
) -> mpl.axis.Axis:
"""Visualize colored 3D point clouds.
Parameters
----------
fig : mpl.figure.Figure
The figure for which a new axis object is added for plotting
ax_pos : int
Three-digit integer specifying axis layout and position
(see docs for `mpl.figure.Figure.add_subplot`)
color : np.ndarray
The point colors as a float array of shape `(N,)`
pos : np.ndarray
The point xyz-coordinates as an array
cmap : str, optional
String identifier for a matplotlib colormap.
Is used to map the values in `color` to rgb colors.
, by default "plasma"
point_size : float, optional
The size of plotted points, by default 180.0
label_axes : bool, optional
whether to label x,y and z axes by default False
annotate_points : bool, optional
whether to label points with their index, by default True
cbar_label : str, optional
label for the colorbar, by default ""
Returns
-------
mpl.axis.Axis
The new axis object for the 3D point cloud plot.
"""
cmap = mpl.cm.get_cmap(cmap)
ax = fig.add_subplot(ax_pos, projection="3d")
x, y, z = pos
if remove_axes_ticks:
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_zticklabels([])
if label_axes:
ax.set_xlabel("$x$ coordinate")
ax.set_ylabel("$y$ coordinate")
ax.set_zlabel("$z$ coordinate")
sc = ax.scatter(x, y, z, c=color, cmap=cmap, s=point_size)
plt.colorbar(sc, location="bottom", shrink=0.6, anchor=(0.5, 2), label=cbar_label)
if annotate_points:
_colors = sc.cmap(color)
rgb = _colors[:, :3].transpose()
brightness = to_perceived_brightness(rgb)
for i, (xi, yi, zi, li) in enumerate(zip(x, y, z, brightness)):
ax.text(xi, yi, zi, str(i), None, color=[1 - li] * 3, ha="center", va="center")
return ax
# testing
fig = plt.figure(figsize=(8, 8))
for ax_pos in [221, 222, 223, 224]:
pos = np.random.rand(3, 5)
color = np.random.rand(5)
plot_point_cloud_3d(fig, ax_pos, color, pos)
fig.suptitle("Random test point clouds")
fig.tight_layout()
/tmp/ipykernel_51907/441942460.py:53: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.
cmap = mpl.cm.get_cmap(cmap)
/tmp/ipykernel_51907/441942460.py:84: UserWarning: Tight layout not applied. The left and right margins cannot be made large enough to accommodate all axes decorations.
fig.tight_layout()
[6]:
def plot_model_input(data: Data, fig: mpl.figure.Figure, ax_pos: int) -> mpl.axis.Axis:
"""
Plots 3D point cloud model input represented by a torch geometric
`Data` object. Use atomic numbers as colors.
Parameters
----------
data : Data
The 3D point cloud. Must have atomic numbers `z` and 2D coordinates `pos`
properties that are not `None`.
fig: mpl.figure.Figure
The maptlotlib figure to plot on.
ax_pos:
Three-digit integer specifying axis layout and position
(see docs for `mpl.figure.Figure.add_subplot`).
Returns
-------
mpl.axis.Axis
The newly created axis object.
"""
color, pos = data.z, data.pos
color = color.flatten().detach().numpy()
pos = pos.T.detach().numpy()
return plot_point_cloud_3d(fig, ax_pos, color, pos, cbar_label="Atomic number")
def plot_model_embedding(
data: Data, model: Callable[[Data], Tensor], fig: mpl.figure.Figure, ax_pos: int
) -> mpl.axis.Axis:
"""
Same as `plot_model_input` but instead of node features as color,
first apply a GNN model to obtain colors from node embeddings.
Parameters
----------
data : Data
the model input. Must have 3D coordinates `pos`
an atomic number `z` properties that are not `None`.
model : Callable[[Data], Tensor]
the model must take Data objects as input and return node embeddings
as a Tensor output.
fig: mpl.figure.Figure
The maptlotlib figure to plot on.
ax_pos:
Three-digit integer specifying axis layout and position
(see docs for `mpl.figure.Figure.add_subplot`).
Returns
-------
mpl.axis.Axis
The newly created axis object.
"""
x = model(data)
pos = data.pos
color = x.flatten().detach().numpy()
pos = pos.T.detach().numpy()
return plot_point_cloud_3d(fig, ax_pos, color, pos, cbar_label="Atom embedding (1D)")
Set up and inspect the QM9 dataset¶
Preprocessing¶
For the sake of this tutorial, we will restrict ourselves to small molecules with at most 8 heavy atoms. Due to our decision to ignore structural information and treat molecules as point clouds, where every atom interacts with every other atom, we also need to extend the torch geometric Data
objects with additional adjacency information that represents a complete graph without self-loops.
For performance reasons, both of these steps are performed once when pre-processing the raw data using the pre_filter
and pre_transform
keyword arguments of the dataset class.
[7]:
def num_heavy_atoms(qm9_data: Data) -> int:
"""Count the number of heavy atoms in a torch geometric
Data object.
Parameters
----------
qm9_data : Data
A pytorch geometric qm9 data object representing a small molecule
where atomic numbers are stored in a
tensor-valued attribute `qm9_data.z`
Returns
-------
int
The number of heavy atoms in the molecule.
"""
# every atom with atomic number other than 1 is heavy
return (qm9_data.z != 1).sum()
def complete_edge_index(n: int) -> LongTensor:
"""
Constructs a complete edge index.
NOTE: representing complete graphs
with sparse edge tensors is arguably a bad idea
due to performance reasons, but for this tutorial it'll do.
Parameters
----------
n : int
the number of nodes in the graph.
Returns
-------
LongTensor
A PyTorch `edge_index` represents a complete graph with n nodes,
without self-loops. Shape (2, n).
"""
# filter removes self loops
edges = list(filter(lambda e: e[0] != e[1], product(range(n), range(n))))
return torch.tensor(edges, dtype=torch.long).T
def add_complete_graph_edge_index(data: Data) -> Data:
"""
On top of any edge information already there,
add a second edge index that represents
the complete graph corresponding to a given
torch geometric data object
Parameters
----------
data : Data
The torch geometric data object.
Returns
-------
Data
The torch geometric `Data` object with a new
attribute `complete_edge_index` as described above.
"""
data.complete_edge_index = complete_edge_index(data.num_nodes)
return data
#
dataset = QM9(
DATA,
# Filter out molecules with more than 8 heavy atoms
pre_filter=lambda data: num_heavy_atoms(data) < 9,
# implement point cloud adjacency as a complete graph
pre_transform=add_complete_graph_edge_index,
)
print(f"Num. examples in QM9 restricted to molecules with at most 8 heavy atoms: {len(dataset)}")
Num. examples in QM9 restricted to molecules with at most 8 heavy atoms: 21800
NOTE: executing the above cell for the first time first downloads and then processes the raw data, which might take a while.
Indexing the dataset we just created returns a single Pytorch Geometric Data
object representing one molecular graph/point cloud. You can think of these objects as dictionaries with some extra utility methods already implemented.
[8]:
data = dataset[0]
# This displays all named data attributes, and their shapes (in the case of tensors), or values (in the case of other data).
data
[8]:
Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], z=[5], name='gdb_1', idx=[1], complete_edge_index=[2, 20])
For index 0 (name gdb_1
) this should be the molecule CH4. We can check this by looking into the atomic numbers stored in the attributed named z
[9]:
data.z
[9]:
tensor([6, 1, 1, 1, 1])
For molecules with N
atoms and M
(covalent) bonds, the data objects also contain named tensors of shape
Data.x
:(N, d_node)
node-level features (e.g. formal charge, membership to aromatic rings, chirality, …), but we will ignore them here and just use atomic numbers.Data.y
:(19,)
regression targetsData.edge_index
:(2, M)
edges between atoms derived from covalent bonds, stored as source and target node index pairs.Data.edge_attr
:(M, d_edge)
contains bond features (e.g. bond type, ring-membership, …)Data.pos
:(N, 3)
most interesting to us, atom 3D coordinates.Data.complete_edge_index
:(2, (N-1)^2)
: the complete graph edge index (without self-loops) we added earlier.
The input to our (point cloud) model we will implement later can be visualized using just Data.z
as color and Data.pos
as scatter plot positions. Note: the alpha channel of colors is used to convey depth-information.
[10]:
data.pos.round(decimals=2)
[10]:
tensor([[-0.0100, 1.0900, 0.0100],
[ 0.0000, -0.0100, 0.0000],
[ 1.0100, 1.4600, 0.0000],
[-0.5400, 1.4500, -0.8800],
[-0.5200, 1.4400, 0.9100]])
[11]:
fig = plt.figure()
ax = plot_model_input(data, fig, 111)
_ = ax.set_title("CH$_4$ (Methane)")
/tmp/ipykernel_51907/441942460.py:53: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.
cmap = mpl.cm.get_cmap(cmap)
Atomic number distribution and point cloud size¶
Now that our dataset is set up, and we have a basic understanding of how molecules are represented, we can try to visualize the properties of the entire dataset. Let us first look at the distribution of node-level features (atomic numbers) and the point cloud size (number of atoms) aggregated over the entire dataset.
[12]:
fig, (ax_atoms, ax_graph_size) = plt.subplots(1, 2, figsize=(8, 5))
# ax_atoms.hist(dataset.data.z[dataset.data.z != 1])
ax_atoms.hist(dataset.data.z)
ax_atoms.set_xlabel("Atomic number $z$")
ax_atoms.set_ylabel("Count")
num_nodes = [dataset[i].num_nodes for i in range(len(dataset))]
ax_graph_size.hist(num_nodes)
ax_graph_size.set_xlabel("Graph size (#nodes)")
fig.suptitle("Aggregated molecular point cloud properties")
fig.tight_layout()
~/.miniconda3/envs/teachopencadd/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. The data of the dataset is already cached, so any modifications to `data` will not be reflected when accessing its elements. Clearing the cache now by removing all elements in `dataset._data_list`. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.
warnings.warn(msg)
We can see that while fluorine atoms (number 9) show up in the data, they are heavily underrepresented (the bar at \(z=9\) is barely visible), which is not a nice property that is likely since we shrunk the dataset. The number of atoms seems to be roughly normally distributed, which is nice.
Data split, distribution of regression target electronic spatial extent¶
Next, we will implement data splitting, choose a regression target and visualize the split w.r.t. to this target. Out of the 19 regression targets included in QM9, we’ll focus on electronic spatial extent, which, simply put, describes the volume of a molecule, so it should be a good fit for methods that use 3D information. Let us now start with implementing a data module that takes care of train/val/test splits and of indexing the correct target.
[13]:
class QM9DataModule:
def __init__(
self,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
test_ratio: float = 0.1,
target_idx: int = 5,
seed: float = 420,
) -> None:
"""Encapsulates everything related to the dataset
Parameters
----------
train_ratio : float, optional
fraction of data used for training, by default 0.8
val_ratio : float, optional
fraction of data used for validation, by default 0.1
test_ratio : float, optional
fraction of data used for testing, by default 0.1
target_idx : int, optional
index of the target (see torch geometric docs), by default 5 (electronic spatial extent)
(https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html?highlight=qm9#torch_geometric.datasets.QM9)
seed : float, optional
random seed for data split, by default 420
"""
assert sum([train_ratio, val_ratio, test_ratio]) == 1
self.target_idx = target_idx
self.num_examples = len(self.dataset())
rng = np.random.default_rng(seed)
self.shuffled_index = rng.permutation(self.num_examples)
self.train_split = self.shuffled_index[: int(self.num_examples * train_ratio)]
self.val_split = self.shuffled_index[
int(self.num_examples * train_ratio) : int(
self.num_examples * (train_ratio + val_ratio)
)
]
self.test_split = self.shuffled_index[
int(self.num_examples * (train_ratio + val_ratio)) : self.num_examples
]
def dataset(self, transform=None) -> QM9:
dataset = QM9(
DATA,
pre_filter=lambda data: num_heavy_atoms(data) < 9,
pre_transform=add_complete_graph_edge_index,
)
dataset.data.y = dataset.data.y[:, self.target_idx].view(-1, 1)
return dataset
def loader(self, split, **loader_kwargs) -> DataLoader:
dataset = self.dataset()[split]
return DataLoader(dataset, **loader_kwargs)
def train_loader(self, **loader_kwargs) -> DataLoader:
return self.loader(self.train_split, shuffle=True, **loader_kwargs)
def val_loader(self, **loader_kwargs) -> DataLoader:
return self.loader(self.val_split, shuffle=False, **loader_kwargs)
def test_loader(self, **loader_kwargs) -> DataLoader:
return self.loader(self.test_split, shuffle=False, **loader_kwargs)
Now we can easily plot the target across the data split.
[14]:
data_module = QM9DataModule()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 6), sharey=True)
target = data_module.dataset().data.y.flatten().numpy()
ax1.boxplot(
[
target[data_module.train_split],
target[data_module.val_split],
target[data_module.test_split],
]
)
ax1.set_xticklabels(["Train", "Val", "Test"])
ax1.set_ylabel("Electronic spatial extent $\langle R^2 \\rangle$")
for label, split in {
"Train": data_module.train_split,
"Val": data_module.val_split,
"Test": data_module.test_split,
}.items():
ax2.scatter(split, target[split], label=label, s=1)
ax2.set_xlabel("Example index")
ax2.legend()
fig.suptitle("Random data split - target distribution")
~/.miniconda3/envs/teachopencadd/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.
warnings.warn(msg)
[14]:
Text(0.5, 0.98, 'Random data split - target distribution')
You should be able to observe that random splits are typically very homogenous, which means measuring generalization capabilities with them can yield deceivingly good results.
Model implementation¶
Plain “naive Euclidean” GNN¶
A naive way to incorporate 3D coordinates into a GNN for molecular graphs would be to interpret them as atom-level features that are simply combined with the other features. It is easy to implement a simple baseline model which does exactly this (see Talktorial T035). For its message-passing topology, our implementation uses the edges induced by bonds between atoms.
[15]:
class NaiveEuclideanGNN(nn.Module):
def __init__(
self,
hidden_channels: int,
num_layers: int,
num_spatial_dims: int,
final_embedding_size: Optional[int] = None,
act: nn.Module = nn.ReLU(),
) -> None:
super().__init__()
# NOTE nn.Embedding acts like a lookup table.
# Here we use it to store each atomic number in [0,100]
# a learnable, fixed-size vector representation
self.f_initial_embed = nn.Embedding(100, hidden_channels)
self.f_pos_embed = nn.Linear(num_spatial_dims, hidden_channels)
self.f_combine = nn.Sequential(nn.Linear(2 * hidden_channels, hidden_channels), act)
if final_embedding_size is None:
final_embedding_size = hidden_channels
# Graph isomorphism network as main GNN
# (see Talktorial 034)
# takes care of message passing and
# Learning node-level embeddings
self.gnn = geom_nn.models.GIN(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
out_channels=final_embedding_size,
num_layers=num_layers,
act=act,
)
# modules required for aggregating node embeddings
# into graph embeddings and making graph-level predictions
self.aggregation = geom_nn.aggr.SumAggregation()
self.f_predict = nn.Sequential(
nn.Linear(final_embedding_size, final_embedding_size),
act,
nn.Linear(final_embedding_size, 1),
)
def encode(self, data: Data) -> Tensor:
# initial atomic number embedding and embedding od positional information
atom_embedding = self.f_initial_embed(data.z)
pos_embedding = self.f_pos_embed(data.pos)
# treat both as plain node-level features and combine into initial node-level
# embedddings
initial_node_embed = self.f_combine(torch.cat((atom_embedding, pos_embedding), dim=-1))
# message passing
# NOTE in contrast to the EGNN implemented later, this model does use bond information
# i.e., data.egde_index stems from the bond adjacency matrix
node_embed = self.gnn(initial_node_embed, data.edge_index)
return node_embed
def forward(self, data: Data) -> Tensor:
node_embed = self.encode(data)
aggr = self.aggregation(node_embed, data.batch)
return self.f_predict(aggr)
Demo: Plain GNNs are not \(\text{E(3)}\)-invariant¶
However, this approach is problematic because the corresponding atom embeddings of a regular GNN (from which we would also derive our final predictions) will not be \(\text{E}(3)\)-invariant. This can be demonstrated easily:
[16]:
# use rotations along z-axis as demo e(3) transformation
def rotation_matrix_z(theta: float) -> Tensor:
"""Generates a rotation matrix and returns
a corresponing tensor. The rotation is about the $z$-axis.
(https://en.wikipedia.org/wiki/Rotation_matrix)
Parameters
----------
theta : float
the angle of rotation.
Returns
-------
Tensor
the rotation matrix as float tensor.
"""
return torch.tensor(
[
[math.cos(theta), -math.sin(theta), 0],
[math.sin(theta), math.cos(theta), 0],
[0, 0, 1],
]
)
NOTE: you may need to run the cell below multiple times to find a model initialization for which non-invariance can easily be observed.
[17]:
# Some data points from qm9
sample_data = dataset[800].clone()
# apply an E(3) transformation
rotated_sample_data = sample_data.clone()
rotated_sample_data.pos = rotated_sample_data.pos @ rotation_matrix_z(45)
# initialize a model with 2 hidden layers, 32 hidden channels,
# that outputs 1-dimensional node embeddings
model = NaiveEuclideanGNN(
hidden_channels=32,
num_layers=2,
num_spatial_dims=3,
final_embedding_size=1,
)
# make a plot that demonstrates non-equivariance
# fig, axes = plt.subplots(2, 2, figsize=(8,8), sharex=True, sharey=True)
fig = plt.figure(figsize=(8, 8))
ax1 = plot_model_input(sample_data, fig, 221)
ax1.set_title("Sample input $(X, Z)$")
ax2 = plot_model_input(rotated_sample_data, fig, 222)
ax2.set_title("Rotated input $(X, g(Z))$")
ax3 = plot_model_embedding(sample_data, model.encode, fig, 223)
ax3.set_title("Model output for $(X, Z)$")
ax4 = plot_model_embedding(rotated_sample_data, model.encode, fig, 224)
ax4.set_title("Model output for $(X, g(Z))$")
fig.tight_layout()
/tmp/ipykernel_51907/441942460.py:53: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.
cmap = mpl.cm.get_cmap(cmap)
/tmp/ipykernel_51907/2912899843.py:32: UserWarning: Tight layout not applied. The left and right margins cannot be made large enough to accommodate all axes decorations.
fig.tight_layout()
When executing the above cells a few times, we can observe that rotating the molecule may significantly alter the atom embeddings obtained from the plain GNN model.
EGNN model¶
We now implement an \(\text{E}(n)\)-invariant GNN based on the principles outlined in the theory section.
[18]:
class EquivariantMPLayer(nn.Module):
def __init__(
self,
in_channels: int,
hidden_channels: int,
act: nn.Module,
) -> None:
super().__init__()
self.act = act
self.residual_proj = nn.Linear(in_channels, hidden_channels, bias=False)
# Messages will consist of two (source and target) node embeddings and a scalar distance
message_input_size = 2 * in_channels + 1
# equation (3) "phi_l" NN
self.message_mlp = nn.Sequential(
nn.Linear(message_input_size, hidden_channels),
act,
)
# equation (4) "psi_l" NN
self.node_update_mlp = nn.Sequential(
nn.Linear(in_channels + hidden_channels, hidden_channels),
act,
)
def node_message_function(
self,
source_node_embed: Tensor, # h_i
target_node_embed: Tensor, # h_j
node_dist: Tensor, # d_ij
) -> Tensor:
# implements equation (3)
message_repr = torch.cat((source_node_embed, target_node_embed, node_dist), dim=-1)
return self.message_mlp(message_repr)
def compute_distances(self, node_pos: Tensor, edge_index: LongTensor) -> Tensor:
row, col = edge_index
xi, xj = node_pos[row], node_pos[col]
# relative squared distance
# implements equation (2) ||X_i - X_j||^2
rsdist = (xi - xj).pow(2).sum(1, keepdim=True)
return rsdist
def forward(
self,
node_embed: Tensor,
node_pos: Tensor,
edge_index: Tensor,
) -> Tensor:
row, col = edge_index
dist = self.compute_distances(node_pos, edge_index)
# compute messages "m_ij" from equation (3)
node_messages = self.node_message_function(node_embed[row], node_embed[col], dist)
# message sum aggregation in equation (4)
aggr_node_messages = scatter(node_messages, col, dim=0, reduce="sum")
# compute new node embeddings "h_i^{l+1}"
# (implements rest of equation (4))
new_node_embed = self.residual_proj(node_embed) + self.node_update_mlp(
torch.cat((node_embed, aggr_node_messages), dim=-1)
)
return new_node_embed
class EquivariantGNN(nn.Module):
def __init__(
self,
hidden_channels: int,
final_embedding_size: Optional[int] = None,
target_size: int = 1,
num_mp_layers: int = 2,
) -> None:
super().__init__()
if final_embedding_size is None:
final_embedding_size = hidden_channels
# non-linear activation func.
# usually configurable, here we just use Relu for simplicity
self.act = nn.ReLU()
# equation (1) "psi_0"
self.f_initial_embed = nn.Embedding(100, hidden_channels)
# create stack of message passing layers
self.message_passing_layers = nn.ModuleList()
channels = [hidden_channels] * (num_mp_layers) + [final_embedding_size]
for d_in, d_out in zip(channels[:-1], channels[1:]):
layer = EquivariantMPLayer(d_in, d_out, self.act)
self.message_passing_layers.append(layer)
# modules required for readout of a graph-level
# representation and graph-level property prediction
self.aggregation = SumAggregation()
self.f_predict = nn.Sequential(
nn.Linear(final_embedding_size, final_embedding_size),
self.act,
nn.Linear(final_embedding_size, target_size),
)
def encode(self, data: Data) -> Tensor:
# theory, equation (1)
node_embed = self.f_initial_embed(data.z)
# message passing
# theory, equation (3-4)
for mp_layer in self.message_passing_layers:
# NOTE here we use the complete edge index defined by the transform earlier on
# to implement the sum over $j \neq i$ in equation (4)
node_embed = mp_layer(node_embed, data.pos, data.complete_edge_index)
return node_embed
def _predict(self, node_embed, batch_index) -> Tensor:
aggr = self.aggregation(node_embed, batch_index)
return self.f_predict(aggr)
def forward(self, data: Data) -> Tensor:
node_embed = self.encode(data)
pred = self._predict(node_embed, data.batch)
return pred
Demo: Our EGNN is \(E(3)\)-invariant¶
We can collect evidence that this model is indeed \(\text{E}(n)\)-invariant by repeating the experiment we conducted earlier.
[19]:
model = EquivariantGNN(hidden_channels=32, final_embedding_size=1, num_mp_layers=2)
[20]:
# Some data points from qm9
sample_data = dataset[800].clone()
# apply E(3) transformation
rotated_sample_data = sample_data.clone()
rotated_sample_data.pos = rotated_sample_data.pos @ rotation_matrix_z(120)
fig = plt.figure(figsize=(8, 8))
ax1 = plot_model_input(sample_data, fig, 221)
ax1.set_title("Sample input $(X, Z)$")
ax2 = plot_model_input(rotated_sample_data, fig, 222)
ax2.set_title("Rotated input $(X, g(Z))$")
ax3 = plot_model_embedding(sample_data, model.encode, fig, 223)
ax3.set_title("Model output for $(X, Z)$")
ax4 = plot_model_embedding(rotated_sample_data, model.encode, fig, 224)
ax4.set_title("Model output for $(X, g(Z))$")
fig.tight_layout()
/tmp/ipykernel_51907/441942460.py:53: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.
cmap = mpl.cm.get_cmap(cmap)
/tmp/ipykernel_51907/1099465595.py:21: UserWarning: Tight layout not applied. tight_layout cannot make axes width small enough to accommodate all axes decorations
fig.tight_layout()
You can execute the above cells as often as you like, with whatever input you choose, the atom embeddings will always be unaffected by the rotation applied to the model input.
Training and evaluation¶
Now that we have set up our data and implemented two different models for point clouds, we can start implementing a training and evaluation pipeline.
We will follow the ubiquitous ML principle of also monitoring a validation loss in addition to the training loss. The validation loss acts as an estimate for how well the model generalizes and can be used for selecting a final model to be tested.
[21]:
# We will be using mean absolute error
# as a metric for validation and testing
def total_absolute_error(pred: Tensor, target: Tensor, batch_dim: int = 0) -> Tensor:
"""Total absolute error, i.e. sums over batch dimension.
Parameters
----------
pred : Tensor
batch of model predictions
target : Tensor
batch of ground truth / target values
batch_dim : int, optional
dimension that indexes batch elements, by default 0
Returns
-------
Tensor
total absolute error
"""
return (pred - target).abs().sum(batch_dim)
[22]:
def run_epoch(
model: nn.Module,
loader: DataLoader,
criterion: Callable[[Tensor, Tensor], Tensor],
pbar: Optional[Any] = None,
optim: Optional[torch.optim.Optimizer] = None,
):
"""Run a single epoch.
Parameters
----------
model : nn.Module
the NN used for regression
loader : DataLoader
an iterable over data batches
criterion : Callable[[Tensor, Tensor], Tensor]
a criterion (loss) that is optimized
pbar : Optional[Any], optional
a tqdm progress bar, by default None
optim : Optional[torch.optim.Optimizer], optional
a optimizer that is optimizing the criterion, by default None
"""
def step(
data_batch: Data,
) -> Tuple[float, float]:
"""Perform a single train/val step on a data batch.
Parameters
----------
data_batch : Data
Returns
-------
Tuple[float, float]
Loss (mean squared error) and validation critierion (absolute error).
"""
pred = model.forward(data_batch)
target = data_batch.y
loss = criterion(pred, target)
if optim is not None:
optim.zero_grad()
loss.backward()
optim.step()
return loss.detach().item(), total_absolute_error(pred.detach(), target.detach())
if optim is not None:
model.train()
# This enables pytorch autodiff s.t. we can compute gradients
model.requires_grad_(True)
else:
model.eval()
# disable autodiff: when evaluating we do not need to track gradients
model.requires_grad_(False)
total_loss = 0
total_mae = 0
for data in loader:
loss, mae = step(data)
total_loss += loss * data.num_graphs
total_mae += mae
if pbar is not None:
pbar.update(1)
return total_loss / len(loader.dataset), total_mae / len(loader.dataset)
def train_model(
data_module: QM9DataModule,
model: nn.Module,
num_epochs: int = 30,
lr: float = 3e-4,
batch_size: int = 32,
weight_decay: float = 1e-8,
best_model_path: Path = DATA.joinpath("trained_model.pth"),
) -> Dict[str, Any]:
"""Takes data and model as input and runs training, collecting additional validation metrics
while doing so.
Parameters
----------
data_module : QM9DataModule
a data module as defined earlier
model : nn.Module
a gnn model
num_epochs : int, optional
number of epochs to train for, by default 30
lr : float, optional
"learning rate": optimizer SGD step size, by default 3e-4
batch_size : int, optional
number of examples used for one training step, by default 32
weight_decay : float, optional
L2 regularization parameter, by default 1e-8
best_model_path : Path, optional
path where the model weights with lowest val. error should be stored
, by default DATA.joinpath("trained_model.pth")
Returns
-------
Dict[str, Any]
a training result, ie statistics and info about the model
"""
# create data loaders
train_loader = data_module.train_loader(batch_size=batch_size)
val_loader = data_module.val_loader(batch_size=batch_size)
# setup optimizer and loss
optim = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-8)
loss_fn = nn.MSELoss()
# keep track of the epoch with the best validation mae
# st we can save the "best" model weights
best_val_mae = float("inf")
# Statistics that will be plotted later on
# and model info
result = {
"model": model,
"path_to_best_model": best_model_path,
"train_loss": np.full(num_epochs, float("nan")),
"val_loss": np.full(num_epochs, float("nan")),
"train_mae": np.full(num_epochs, float("nan")),
"val_mae": np.full(num_epochs, float("nan")),
}
# Auxiliary functions for updating and reporting
# Training progress statistics
def update_statistics(i_epoch: int, **kwargs: float):
for key, value in kwargs.items():
result[key][i_epoch] = value
def desc(i_epoch: int) -> str:
return " | ".join(
[f"Epoch {i_epoch + 1:3d} / {num_epochs}"]
+ [
f"{key}: {value[i_epoch]:8.2f}"
for key, value in result.items()
if isinstance(value, np.ndarray)
]
)
# main training loop
for i_epoch in range(0, num_epochs):
progress_bar = tqdm(total=len(train_loader) + len(val_loader))
try:
# tqdm for reporting progress
progress_bar.set_description(desc(i_epoch))
# training epoch
train_loss, train_mae = run_epoch(model, train_loader, loss_fn, progress_bar, optim)
# validation epoch
val_loss, val_mae = run_epoch(model, val_loader, loss_fn, progress_bar)
update_statistics(
i_epoch,
train_loss=train_loss,
val_loss=val_loss,
train_mae=train_mae,
val_mae=val_mae,
)
progress_bar.set_description(desc(i_epoch))
if val_mae < best_val_mae:
best_val_mae = val_mae
torch.save(model.state_dict(), best_model_path)
finally:
progress_bar.close()
return result
[23]:
@torch.no_grad()
def test_model(model: nn.Module, data_module: QM9DataModule) -> Tuple[float, Tensor, Tensor]:
"""
Test a model.
Parameters
----------
model : nn.Module
a trained model
data_module : QM9DataModule
a data module as defined earlier
from which we'll get the test data
Returns
-------
_Tuple[float, Tensor, Tensor]
Test MAE, and model predictions & targets for further processing
"""
test_mae = 0
preds, targets = [], []
loader = data_module.test_loader()
for data in loader:
pred = model(data)
target = data.y
preds.append(pred)
targets.append(target)
test_mae += total_absolute_error(pred, target).item()
test_mae = test_mae / len(data_module.test_split)
return test_mae, torch.cat(preds, dim=0), torch.cat(targets, dim=0)
Training the EGNN¶
[24]:
model = EquivariantGNN(hidden_channels=64, num_mp_layers=2)
egnn_train_result = train_model(
data_module,
model,
num_epochs=25,
lr=2e-4,
batch_size=32,
weight_decay=1e-8,
best_model_path=DATA.joinpath("trained_egnn.pth"),
)
~/.miniconda3/envs/teachopencadd/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.
warnings.warn(msg)
Epoch 1 / 25 | train_loss: 83816.98 | val_loss: 8249.58 | train_mae: 177.73 | val_mae: 62.58: 100%|█████████████████████████████████████████████████| 614/614 [00:28<00:00, 21.82it/s]
Epoch 2 / 25 | train_loss: 5193.82 | val_loss: 2884.07 | train_mae: 51.31 | val_mae: 34.33: 100%|█████████████████████████████████████████████████| 614/614 [00:21<00:00, 28.63it/s]
Epoch 3 / 25 | train_loss: 1847.79 | val_loss: 932.89 | train_mae: 29.56 | val_mae: 20.27: 100%|█████████████████████████████████████████████████| 614/614 [00:19<00:00, 31.40it/s]
Epoch 4 / 25 | train_loss: 853.45 | val_loss: 786.19 | train_mae: 20.28 | val_mae: 21.96: 100%|█████████████████████████████████████████████████| 614/614 [00:18<00:00, 32.92it/s]
Epoch 5 / 25 | train_loss: 619.10 | val_loss: 389.76 | train_mae: 17.42 | val_mae: 13.32: 100%|█████████████████████████████████████████████████| 614/614 [00:18<00:00, 32.35it/s]
Epoch 6 / 25 | train_loss: 479.23 | val_loss: 343.42 | train_mae: 15.32 | val_mae: 12.68: 100%|█████████████████████████████████████████████████| 614/614 [00:19<00:00, 31.53it/s]
Epoch 7 / 25 | train_loss: 383.88 | val_loss: 285.21 | train_mae: 13.70 | val_mae: 11.90: 100%|█████████████████████████████████████████████████| 614/614 [00:21<00:00, 29.10it/s]
Epoch 8 / 25 | train_loss: 306.98 | val_loss: 201.01 | train_mae: 12.24 | val_mae: 9.67: 100%|█████████████████████████████████████████████████| 614/614 [00:19<00:00, 30.92it/s]
Epoch 9 / 25 | train_loss: 259.13 | val_loss: 387.28 | train_mae: 11.30 | val_mae: 16.76: 100%|█████████████████████████████████████████████████| 614/614 [00:19<00:00, 30.84it/s]
Epoch 10 / 25 | train_loss: 265.28 | val_loss: 231.35 | train_mae: 11.62 | val_mae: 11.55: 100%|█████████████████████████████████████████████████| 614/614 [00:18<00:00, 33.07it/s]
Epoch 11 / 25 | train_loss: 187.75 | val_loss: 148.88 | train_mae: 9.50 | val_mae: 8.89: 100%|█████████████████████████████████████████████████| 614/614 [00:18<00:00, 32.69it/s]
Epoch 12 / 25 | train_loss: 168.02 | val_loss: 303.56 | train_mae: 9.35 | val_mae: 15.35: 100%|█████████████████████████████████████████████████| 614/614 [00:18<00:00, 32.65it/s]
Epoch 13 / 25 | train_loss: 153.95 | val_loss: 86.36 | train_mae: 9.01 | val_mae: 6.31: 100%|█████████████████████████████████████████████████| 614/614 [00:20<00:00, 30.39it/s]
Epoch 14 / 25 | train_loss: 136.87 | val_loss: 73.61 | train_mae: 8.50 | val_mae: 5.95: 100%|█████████████████████████████████████████████████| 614/614 [00:23<00:00, 25.83it/s]
Epoch 15 / 25 | train_loss: 111.18 | val_loss: 144.24 | train_mae: 7.67 | val_mae: 10.03: 100%|█████████████████████████████████████████████████| 614/614 [00:18<00:00, 32.98it/s]
Epoch 16 / 25 | train_loss: 110.45 | val_loss: 127.40 | train_mae: 7.69 | val_mae: 9.73: 100%|█████████████████████████████████████████████████| 614/614 [00:20<00:00, 30.23it/s]
Epoch 17 / 25 | train_loss: 88.18 | val_loss: 87.97 | train_mae: 6.83 | val_mae: 7.15: 100%|█████████████████████████████████████████████████| 614/614 [00:21<00:00, 28.76it/s]
Epoch 18 / 25 | train_loss: 96.38 | val_loss: 69.41 | train_mae: 7.21 | val_mae: 6.59: 100%|█████████████████████████████████████████████████| 614/614 [00:22<00:00, 27.20it/s]
Epoch 19 / 25 | train_loss: 66.19 | val_loss: 213.51 | train_mae: 5.91 | val_mae: 13.43: 100%|█████████████████████████████████████████████████| 614/614 [00:21<00:00, 28.50it/s]
Epoch 20 / 25 | train_loss: 97.42 | val_loss: 30.31 | train_mae: 6.86 | val_mae: 3.56: 100%|█████████████████████████████████████████████████| 614/614 [00:21<00:00, 29.20it/s]
Epoch 21 / 25 | train_loss: 66.18 | val_loss: 31.33 | train_mae: 5.98 | val_mae: 3.84: 100%|█████████████████████████████████████████████████| 614/614 [00:18<00:00, 32.92it/s]
Epoch 22 / 25 | train_loss: 64.28 | val_loss: 42.74 | train_mae: 5.82 | val_mae: 4.67: 100%|█████████████████████████████████████████████████| 614/614 [00:21<00:00, 28.31it/s]
Epoch 23 / 25 | train_loss: 65.13 | val_loss: 102.03 | train_mae: 5.96 | val_mae: 8.89: 100%|█████████████████████████████████████████████████| 614/614 [00:18<00:00, 32.82it/s]
Epoch 24 / 25 | train_loss: 71.31 | val_loss: 142.06 | train_mae: 6.10 | val_mae: 8.94: 100%|█████████████████████████████████████████████████| 614/614 [00:18<00:00, 33.01it/s]
Epoch 25 / 25 | train_loss: 63.18 | val_loss: 28.93 | train_mae: 5.71 | val_mae: 3.54: 100%|█████████████████████████████████████████████████| 614/614 [00:18<00:00, 32.71it/s]
Training the plain GNN¶
[25]:
gcn_baseline = NaiveEuclideanGNN(64, 4, 3)
gcn_train_result = train_model(
data_module,
gcn_baseline,
num_epochs=100,
lr=3e-4,
batch_size=32,
best_model_path=DATA.joinpath("trained_gnn.pth"),
)
Epoch 1 / 100 | train_loss: 131605.16 | val_loss: 51881.71 | train_mae: 257.84 | val_mae: 178.73: 100%|███████████████████████████████████████████████| 614/614 [00:10<00:00, 59.30it/s]
Epoch 2 / 100 | train_loss: 46252.38 | val_loss: 33482.20 | train_mae: 159.96 | val_mae: 139.05: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.94it/s]
Epoch 3 / 100 | train_loss: 31016.87 | val_loss: 25729.41 | train_mae: 131.66 | val_mae: 118.91: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 81.91it/s]
Epoch 4 / 100 | train_loss: 19575.64 | val_loss: 16162.15 | train_mae: 104.25 | val_mae: 95.53: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 82.16it/s]
Epoch 5 / 100 | train_loss: 13219.30 | val_loss: 18588.67 | train_mae: 85.73 | val_mae: 111.78: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 78.85it/s]
Epoch 6 / 100 | train_loss: 10594.13 | val_loss: 9926.80 | train_mae: 76.39 | val_mae: 69.72: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.72it/s]
Epoch 7 / 100 | train_loss: 8526.29 | val_loss: 9002.18 | train_mae: 67.99 | val_mae: 67.95: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.48it/s]
Epoch 8 / 100 | train_loss: 7094.07 | val_loss: 6584.26 | train_mae: 62.25 | val_mae: 56.78: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.69it/s]
Epoch 9 / 100 | train_loss: 6621.97 | val_loss: 6924.00 | train_mae: 59.52 | val_mae: 58.06: 100%|████████████████████████████████████████████████| 614/614 [00:06<00:00, 87.96it/s]
Epoch 10 / 100 | train_loss: 6078.34 | val_loss: 6341.89 | train_mae: 56.62 | val_mae: 56.60: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.94it/s]
Epoch 11 / 100 | train_loss: 5703.69 | val_loss: 6478.72 | train_mae: 55.66 | val_mae: 58.99: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.74it/s]
Epoch 12 / 100 | train_loss: 5338.27 | val_loss: 5886.17 | train_mae: 53.40 | val_mae: 54.12: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 84.17it/s]
Epoch 13 / 100 | train_loss: 5155.95 | val_loss: 5749.83 | train_mae: 52.08 | val_mae: 52.79: 100%|████████████████████████████████████████████████| 614/614 [00:06<00:00, 87.90it/s]
Epoch 14 / 100 | train_loss: 4964.06 | val_loss: 5171.14 | train_mae: 51.32 | val_mae: 49.19: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.07it/s]
Epoch 15 / 100 | train_loss: 4871.12 | val_loss: 5170.29 | train_mae: 50.80 | val_mae: 49.65: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.82it/s]
Epoch 16 / 100 | train_loss: 4644.91 | val_loss: 4882.15 | train_mae: 49.54 | val_mae: 46.48: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.09it/s]
Epoch 17 / 100 | train_loss: 4630.13 | val_loss: 6608.02 | train_mae: 49.50 | val_mae: 58.90: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 81.44it/s]
Epoch 18 / 100 | train_loss: 4616.51 | val_loss: 5037.78 | train_mae: 49.50 | val_mae: 49.18: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 80.24it/s]
Epoch 19 / 100 | train_loss: 4216.14 | val_loss: 5161.73 | train_mae: 46.83 | val_mae: 47.26: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 77.75it/s]
Epoch 20 / 100 | train_loss: 4141.20 | val_loss: 4744.50 | train_mae: 46.76 | val_mae: 46.50: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 76.75it/s]
Epoch 21 / 100 | train_loss: 4027.40 | val_loss: 4649.81 | train_mae: 46.30 | val_mae: 45.15: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 81.51it/s]
Epoch 22 / 100 | train_loss: 3948.61 | val_loss: 4409.36 | train_mae: 45.76 | val_mae: 45.78: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 81.78it/s]
Epoch 23 / 100 | train_loss: 3980.70 | val_loss: 4948.89 | train_mae: 45.73 | val_mae: 48.64: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.25it/s]
Epoch 24 / 100 | train_loss: 3621.07 | val_loss: 4880.79 | train_mae: 43.77 | val_mae: 49.11: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.39it/s]
Epoch 25 / 100 | train_loss: 3692.66 | val_loss: 6213.45 | train_mae: 44.01 | val_mae: 55.11: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.52it/s]
Epoch 26 / 100 | train_loss: 3850.76 | val_loss: 4153.96 | train_mae: 45.27 | val_mae: 44.38: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.56it/s]
Epoch 27 / 100 | train_loss: 3442.45 | val_loss: 4037.86 | train_mae: 42.52 | val_mae: 42.73: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.44it/s]
Epoch 28 / 100 | train_loss: 3333.53 | val_loss: 4577.04 | train_mae: 41.92 | val_mae: 47.31: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.67it/s]
Epoch 29 / 100 | train_loss: 3418.91 | val_loss: 3798.06 | train_mae: 42.22 | val_mae: 40.75: 100%|████████████████████████████████████████████████| 614/614 [00:06<00:00, 88.26it/s]
Epoch 30 / 100 | train_loss: 3373.68 | val_loss: 4033.93 | train_mae: 42.03 | val_mae: 41.66: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.06it/s]
Epoch 31 / 100 | train_loss: 3111.80 | val_loss: 3682.38 | train_mae: 40.34 | val_mae: 40.78: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.05it/s]
Epoch 32 / 100 | train_loss: 3117.61 | val_loss: 3806.91 | train_mae: 40.66 | val_mae: 41.78: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.66it/s]
Epoch 33 / 100 | train_loss: 3206.69 | val_loss: 3581.51 | train_mae: 40.87 | val_mae: 40.26: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.08it/s]
Epoch 34 / 100 | train_loss: 3079.73 | val_loss: 4547.30 | train_mae: 40.19 | val_mae: 45.41: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 79.78it/s]
Epoch 35 / 100 | train_loss: 3149.01 | val_loss: 4522.56 | train_mae: 40.80 | val_mae: 45.55: 100%|████████████████████████████████████████████████| 614/614 [00:06<00:00, 90.56it/s]
Epoch 36 / 100 | train_loss: 2918.20 | val_loss: 3964.84 | train_mae: 39.13 | val_mae: 41.33: 100%|████████████████████████████████████████████████| 614/614 [00:06<00:00, 95.45it/s]
Epoch 37 / 100 | train_loss: 2908.17 | val_loss: 4062.00 | train_mae: 39.27 | val_mae: 42.20: 100%|████████████████████████████████████████████████| 614/614 [00:06<00:00, 92.70it/s]
Epoch 38 / 100 | train_loss: 2858.89 | val_loss: 3632.95 | train_mae: 39.15 | val_mae: 39.16: 100%|████████████████████████████████████████████████| 614/614 [00:06<00:00, 93.83it/s]
Epoch 39 / 100 | train_loss: 2911.77 | val_loss: 4292.41 | train_mae: 39.19 | val_mae: 45.72: 100%|████████████████████████████████████████████████| 614/614 [00:06<00:00, 91.97it/s]
Epoch 40 / 100 | train_loss: 2774.28 | val_loss: 4351.61 | train_mae: 38.52 | val_mae: 44.17: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 74.30it/s]
Epoch 41 / 100 | train_loss: 2828.55 | val_loss: 3983.95 | train_mae: 38.97 | val_mae: 43.67: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 71.27it/s]
Epoch 42 / 100 | train_loss: 2639.83 | val_loss: 3555.43 | train_mae: 37.39 | val_mae: 39.24: 100%|████████████████████████████████████████████████| 614/614 [00:06<00:00, 89.57it/s]
Epoch 43 / 100 | train_loss: 2802.29 | val_loss: 3784.69 | train_mae: 38.28 | val_mae: 40.87: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 84.90it/s]
Epoch 44 / 100 | train_loss: 2669.42 | val_loss: 4253.03 | train_mae: 37.75 | val_mae: 42.04: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 80.92it/s]
Epoch 45 / 100 | train_loss: 2543.59 | val_loss: 3882.21 | train_mae: 36.53 | val_mae: 41.58: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 84.11it/s]
Epoch 46 / 100 | train_loss: 2563.37 | val_loss: 3912.68 | train_mae: 36.96 | val_mae: 41.34: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.72it/s]
Epoch 47 / 100 | train_loss: 2559.09 | val_loss: 5493.43 | train_mae: 37.02 | val_mae: 51.46: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.71it/s]
Epoch 48 / 100 | train_loss: 2576.10 | val_loss: 4768.31 | train_mae: 36.85 | val_mae: 52.41: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.25it/s]
Epoch 49 / 100 | train_loss: 2398.73 | val_loss: 3751.79 | train_mae: 35.63 | val_mae: 38.73: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.18it/s]
Epoch 50 / 100 | train_loss: 2565.22 | val_loss: 3555.04 | train_mae: 37.00 | val_mae: 37.67: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.84it/s]
Epoch 51 / 100 | train_loss: 2360.89 | val_loss: 3545.71 | train_mae: 35.47 | val_mae: 40.19: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.31it/s]
Epoch 52 / 100 | train_loss: 2467.34 | val_loss: 3806.81 | train_mae: 36.17 | val_mae: 40.76: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.50it/s]
Epoch 53 / 100 | train_loss: 2312.26 | val_loss: 3497.13 | train_mae: 35.16 | val_mae: 37.88: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.50it/s]
Epoch 54 / 100 | train_loss: 2167.41 | val_loss: 3487.33 | train_mae: 34.01 | val_mae: 36.67: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.55it/s]
Epoch 55 / 100 | train_loss: 2185.56 | val_loss: 3963.23 | train_mae: 34.25 | val_mae: 40.11: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 77.66it/s]
Epoch 56 / 100 | train_loss: 2280.32 | val_loss: 3320.05 | train_mae: 34.80 | val_mae: 36.47: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.87it/s]
Epoch 57 / 100 | train_loss: 2211.96 | val_loss: 3721.73 | train_mae: 34.26 | val_mae: 39.31: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.31it/s]
Epoch 58 / 100 | train_loss: 2284.40 | val_loss: 3462.30 | train_mae: 34.87 | val_mae: 38.11: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.90it/s]
Epoch 59 / 100 | train_loss: 2304.63 | val_loss: 3297.33 | train_mae: 35.03 | val_mae: 36.64: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.13it/s]
Epoch 60 / 100 | train_loss: 2117.74 | val_loss: 3451.63 | train_mae: 33.53 | val_mae: 38.31: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.60it/s]
Epoch 61 / 100 | train_loss: 2247.47 | val_loss: 3426.57 | train_mae: 34.57 | val_mae: 39.07: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.00it/s]
Epoch 62 / 100 | train_loss: 2037.65 | val_loss: 3057.25 | train_mae: 32.89 | val_mae: 36.22: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.02it/s]
Epoch 63 / 100 | train_loss: 2076.61 | val_loss: 3299.59 | train_mae: 33.09 | val_mae: 35.54: 100%|████████████████████████████████████████████████| 614/614 [00:06<00:00, 88.46it/s]
Epoch 64 / 100 | train_loss: 2004.54 | val_loss: 3806.01 | train_mae: 32.72 | val_mae: 41.34: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.80it/s]
Epoch 65 / 100 | train_loss: 1974.01 | val_loss: 3369.84 | train_mae: 32.42 | val_mae: 35.64: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 83.84it/s]
Epoch 66 / 100 | train_loss: 1968.44 | val_loss: 3482.79 | train_mae: 32.30 | val_mae: 38.24: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.11it/s]
Epoch 67 / 100 | train_loss: 1977.37 | val_loss: 3171.34 | train_mae: 32.50 | val_mae: 36.01: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.51it/s]
Epoch 68 / 100 | train_loss: 1863.63 | val_loss: 3157.46 | train_mae: 31.64 | val_mae: 34.47: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.52it/s]
Epoch 69 / 100 | train_loss: 1961.02 | val_loss: 3296.61 | train_mae: 32.31 | val_mae: 37.72: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 81.34it/s]
Epoch 70 / 100 | train_loss: 1943.04 | val_loss: 3274.01 | train_mae: 32.25 | val_mae: 34.72: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 84.12it/s]
Epoch 71 / 100 | train_loss: 1889.30 | val_loss: 3842.83 | train_mae: 31.79 | val_mae: 40.47: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 81.92it/s]
Epoch 72 / 100 | train_loss: 1896.79 | val_loss: 4150.15 | train_mae: 31.78 | val_mae: 43.21: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 82.86it/s]
Epoch 73 / 100 | train_loss: 1860.80 | val_loss: 3174.58 | train_mae: 31.54 | val_mae: 35.24: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.59it/s]
Epoch 74 / 100 | train_loss: 1870.97 | val_loss: 3089.24 | train_mae: 31.62 | val_mae: 36.22: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 77.78it/s]
Epoch 75 / 100 | train_loss: 1869.31 | val_loss: 3397.72 | train_mae: 31.83 | val_mae: 37.43: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 71.10it/s]
Epoch 76 / 100 | train_loss: 1723.28 | val_loss: 3294.50 | train_mae: 30.57 | val_mae: 34.38: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 82.95it/s]
Epoch 77 / 100 | train_loss: 2041.00 | val_loss: 3031.84 | train_mae: 32.94 | val_mae: 34.34: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 76.83it/s]
Epoch 78 / 100 | train_loss: 1670.29 | val_loss: 3470.53 | train_mae: 29.81 | val_mae: 38.44: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 74.45it/s]
Epoch 79 / 100 | train_loss: 1795.03 | val_loss: 3402.04 | train_mae: 31.12 | val_mae: 38.37: 100%|████████████████████████████████████████████████| 614/614 [00:09<00:00, 65.54it/s]
Epoch 80 / 100 | train_loss: 1769.94 | val_loss: 3149.87 | train_mae: 30.62 | val_mae: 33.53: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 74.66it/s]
Epoch 81 / 100 | train_loss: 1682.90 | val_loss: 3412.26 | train_mae: 30.09 | val_mae: 38.59: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 80.33it/s]
Epoch 82 / 100 | train_loss: 1737.01 | val_loss: 3396.11 | train_mae: 30.41 | val_mae: 35.16: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 69.65it/s]
Epoch 83 / 100 | train_loss: 1967.99 | val_loss: 2990.12 | train_mae: 32.48 | val_mae: 33.36: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 74.43it/s]
Epoch 84 / 100 | train_loss: 1722.19 | val_loss: 3358.17 | train_mae: 30.33 | val_mae: 35.79: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 77.35it/s]
Epoch 85 / 100 | train_loss: 1633.18 | val_loss: 3178.61 | train_mae: 29.64 | val_mae: 34.94: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 83.83it/s]
Epoch 86 / 100 | train_loss: 1691.06 | val_loss: 3589.25 | train_mae: 30.34 | val_mae: 39.16: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 80.62it/s]
Epoch 87 / 100 | train_loss: 1573.87 | val_loss: 3321.35 | train_mae: 29.09 | val_mae: 34.73: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 70.68it/s]
Epoch 88 / 100 | train_loss: 1624.47 | val_loss: 3328.78 | train_mae: 29.65 | val_mae: 36.73: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 80.53it/s]
Epoch 89 / 100 | train_loss: 1566.87 | val_loss: 2961.29 | train_mae: 28.97 | val_mae: 34.25: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 83.01it/s]
Epoch 90 / 100 | train_loss: 1622.22 | val_loss: 3082.79 | train_mae: 29.63 | val_mae: 35.45: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 68.98it/s]
Epoch 91 / 100 | train_loss: 1658.47 | val_loss: 2969.16 | train_mae: 30.08 | val_mae: 33.32: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 79.57it/s]
Epoch 92 / 100 | train_loss: 1495.98 | val_loss: 3389.66 | train_mae: 28.39 | val_mae: 34.39: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 77.54it/s]
Epoch 93 / 100 | train_loss: 1512.13 | val_loss: 3102.99 | train_mae: 28.54 | val_mae: 34.93: 100%|████████████████████████████████████████████████| 614/614 [00:08<00:00, 73.15it/s]
Epoch 94 / 100 | train_loss: 1489.28 | val_loss: 3190.44 | train_mae: 28.52 | val_mae: 35.37: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.53it/s]
Epoch 95 / 100 | train_loss: 1651.04 | val_loss: 2940.62 | train_mae: 29.94 | val_mae: 32.10: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.13it/s]
Epoch 96 / 100 | train_loss: 1529.36 | val_loss: 3137.99 | train_mae: 28.90 | val_mae: 34.13: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.59it/s]
Epoch 97 / 100 | train_loss: 1531.24 | val_loss: 3489.23 | train_mae: 28.92 | val_mae: 38.00: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 87.50it/s]
Epoch 98 / 100 | train_loss: 1580.89 | val_loss: 2815.52 | train_mae: 29.14 | val_mae: 32.56: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 78.90it/s]
Epoch 99 / 100 | train_loss: 1518.98 | val_loss: 3225.39 | train_mae: 28.69 | val_mae: 32.25: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 85.77it/s]
Epoch 100 / 100 | train_loss: 1555.37 | val_loss: 3180.20 | train_mae: 29.11 | val_mae: 35.67: 100%|████████████████████████████████████████████████| 614/614 [00:07<00:00, 86.09it/s]
Comparative evaluation¶
Let us now compare the trained EGNN and GNN baseline model. First note that in terms of capacity (measured by the number of trainable parameters) the models are very similar. But be aware that the comparison is still not completely fair, because
the EGNN is a message-passing neural network while the baseline GNN is a type of graph convolutional neural network
the EGNN is run on complete graphs, whereas the baseline GNN uses the bond adjacency info, which could also be a disadvantage
[26]:
gcn_num_params = sum(p.numel() for p in gcn_train_result["model"].parameters())
egnn_num_params = sum(p.numel() for p in egnn_train_result["model"].parameters())
for key, value in {"GCN": gcn_num_params, "EGNN": egnn_num_params}.items():
print(f"{key} has {value} parameters")
GCN has 52417 parameters
EGNN has 51969 parameters
Plotting the loss and validation MAE for each epoch, we can observe that the EGNN training progresses much faster and yields much better results, even though it is trained for a smaller number of epochs (note that loss and MAE are in log-scale).
Surprisingly the validation loss/MAE for the EGNN is sometimes lower than the train loss/MAE. This might be explained by the fact that the data split is very homogenous, and the validation data contains fewer outliers than the train data (see box plots from the section on distribution of regression target across splits).
[27]:
fig, (loss_ax, mae_ax) = plt.subplots(1, 2, figsize=(8, 4))
loss_ax.set_title("Loss (MSE)")
mae_ax.set_title("MAE")
loss_ax.set_xlabel("Epoch")
mae_ax.set_xlabel("Epoch")
for metric in ["train_loss", "val_loss", "train_mae", "val_mae"]:
split = metric.split("_")[0]
ax = loss_ax if "loss" in metric else mae_ax
ax.plot(egnn_train_result[metric], label=f"EGNN {split}")
ax.plot(gcn_train_result[metric], label=f"GNN {split}")
mae_ax.legend()
mae_ax.set_yscale("log")
loss_ax.set_yscale("log")
This performance improvement can also be observed in the held-out test data. For testing, we select the best model as the model that had the lowest validation MAE each.
[28]:
gcn_model = gcn_train_result["model"]
gcn_model.load_state_dict(torch.load(gcn_train_result["path_to_best_model"]))
gcn_test_mae, gcn_preds, gcn_targets = test_model(gcn_model, data_module)
egnn_model = egnn_train_result["model"]
egnn_model.load_state_dict(torch.load(egnn_train_result["path_to_best_model"]))
egnn_test_mae, egnn_preds, egnn_targets = test_model(egnn_model, data_module)
print(f"EGNN test MAE: {egnn_test_mae}")
print(f"GNN test MAE: {gcn_test_mae}")
EGNN test MAE: 3.4511184202421696
GNN test MAE: 31.881833081726633
[29]:
fig, ax = plt.subplots()
ax.plot(gcn_targets, gcn_targets, "--", color="grey")
ax.scatter(gcn_targets, gcn_preds, s=1, label="GNN")
ax.scatter(egnn_targets, egnn_preds, s=1, label="EGNN")
ax.set_ylabel("Model prediction")
ax.set_xlabel("Ground truth $\langle R^2 \\rangle$")
ax.set_title("Test performance")
ax.legend()
[29]:
<matplotlib.legend.Legend at 0x7f663bf0f700>
These findings support our initial hypothesis that \(\text{E}(3)\)-invariant models lead to faster learning and improved generalization performance.
Discussion¶
Summary¶
You have now seen, theoretically and practically, why we need \((S)E(3)\) to work with point cloud representations of molecules and how to implement, train and evaluate them. The dataset used here is not directly relevant to CADD, but the practical importance of \((S)E(3)\) equi-/invariance definitely carries over to more relevant applications such as protein ligand docking. Recent work on molecular representation learning also suggests that 3D point clouds are favored for a broad range of property prediction tasks more relevant to CADD such as toxicity prediction.
Caveats of our approach¶
At this point, we should also go over some final caveats with the EGNN presented here and our approach in general:
Our model assumes that every atom interacts with every other atom, i.e. the neighborhood of node \(i\), \(N(i) = \{j \neq i\}\) is complete. This approach has quadratic complexity meaning its more computationally expensive (go back to the model training and compare how long one epoch takes compared to the plain GNN) and thus might not be scalable to larger molecules. In this case we could restrict interactions by instead using
\(k\)-nearest neighborhoods, i.e. \(|N(i)| = k\) contains the \(k\) nodes with the smallest euclidean distance to \(i\),
or spherical neighborhoods with a fixed radius \(\delta\) instead, i.e. \(N(i) = \{j \mid ||X_i - X_j||^2 \leq \delta\}\)
Our EGNN model is \(E(3)\)-invariant. Note that some molecular properties are sensitive to reflection, In such settings, \(SE(3)\)-invariance should be the preferred model property (see Talktorial T033).
Random data splits are considered bad practice for measuring the capability of a molecular machine learning model to generalize to unseen data (see this paper which analyzes and discusses this issue in-depth for QM9)
Quiz¶
In addition to 3D coordinates, what is strictly required for inference of covalent bonds between atoms?
What is the difference between equivariance and invariance?
True or false? \(SE(3)\) contains transformations which are not included in \(E(3)\).
True or false? The atom embeddings \(h\) computed by iterating the following message passing scheme for a fixed number of steps are \(E(3)\)-invariant
\[m_{ij}^{l} = \phi_{l}(h_i^l, h_j^l, X_i - X_j)\]\[h_{i}^{l+1} = \psi_l(h_{i}^l, \sum_{j \neq i} m_{ij}^l)\]