Source code for torch_geometric.nn.models.gpse

import logging
import os
import os.path as osp
import time
from collections import OrderedDict
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from tqdm import trange

import torch_geometric.transforms as T
from torch_geometric.data import Data, Dataset, download_url
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import (
    ResGatedGraphConv,
    global_add_pool,
    global_max_pool,
    global_mean_pool,
)
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.utils import to_dense_batch


class Linear(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        bias: bool,
    ) -> None:
        super().__init__()
        self.model = torch.nn.Linear(in_channels, out_channels, bias=bias)

    def forward(self, batch):
        if isinstance(batch, torch.Tensor):
            batch = self.model(batch)
        else:
            batch.x = self.model(batch.x)
        return batch


class ResGatedGCNConv(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        bias: bool,
        **kwargs,
    ) -> None:
        super().__init__()
        self.model = ResGatedGraphConv(
            in_channels,
            out_channels,
            bias=bias,
            **kwargs,
        )

    def forward(self, batch):
        batch.x = self.model(batch.x, batch.edge_index)
        return batch


class GeneralLayer(torch.nn.Module):
    def __init__(
        self,
        name: str,
        in_channels: int,
        out_channels: int,
        has_batch_norm: bool,
        has_l2_norm: bool,
        dropout: float,
        act: Optional[str],
        **kwargs,
    ):
        super().__init__()
        self.has_l2_norm = has_l2_norm

        layer_dict = {
            'linear': Linear,
            'resgatedgcnconv': ResGatedGCNConv,
        }
        self.layer = layer_dict[name](
            in_channels,
            out_channels,
            bias=not has_batch_norm,
            **kwargs,
        )
        post_layers = []
        if has_batch_norm:
            post_layers.append(
                torch.nn.BatchNorm1d(out_channels, eps=1e-5, momentum=0.1))
        if dropout > 0:
            post_layers.append(torch.nn.Dropout(p=dropout, inplace=False))
        if act is not None:
            post_layers.append(activation_resolver(act))
        self.post_layer = nn.Sequential(*post_layers)

    def forward(self, batch):
        batch = self.layer(batch)
        if isinstance(batch, torch.Tensor):
            batch = self.post_layer(batch)
            if self.has_l2_norm:
                batch = F.normalize(batch, p=2, dim=1)
        else:
            batch.x = self.post_layer(batch.x)
            if self.has_l2_norm:
                batch.x = F.normalize(batch.x, p=2, dim=1)
        return batch


class GeneralMultiLayer(torch.nn.Module):
    def __init__(
        self,
        name: str,
        in_channels: int,
        out_channels: int,
        hidden_channels: Optional[int],
        num_layers: int,
        has_batch_norm: bool,
        has_l2_norm: bool,
        dropout: float,
        act: str,
        final_act: bool,
        **kwargs,
    ) -> None:
        super().__init__()
        hidden_channels = hidden_channels or out_channels

        for i in range(num_layers):
            d_in = in_channels if i == 0 else hidden_channels
            d_out = out_channels if i == num_layers - 1 else hidden_channels
            layer = GeneralLayer(
                name=name,
                in_channels=d_in,
                out_channels=d_out,
                has_batch_norm=has_batch_norm,
                has_l2_norm=has_l2_norm,
                dropout=dropout,
                act=None if i == num_layers - 1 and not final_act else act,
                **kwargs,
            )
            self.add_module(f'Layer_{i}', layer)

    def forward(self, batch):
        for layer in self.children():
            batch = layer(batch)
        return batch


class BatchNorm1dNode(torch.nn.Module):
    def __init__(self, channels: int) -> None:
        super().__init__()
        self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1)

    def forward(self, batch):
        batch.x = self.bn(batch.x)
        return batch


class BatchNorm1dEdge(torch.nn.Module):
    def __init__(self, channels: int) -> None:
        super().__init__()
        self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1)

    def forward(self, batch):
        batch.edge_attr = self.bn(batch.edge_attr)
        return batch


class MLP(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        hidden_channels: Optional[int],
        num_layers: int,
        has_batch_norm: bool = True,
        has_l2_norm: bool = True,
        dropout: float = 0.2,
        act: str = 'relu',
        **kwargs,
    ):
        super().__init__()
        hidden_channels = hidden_channels or in_channels

        layers = []
        if num_layers > 1:
            layer = GeneralMultiLayer(
                'linear',
                in_channels,
                hidden_channels,
                hidden_channels,
                num_layers - 1,
                has_batch_norm,
                has_l2_norm,
                dropout,
                act,
                final_act=True,
                **kwargs,
            )
            layers.append(layer)
        layers.append(Linear(hidden_channels, out_channels, bias=True))
        self.model = nn.Sequential(*layers)

    def forward(self, batch):
        if isinstance(batch, torch.Tensor):
            batch = self.model(batch)
        else:
            batch.x = self.model(batch.x)
        return batch


class GNNStackStage(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_layers: int,
        layer_type: str,
        stage_type: str = 'skipsum',
        final_l2_norm: bool = True,
        has_batch_norm: bool = True,
        has_l2_norm: bool = True,
        dropout: float = 0.2,
        act: Optional[str] = 'relu',
    ):
        super().__init__()
        self.num_layers = num_layers
        self.stage_type = stage_type
        self.final_l2_norm = final_l2_norm

        for i in range(num_layers):
            if stage_type == 'skipconcat':
                if i == 0:
                    d_in = in_channels
                else:
                    d_in = in_channels + i * out_channels
            else:
                d_in = in_channels if i == 0 else out_channels
            layer = GeneralLayer(layer_type, d_in, out_channels,
                                 has_batch_norm, has_l2_norm, dropout, act)
            self.add_module(f'layer{i}', layer)

    def forward(self, batch):
        for i, layer in enumerate(self.children()):
            x = batch.x
            batch = layer(batch)
            if self.stage_type == 'skipsum':
                batch.x = x + batch.x
            elif self.stage_type == 'skipconcat' and i < self.num_layers - 1:
                batch.x = torch.cat([x, batch.x], dim=1)

        if self.final_l2_norm:
            batch.x = F.normalize(batch.x, p=2, dim=-1)

        return batch


class GNNInductiveHybridMultiHead(torch.nn.Module):
    r"""GNN prediction head for inductive node and graph prediction tasks using
    individual MLP for each task.

    Args:
        dim_in (int): Input dimension.
        dim_out (int): Output dimension. Not used, as the dimension is
            determined by :obj:`num_node_targets` and :obj:`num_graph_targets`
            instead.
        num_node_targets (int): Number of individual PSEs used as node-level
            targets in pretraining :class:`GPSE`.
        num_graph_targets (int): Number of graph-level targets used in
            pretraining :class:`GPSE`.
        layers_post_mp (int): Number of MLP layers after GNN message-passing.
        virtual_node (bool, optional): Whether a virtual node is added to
            graphs in :class:`GPSE` computation. (default: :obj:`True`)
        multi_head_dim_inner (int, optional): Width of MLPs for PSE target
            prediction heads. (default: :obj:`32`)
        graph_pooling (str, optional): Type of graph pooling applied before
            post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`.
            (default: :obj:`add`)
        has_bn (bool, optional): Whether to apply batch normalization to layer
            outputs. (default: :obj:`True`)
        has_l2norm (bool, optional): Whether to apply L2 normalization to the
            layer outputs. (default: :obj:`True`)
        dropout (float, optional): Dropout ratio at layer output.
            (default: :obj:`0.2`)
        act (str, optional): Activation to apply to layer outputs if
            :obj:`has_act` is :obj:`True`. (default: :obj:`relu`)
    """
    def __init__(
        self,
        dim_in: int,
        dim_out: int,
        num_node_targets: int,
        num_graph_targets: int,
        layers_post_mp: int,
        virtual_node: bool = True,
        multi_head_dim_inner: int = 32,
        graph_pooling: str = 'add',
        has_bn: bool = True,
        has_l2norm: bool = True,
        dropout: float = 0.2,
        act: str = 'relu',
    ):
        super().__init__()
        pool_dict = {
            'add': global_add_pool,
            'max': global_max_pool,
            'mean': global_mean_pool
        }
        self.node_target_dim = num_node_targets
        self.graph_target_dim = num_graph_targets
        self.virtual_node = virtual_node
        num_layers = layers_post_mp

        self.node_post_mps = nn.ModuleList([
            MLP(dim_in, 1, multi_head_dim_inner, num_layers, has_bn,
                has_l2norm, dropout, act) for _ in range(self.node_target_dim)
        ])

        self.graph_pooling = pool_dict[graph_pooling]

        self.graph_post_mp = MLP(dim_in, self.graph_target_dim, dim_in,
                                 num_layers, has_bn, has_l2norm, dropout, act)

    def _pad_and_stack(self, x1: torch.Tensor, x2: torch.Tensor, pad1: int,
                       pad2: int):
        padded_x1 = nn.functional.pad(x1, (0, pad2))
        padded_x2 = nn.functional.pad(x2, (pad1, 0))
        return torch.vstack([padded_x1, padded_x2])

    def _apply_index(self, batch, virtual_node: bool, pad_node: int,
                     pad_graph: int):
        graph_pred, graph_true = batch.graph_feature, batch.y_graph
        node_pred, node_true = batch.node_feature, batch.y
        if virtual_node:
            # Remove virtual node
            idx = torch.concat([
                torch.where(batch.batch == i)[0][:-1]
                for i in range(batch.batch.max().item() + 1)
            ])
            node_pred, node_true = node_pred[idx], node_true[idx]

        # Stack node predictions on top of graph predictions and pad with zeros
        pred = self._pad_and_stack(node_pred, graph_pred, pad_node, pad_graph)
        true = self._pad_and_stack(node_true, graph_true, pad_node, pad_graph)

        return pred, true

    def forward(self, batch):
        batch.node_feature = torch.hstack(
            [m(batch.x) for m in self.node_post_mps])
        graph_emb = self.graph_pooling(batch.x, batch.batch)
        batch.graph_feature = self.graph_post_mp(graph_emb)
        return self._apply_index(batch, self.virtual_node,
                                 self.node_target_dim, self.graph_target_dim)


class IdentityHead(torch.nn.Module):
    def forward(self, batch):
        return batch.x, batch.y


[docs]class GPSE(torch.nn.Module): r"""The Graph Positional and Structural Encoder (GPSE) model from the `"Graph Positional and Structural Encoder" <https://arxiv.org/abs/2307.07107>`_ paper. The GPSE model consists of a (1) deep GNN that consists of stacked message passing layers, and a (2) prediction head to predict pre-computed positional and structural encodings (PSE). When used on downstream datasets, these prediction heads are removed and the final fully-connected layer outputs are used as learned PSE embeddings. GPSE also provides a static method :meth:`from_pretrained` to load pre-trained GPSE models trained on a variety of molecular datasets. .. code-block:: python from torch_geometric.nn import GPSE, GPSENodeEncoder from torch_geometric.transforms import AddGPSE from torch_geometric.nn.models.gpse import precompute_GPSE gpse_model = GPSE.from_pretrained('molpcba') # Option 1: Precompute GPSE encodings in-place for a given dataset dataset = ZINC(path, subset=True, split='train') precompute_gpse(gpse_model, dataset) # Option 2: Use the GPSE model with AddGPSE as a pre_transform to save # the encodings dataset = ZINC(path, subset=True, split='train', pre_transform=AddGPSE(gpse_model, vn=True, rand_type='NormalSE')) Both approaches append the generated encodings to the :obj:`pestat_GPSE` attribute of :class:`~torch_geometric.data.Data` objects. To use the GPSE encodings for a downstream task, one may need to add these encodings to the :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects. To do so, one can use the :class:`GPSENodeEncoder` provided to map these encodings to a desired dimension before appending them to :obj:`x`. Let's say we have a graph dataset with 64 original node features, and we have generated GPSE encodings of dimension 32, i.e. :obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an inner dimension of 128. To do so, we can map the 32-dimensional GPSE encodings to a higher dimension of 64, and then append them to the :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects to obtain a 128-dimensional node feature representation. :class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and concatenation to :obj:`x`, the outputs of which can be used as input to a GNN: .. code-block:: python encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64, expand_x=False) gnn = GNN(...) for batch in loader: x = encoder(batch.x, batch.pestat_GPSE) out = gnn(x, batch.edge_index) Args: dim_in (int, optional): Input dimension. (default: :obj:`20`) dim_out (int, optional): Output dimension. (default: :obj:`51`) dim_inner (int, optional): Width of the encoder layers. (default: :obj:`512`) layer_type (str, optional): Type of graph convolutional layer for message-passing. (default: :obj:`resgatedgcnconv`) layers_pre_mp (int, optional): Number of MLP layers before message-passing. (default: :obj:`1`) layers_mp (int, optional): Number of layers for message-passing. (default: :obj:`20`) layers_post_mp (int, optional): Number of MLP layers after message-passing. (default: :obj:`2`) num_node_targets (int, optional): Number of individual PSEs used as node-level targets in pretraining :class:`GPSE`. (default: :obj:`51`) num_graph_targets (int, optional): Number of graph-level targets used in pretraining :class:`GPSE`. (default: :obj:`11`) stage_type (str, optional): The type of staging to apply. Possible values are: :obj:`skipsum`, :obj:`skipconcat`. Any other value will default to no skip connections. (default: :obj:`skipsum`) has_bn (bool, optional): Whether to apply batch normalization in the layer. (default: :obj:`True`) final_l2norm (bool, optional): Whether to apply L2 normalization to the outputs. (default: :obj:`True`) has_l2norm (bool, optional): Whether to apply L2 normalization after the layer. (default: :obj:`True`) dropout (float, optional): Dropout ratio at layer output. (default: :obj:`0.2`) has_act (bool, optional): Whether has activation after the layer. (default: :obj:`True`) final_act (bool, optional): Whether to apply activation after the layer stack. (default: :obj:`True`) act (str, optional): Activation to apply to layer output if :obj:`has_act` is :obj:`True`. (default: :obj:`relu`) virtual_node (bool, optional): Whether a virtual node is added to graphs in :class:`GPSE` computation. (default: :obj:`True`) multi_head_dim_inner (int, optional): Width of MLPs for PSE target prediction heads. (default: :obj:`32`) graph_pooling (str, optional): Type of graph pooling applied before post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`. (default: :obj:`add`) use_repr (bool, optional): Whether to use the hidden representation of the final layer as :class:`GPSE` encodings. (default: :obj:`True`) repr_type (str, optional): Type of representation to use. Options are :obj:`no_post_mp`, :obj:`one_layer_before`. (default: :obj:`no_post_mp`) bernoulli_threshold (float, optional): Threshold for Bernoulli sampling of virtual nodes. (default: :obj:`0.5`) """ url_dict = { 'molpcba': 'https://zenodo.org/record/8145095/files/' 'gpse_model_molpcba_1.0.pt', 'zinc': 'https://zenodo.org/record/8145095/files/gpse_model_zinc_1.0.pt', 'pcqm4mv2': 'https://zenodo.org/record/8145095/files/' 'gpse_model_pcqm4mv2_1.0.pt', 'geom': 'https://zenodo.org/record/8145095/files/gpse_model_geom_1.0.pt', 'chembl': 'https://zenodo.org/record/8145095/files/gpse_model_chembl_1.0.pt' } def __init__( self, dim_in: int = 20, dim_out: int = 51, dim_inner: int = 512, layer_type: str = 'resgatedgcnconv', layers_pre_mp: int = 1, layers_mp: int = 20, layers_post_mp: int = 2, num_node_targets: int = 51, num_graph_targets: int = 11, stage_type: str = 'skipsum', has_bn: bool = True, head_bn: bool = False, final_l2norm: bool = True, has_l2norm: bool = True, dropout: float = 0.2, has_act: bool = True, final_act: bool = True, act: str = 'relu', virtual_node: bool = True, multi_head_dim_inner: int = 32, graph_pooling: str = 'add', use_repr: bool = True, repr_type: str = 'no_post_mp', bernoulli_threshold: float = 0.5, ): super().__init__() self.use_repr = use_repr self.repr_type = repr_type self.bernoulli_threshold = bernoulli_threshold if layers_pre_mp > 0: self.pre_mp = GeneralMultiLayer( name='linear', in_channels=dim_in, out_channels=dim_inner, hidden_channels=dim_inner, num_layers=layers_pre_mp, has_batch_norm=has_bn, has_l2_norm=has_l2norm, dropout=dropout, act=act, final_act=final_act, ) dim_in = dim_inner if layers_mp > 0: self.mp = GNNStackStage( in_channels=dim_in, out_channels=dim_inner, num_layers=layers_mp, layer_type=layer_type, stage_type=stage_type, final_l2_norm=final_l2norm, has_batch_norm=has_bn, has_l2_norm=has_l2norm, dropout=dropout, act=act if has_act else None, ) self.post_mp = GNNInductiveHybridMultiHead( dim_inner, dim_out, num_node_targets, num_graph_targets, layers_post_mp, virtual_node, multi_head_dim_inner, graph_pooling, head_bn, has_l2norm, dropout, act, ) self.reset_parameters()
[docs] def reset_parameters(self): pass
[docs] @classmethod def from_pretrained(cls, name: str, root: str = 'GPSE_pretrained'): r"""Returns a pretrained :class:`GPSE` model on a dataset. Args: name (str): The name of the dataset (:obj:`"molpcba"`, :obj:`"zinc"`, :obj:`"pcqm4mv2"`, :obj:`"geom"`, :obj:`"chembl"`). root (str, optional): The root directory to save the pre-trained model. (default: :obj:`"GPSE_pretrained"`) """ root = osp.expanduser(osp.normpath(root)) os.makedirs(root, exist_ok=True) path = download_url(cls.url_dict[name], root) model = GPSE() # All pretrained models use the default arguments model_state = torch.load(path, map_location='cpu')['model_state'] model_state_new = OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state.items()]) model.load_state_dict(model_state_new) # Set the final linear layer to identity if we use hidden reprs if model.use_repr: if model.repr_type == 'one_layer_before': model.post_mp.layer_post_mp.model[-1] = torch.nn.Identity() elif model.repr_type == 'no_post_mp': model.post_mp = IdentityHead() else: raise ValueError(f"Unknown type '{model.repr_type}'") model.eval() return model
[docs] def forward(self, batch): batch = batch.clone() for module in self.children(): batch = module(batch) return batch
[docs]class GPSENodeEncoder(torch.nn.Module): r"""A helper linear/MLP encoder that takes the :class:`GPSE` encodings (based on the `"Graph Positional and Structural Encoder" <https://arxiv.org/abs/2307.07107>`_ paper) precomputed as :obj:`batch.pestat_GPSE` in the input graphs, maps them to a desired dimension defined by :obj:`dim_pe_out` and appends them to node features. Let's say we have a graph dataset with 64 original node features, and we have generated GPSE encodings of dimension 32, i.e. :obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an inner dimension of 128. To do so, we can map the 32-dimensional GPSE encodings to a higher dimension of 64, and then append them to the :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects to obtain a 128-dimensional node feature representation. :class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and concatenation to :obj:`x`, the outputs of which can be used as input to a GNN: .. code-block:: python encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64, expand_x=False) gnn = GNN(...) for batch in loader: x = encoder(batch.x, batch.pestat_GPSE) batch = gnn(x, batch.edge_index) Args: dim_emb (int): Size of final node embedding. dim_pe_in (int): Original dimension of :obj:`batch.pestat_GPSE`. dim_pe_out (int): Desired dimension of :class:`GPSE` after the encoder. dim_in (int, optional): Original dimension of input node features, required only if :obj:`expand_x` is set to :obj:`True`. (default: :obj:`None`) expand_x (bool, optional): Expand node features :obj:`x` from :obj:`dim_in` to (:obj:`dim_emb` - :obj:`dim_pe_out`) norm_type (str, optional): Type of normalization to apply. (default: :obj:`batchnorm`) model_type (str, optional): Type of encoder, either :obj:`mlp` or :obj:`linear`. (default: :obj:`mlp`) n_layers (int, optional): Number of MLP layers if :obj:`model_type` is :obj:`mlp`. (default: :obj:`2`) dropout_be (float, optional): Dropout ratio of inputs to encoder, i.e. before encoding. (default: :obj:`0.5`) dropout_ae (float, optional): Dropout ratio of outputs, i.e. after encoding. (default: :obj:`0.2`) """ def __init__(self, dim_emb: int, dim_pe_in: int, dim_pe_out: int, dim_in: int = None, expand_x=False, norm_type='batchnorm', model_type='mlp', n_layers=2, dropout_be=0.5, dropout_ae=0.2): super().__init__() assert dim_emb > dim_pe_out, ('Desired GPSE dimension (dim_pe_out) ' 'must be smaller than the final node ' 'embedding dimension (dim_emb).') if expand_x: self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe_out) self.expand_x = expand_x self.raw_norm = None if norm_type == 'batchnorm': self.raw_norm = nn.BatchNorm1d(dim_pe_in) self.dropout_be = nn.Dropout(p=dropout_be) self.dropout_ae = nn.Dropout(p=dropout_ae) activation = nn.ReLU # register.act_dict[cfg.gnn.act] if model_type == 'mlp': layers = [] if n_layers == 1: layers.append(torch.nn.Linear(dim_pe_in, dim_pe_out)) layers.append(activation()) else: layers.append(torch.nn.Linear(dim_pe_in, 2 * dim_pe_out)) layers.append(activation()) for _ in range(n_layers - 2): layers.append( torch.nn.Linear(2 * dim_pe_out, 2 * dim_pe_out)) layers.append(activation()) layers.append(torch.nn.Linear(2 * dim_pe_out, dim_pe_out)) layers.append(activation()) self.pe_encoder = nn.Sequential(*layers) elif model_type == 'linear': self.pe_encoder = nn.Linear(dim_pe_in, dim_pe_out) else: raise ValueError(f"{self.__class__.__name__}: Does not support " f"'{model_type}' encoder model.")
[docs] def forward(self, x, pos_enc): pos_enc = self.dropout_be(pos_enc) pos_enc = self.raw_norm(pos_enc) if self.raw_norm else pos_enc pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe pos_enc = self.dropout_ae(pos_enc) # Expand node features if needed h = self.linear_x(x) if self.expand_x else x # Concatenate final PEs to input embedding return torch.cat((h, pos_enc), 1)
@torch.no_grad() def gpse_process( model: Module, data: Data, rand_type: str, use_vn: bool = True, bernoulli_thresh: float = 0.5, neighbor_loader: bool = False, num_neighbors: Optional[List[int]] = None, fillval: int = 5, layers_mp: int = None, **kwargs, ) -> torch.Tensor: r"""Processes the data using the :class:`GPSE` model to generate and append GPSE encodings. Identical to :obj:`gpse_process_batch`, but operates on a single :class:`~torch_geometric.data.Dataset` object. Unlike transform-based GPSE processing (i.e. :class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument does not append virtual nodes if set to :obj:`True`, and instead assumes the input graphs to :obj:`gpse_process` already have virtual nodes. Under normal circumstances, one does not need to call this function; running :obj:`precompute_GPSE` on your whole dataset is advised instead. Args: model (Module): The :class:`GPSE` model. data (torch_geometric.data.Data): A :class:`~torch_geometric.data.Data` object. rand_type (str, optional): Type of random features to use. Options are :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`. (default: :obj:`NormalSE`) use_vn (bool, optional): Whether the input graphs have virtual nodes. (default: :obj:`True`) bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of virtual nodes. (default: :obj:`0.5`) neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`. (default: :obj:`False`) num_neighbors (List[int], optional): Number of neighbors to consider for each message-passing layer. (default: :obj:`[30, 20, 10]`) fillval (int, optional): Value to fill for missing :obj:`num_neighbors`. (default: :obj:`5`) layers_mp (int, optional): Number of message-passing layers. (default: :obj:`None`) **kwargs (optional): Additional arguments for :obj:`NeighborLoader`. Returns: torch.Tensor: A tensor corresponding to the original :class:`~torch_geometric.data.Data` object, with :class:`GPSE` encodings appended as :obj:`out.pestat_GPSE` attribute. """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Generate random features for the encoder n = data.num_nodes dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1] # Prepare input distributions for GPSE if rand_type == 'NormalSE': rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in)) elif rand_type == 'UniformSE': rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in)) elif rand_type == 'BernoulliSE': rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in)) rand = (rand < bernoulli_thresh) else: raise ValueError(f'Unknown {rand_type=!r}') data.x = torch.from_numpy(rand.astype('float32')) if use_vn: data.x[-1] = 0 model, data = model.to(device), data.to(device) # Generate encodings using the pretrained encoder if neighbor_loader: if layers_mp is None: raise ValueError('Please provide the number of message-passing ' 'layers as "layers_mp".') num_neighbors = num_neighbors or [30, 20, 10] diff = layers_mp - len(num_neighbors) if fillval > 0 and diff > 0: num_neighbors += [fillval] * diff loader = NeighborLoader(data, num_neighbors=num_neighbors, shuffle=False, pin_memory=True, **kwargs) out_list = [] pbar = trange(data.num_nodes, position=2) for batch in loader: out, _ = model(batch.to(device)) out = out[:batch.batch_size].to("cpu", non_blocking=True) out_list.append(out) pbar.update(batch.batch_size) out = torch.vstack(out_list) else: out, _ = model(data) out = out.to("cpu") return out @torch.no_grad() def gpse_process_batch( model: GPSE, batch, rand_type: str, use_vn: bool = True, bernoulli_thresh: float = 0.5, neighbor_loader: bool = False, num_neighbors: Optional[List[int]] = None, fillval: int = 5, layers_mp: int = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Process a batch of data using the :class:`GPSE` model to generate and append :class:`GPSE` encodings. Identical to `gpse_process`, but operates on a batch of :class:`~torch_geometric.data.Data` objects. Unlike transform-based GPSE processing (i.e. :class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument does not append virtual nodes if set to :obj:`True`, and instead assumes the input graphs to :obj:`gpse_process` already have virtual nodes. This is because the virtual nodes are already added to graphs before the call to :obj:`gpse_process_batch` in :obj:`precompute_GPSE` for better efficiency. Under normal circumstances, one does not need to call this function; running :obj:`precompute_GPSE` on your whole dataset is advised instead. Args: model (GPSE): The :class:`GPSE` model. batch: A batch of PyG Data objects. rand_type (str, optional): Type of random features to use. Options are :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`. (default: :obj:`NormalSE`) use_vn (bool, optional): Whether the input graphs have virtual nodes. (default: :obj:`True`) bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of virtual nodes. (default: :obj:`0.5`) neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`. (default: :obj:`False`) num_neighbors (List[int], optional): Number of neighbors to consider for each message-passing layer. (default: :obj:`[30, 20, 10]`) fillval (int, optional): Value to fill for missing :obj:`num_neighbors`. (default: :obj:`5`) layers_mp (int, optional): Number of message-passing layers. (default: :obj:`None`) **kwargs: Additional keyword arguments for :obj:`NeighborLoader`. Returns: Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding to the stacked :class:`GPSE` encodings and the pointers indicating individual graphs. """ n = batch.num_nodes dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1] # Prepare input distributions for GPSE if rand_type == 'NormalSE': rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in)) elif rand_type == 'UniformSE': rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in)) elif rand_type == 'BernoulliSE': rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in)) rand = (rand < bernoulli_thresh) else: raise ValueError(f'Unknown {rand_type=!r}') batch.x = torch.from_numpy(rand.astype('float32')) if use_vn: # HACK: We need to reset virtual node features to zeros to match the # pretraining setting (virtual node applied after random node features # are set, and the default node features for the virtual node are all # zeros). Can potentially test if initializing virtual node features to # random features is better than setting them to zeros. for i in batch.ptr[1:]: batch.x[i - 1] = 0 # Generate encodings using the pretrained encoder device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) if neighbor_loader: if layers_mp is None: raise ValueError('Please provide the number of message-passing ' 'layers as "layers_mp".') num_neighbors = num_neighbors or [30, 20, 10] diff = layers_mp - len(num_neighbors) if fillval > 0 and diff > 0: num_neighbors += [fillval] * diff loader = NeighborLoader(batch, num_neighbors=num_neighbors, shuffle=False, pin_memory=True, **kwargs) out_list = [] pbar = trange(batch.num_nodes, position=2) for batch in loader: out, _ = model(batch.to(device)) out = out[:batch.batch_size].to('cpu', non_blocking=True) out_list.append(out) pbar.update(batch.batch_size) out = torch.vstack(out_list) else: out, _ = model(batch.to(device)) out = out.to('cpu') return out, batch.ptr @torch.no_grad() def precompute_GPSE(model: GPSE, dataset: Dataset, use_vn: bool = True, rand_type: str = 'NormalSE', **kwargs): r"""Precomputes :class:`GPSE` encodings in-place for a given dataset using a :class:`GPSE` model. Args: model (GPSE): The :class:`GPSE` model. dataset (Dataset): A PyG Dataset. use_vn (bool, optional): Whether to append virtual nodes to graphs in :class:`GPSE` computation. Should match the setting used when pre-training the :class:`GPSE` model. (default :obj:`True`) rand_type (str, optional): The type of randomization to use. (default :obj:`NormalSE`) **kwargs (optional): Additional arguments for :class:`~torch_geometric.data.DataLoader`. """ # Temporarily replace the transformation orig_dataset_transform = dataset.transform dataset.transform = None if use_vn: dataset.transform = T.VirtualNode() # Remove split indices, to be recovered at the end of the precomputation tmp_store = {} for name in [ 'train_mask', 'val_mask', 'test_mask', 'train_graph_index', 'val_graph_index', 'test_graph_index', 'train_edge_index', 'val_edge_index', 'test_edge_index' ]: if (name in dataset.data) and (dataset.slices is None or name in dataset.slices): tmp_store_data = dataset.data.pop(name) tmp_store_slices = dataset.slices.pop(name) \ if dataset.slices else None tmp_store[name] = (tmp_store_data, tmp_store_slices) loader = DataLoader(dataset, shuffle=False, pin_memory=True, **kwargs) # Batched GPSE precomputation loop data_list = [] curr_idx = 0 pbar = trange(len(dataset), desc='Pre-computing GPSE') tic = time.perf_counter() for batch in loader: batch_out, batch_ptr = gpse_process_batch(model, batch, rand_type, **kwargs) batch_out = batch_out.to('cpu', non_blocking=True) # Need to wait for batch_ptr to finish transferring so that start and # end indices are ready to use batch_ptr = batch_ptr.to('cpu', non_blocking=False) for start, end in zip(batch_ptr[:-1], batch_ptr[1:]): data = dataset.get(curr_idx) if use_vn: end = end - 1 data.pestat_GPSE = batch_out[start:end] data_list.append(data) curr_idx += 1 pbar.update(len(batch_ptr) - 1) pbar.close() # Collate dataset and reset indices and data list dataset.transform = orig_dataset_transform dataset._indices = None dataset._data_list = data_list dataset.data, dataset.slices = dataset.collate(data_list) # Recover split indices for name, (tmp_store_data, tmp_store_slices) in tmp_store.items(): dataset.data[name] = tmp_store_data if tmp_store_slices is not None: dataset.slices[name] = tmp_store_slices dataset._data_list = None timestr = time.strftime('%H:%M:%S', time.gmtime(time.perf_counter() - tic)) logging.info(f'Finished GPSE pre-computation, took {timestr}') # Release resource and recover original configs del model torch.cuda.empty_cache() def cosim_col_sep(pred: torch.Tensor, true: torch.Tensor, batch_idx: torch.Tensor) -> torch.Tensor: r"""Calculates the average cosine similarity between predicted and true features on a batch of graphs. Args: pred (torch.Tensor): Predicted outputs. true (torch.Tensor): Value of ground truths. batch_idx (torch.Tensor): Batch indices to separate the graphs. Returns: torch.Tensor: Average cosine similarity per graph in batch. Raises: ValueError: If batch_index is not specified. """ if batch_idx is None: raise ValueError("mae_cosim_col_sep requires batch index as " "input to distinguish different graphs.") batch_idx = batch_idx + 1 if batch_idx.min() == -1 else batch_idx pred_dense = to_dense_batch(pred, batch_idx)[0] true_dense = to_dense_batch(true, batch_idx)[0] mask = (true_dense == 0).all(1) # exclude trivial features from loss loss = 1 - F.cosine_similarity(pred_dense, true_dense, dim=1)[~mask].mean() return loss def gpse_loss(pred: torch.Tensor, true: torch.Tensor, batch_idx: torch.Tensor = None) \ -> Tuple[torch.Tensor, torch.Tensor]: r"""Calculates :class:`GPSE` loss as the sum of MAE loss and cosine similarity loss over a batch of graphs. Args: pred (torch.Tensor): Predicted outputs. true (torch.Tensor): Value of ground truths. batch_idx (torch.Tensor): Batch indices to separate the graphs. Returns: Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding to the :class:`GPSE` loss and the predicted node-and-graph level outputs. """ if batch_idx is None: raise ValueError("mae_cosim_col_sep requires batch index as " "input to distinguish different graphs.") mae_loss = F.l1_loss(pred, true) cosim_loss = cosim_col_sep(pred, true, batch_idx) loss = mae_loss + cosim_loss return loss, pred def process_batch_idx(batch_idx, true, use_vn=True): r"""Processes batch indices to adjust for the removal of virtual nodes, and pads batch index for hybrid tasks. Args: batch_idx: Batch indices to separate the graphs. true: Value of ground truths. use_vn: If input graphs have virtual nodes that need to be removed. Returns: torch.Tensor: Batch indices that separate the graphs. """ if batch_idx is None: return if use_vn: # remove virtual node batch_idx = torch.concat([ batch_idx[batch_idx == i][:-1] for i in range(batch_idx.max().item() + 1) ]) # Pad batch index for hybrid tasks (set batch index for graph heads to -1) if (pad := true.shape[0] - batch_idx.shape[0]) > 0: pad_idx = -torch.ones(pad, dtype=torch.long, device=batch_idx.device) batch_idx = torch.hstack([batch_idx, pad_idx]) return batch_idx