Source code for gigl.src.common.models.pyg.graph.augmentations
import torch
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops, dropout_edge, remove_self_loops
[docs]
def drop_feature(x: torch.Tensor, drop_prob: float) -> torch.Tensor:
"""GRACE feature dropping function with probability drop_prob.
From: https://github.com/CRIPAC-DIG/GRACE/blob/51b44961b68b2f38c60f85cf83db13bed8fd0780/model.py#L120
"""
if drop_prob == 0:
return x
elif drop_prob < 0 or drop_prob > 1:
raise ValueError(f"Invalid probability provided for Feat Drop, got {drop_prob}")
drop_mask = (
torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1)
< drop_prob
)
x = x.clone()
x[:, drop_mask] = 0
return x
[docs]
def get_augmented_graph(
graph: Data,
edge_drop_ratio: float = 0.3,
feat_drop_ratio: float = 0.3,
graph_perm: bool = False,
) -> Data:
"""
PyG implementation of DGL transformations. Supports augmentations such as dropping random edges (edge_drop_ratio), dropping random feature components (feat_drop_ratio),
and graph permutation (shuffling the nodes and edges of the graph randomly)
https://docs.dgl.ai/en/0.9.x/api/python/transforms.html
"""
if edge_drop_ratio < 0 or edge_drop_ratio > 1:
raise ValueError(
f"Invalid probability provided for Edge Drop, got {edge_drop_ratio}"
)
if feat_drop_ratio < 0 or feat_drop_ratio > 1:
raise ValueError(
f"Invalid probability provided for Feat Drop, got {feat_drop_ratio}"
)
data = graph.clone()
if graph_perm:
row_perm = torch.randperm(data.x.size(0))
data.x = data.x[row_perm, :]
data.edge_index = torch.randint_like(data.edge_index, data.num_nodes - 1)
_, edge_mask = dropout_edge(data.edge_index, p=edge_drop_ratio, training=True)
if data.edge_attr is not None:
data.edge_attr = data.edge_attr[edge_mask]
data.edge_index = data.edge_index[:, edge_mask]
data.edge_index, _ = remove_self_loops(data.edge_index)
data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.x.size(0))
data.x = drop_feature(data.x, feat_drop_ratio)
return data