Source code for gigl.src.common.models.pyg.graph.augmentations

import torch
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops, dropout_edge, remove_self_loops


[docs] def drop_feature(x: torch.Tensor, drop_prob: float) -> torch.Tensor: """GRACE feature dropping function with probability drop_prob. From: https://github.com/CRIPAC-DIG/GRACE/blob/51b44961b68b2f38c60f85cf83db13bed8fd0780/model.py#L120 """ if drop_prob == 0: return x elif drop_prob < 0 or drop_prob > 1: raise ValueError(f"Invalid probability provided for Feat Drop, got {drop_prob}") drop_mask = ( torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1) < drop_prob ) x = x.clone() x[:, drop_mask] = 0 return x
[docs] def get_augmented_graph( graph: Data, edge_drop_ratio: float = 0.3, feat_drop_ratio: float = 0.3, graph_perm: bool = False, ) -> Data: """ PyG implementation of DGL transformations. Supports augmentations such as dropping random edges (edge_drop_ratio), dropping random feature components (feat_drop_ratio), and graph permutation (shuffling the nodes and edges of the graph randomly) https://docs.dgl.ai/en/0.9.x/api/python/transforms.html """ if edge_drop_ratio < 0 or edge_drop_ratio > 1: raise ValueError( f"Invalid probability provided for Edge Drop, got {edge_drop_ratio}" ) if feat_drop_ratio < 0 or feat_drop_ratio > 1: raise ValueError( f"Invalid probability provided for Feat Drop, got {feat_drop_ratio}" ) data = graph.clone() if graph_perm: row_perm = torch.randperm(data.x.size(0)) data.x = data.x[row_perm, :] data.edge_index = torch.randint_like(data.edge_index, data.num_nodes - 1) _, edge_mask = dropout_edge(data.edge_index, p=edge_drop_ratio, training=True) if data.edge_attr is not None: data.edge_attr = data.edge_attr[edge_mask] data.edge_index = data.edge_index[:, edge_mask] data.edge_index, _ = remove_self_loops(data.edge_index) data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.x.size(0)) data.x = drop_feature(data.x, feat_drop_ratio) return data