Source code for gigl.src.common.graph_builder.pyg_graph_data

from __future__ import annotations

from typing import Any, Dict, List, Optional

import torch
from torch_geometric.data.hetero_data import HeteroData
from torch_geometric.data.storage import EdgeStorage

from gigl.common.collections.frozen_dict import FrozenDict
from gigl.src.common.graph_builder.gbml_graph_protocol import GbmlGraphDataProtocol
from gigl.src.common.types.graph_data import Edge, EdgeType, Node, NodeId, NodeType


[docs] class PygGraphData(HeteroData, GbmlGraphDataProtocol): """ Extends pytorch geometric graph data objects to provide support for more functionality. i.e. providing functionality to do equality checks """ def __init__(self, **kwargs) -> None: super().__init__( **kwargs, ) self.__global_node_to_subgraph_node_mapping: FrozenDict[ Node, Node ] = FrozenDict({}) self.__subgraph_node_to_global_node_mapping: Optional[ FrozenDict[Node, Node] ] = None @property
[docs] def edge_types_to_be_registered( self, ) -> List[EdgeType]: edge_types_to_be_registered = [] if hasattr(self, "_edge_store_dict"): edge_types_to_be_registered = [ EdgeType( src_node_type=src_node_type, relation=relation, dst_node_type=dst_node_type, ) for src_node_type, relation, dst_node_type in self._edge_store_dict.keys() ] return edge_types_to_be_registered
@property
[docs] def global_node_to_subgraph_node_mapping( self, ) -> FrozenDict[Node, Node]: return self.__global_node_to_subgraph_node_mapping
@global_node_to_subgraph_node_mapping.setter def global_node_to_subgraph_node_mapping( self, global_node_to_subgraph_node_mapping: FrozenDict[Node, Node] ) -> None: self.__global_node_to_subgraph_node_mapping = FrozenDict( global_node_to_subgraph_node_mapping ) self.__subgraph_node_to_global_node_mapping = None @property
[docs] def subgraph_node_to_global_node_mapping(self) -> FrozenDict[Node, Node]: if self.__subgraph_node_to_global_node_mapping is None: self.__subgraph_node_to_global_node_mapping = FrozenDict( {v: k for k, v in self.global_node_to_subgraph_node_mapping.items()} ) return self.__subgraph_node_to_global_node_mapping
def __eq__(self, other: object) -> bool: if not isinstance(other, PygGraphData): return False if ( self.global_node_to_subgraph_node_mapping != other.global_node_to_subgraph_node_mapping ): return False if not ( hasattr(self, "x_dict") == hasattr(other, "x_dict") and hasattr(self, "_edge_store_dict") == hasattr(other, "_edge_store_dict") ): return False if hasattr(self, "x_dict"): if len(self.x_dict) != len(other.x_dict): return False for self_x_key, self_x_val in self.x_dict.items(): if self_x_key not in other.x_dict: return False other_x_val = other.x_dict[self_x_key] if not torch.equal(self_x_val, other_x_val): return False if hasattr(self, "_edge_store_dict"): if len(self._edge_store_dict) != len(other._edge_store_dict): return False self_edge_store: EdgeStorage for ( self_edge_type, self_edge_store, ) in self._edge_store_dict.items(): if self_edge_type not in other._edge_store_dict: return False other_edge_store = other._edge_store_dict[self_edge_type] for ( key, tensor, ) in self_edge_store.items(): # edge_attr, edge_index (keys) if key not in other_edge_store: return False if not torch.equal(tensor, other_edge_store[key]): return False return True
[docs] def get_global_node_features_dict(self) -> FrozenDict[Node, torch.Tensor]: if not hasattr(self, "x_dict"): return FrozenDict({}) global_node_to_features_map: Dict[Node, torch.Tensor] = {} for self_node_type, all_node_features_for_node_type in self.x_dict.items(): for subgraph_node_id, node_features in enumerate( all_node_features_for_node_type ): subgraph_node = Node( type=NodeType(self_node_type), id=NodeId(subgraph_node_id) ) global_node = ( self.subgraph_node_to_global_node_mapping[subgraph_node] if subgraph_node in self.subgraph_node_to_global_node_mapping else subgraph_node ) global_node_to_features_map[global_node] = node_features return FrozenDict(global_node_to_features_map)
[docs] def get_global_edge_features_dict(self) -> FrozenDict[Edge, torch.Tensor]: global_edge_to_features_map: Dict[Edge, torch.Tensor] = {} is_graph_data_in_global_space: bool = ( not self.subgraph_node_to_global_node_mapping ) if hasattr(self, "_edge_store_dict"): # Below, example of edge_index = # [[10, 20], [20, 30]] # meaning the following edges exist 10 --> 20, and 20 --> 30 for ( edge_type, edge_store, ) in self._edge_store_dict.items(): edge_index = edge_store["edge_index"] edge_attr = edge_store.get("edge_attr", None) src_node_type, relation, dst_node_type = edge_type for edge_number, ( subgraph_src_node_id_tensor, subgraph_dst_node_id_tensor, ) in enumerate(zip(edge_index[0], edge_index[1])): subgraph_src_node_id = subgraph_src_node_id_tensor.item() subgraph_dst_node_id = subgraph_dst_node_id_tensor.item() subgraph_src_node = Node( type=NodeType(src_node_type), id=NodeId(subgraph_src_node_id) ) subgraph_dst_node = Node( type=NodeType(dst_node_type), id=NodeId(subgraph_dst_node_id) ) global_src_node = ( subgraph_src_node if is_graph_data_in_global_space else self.subgraph_node_to_global_node_mapping[ subgraph_src_node ] ) global_dst_node = ( subgraph_dst_node if is_graph_data_in_global_space else self.subgraph_node_to_global_node_mapping[ subgraph_dst_node ] ) edge = Edge( src_node_id=global_src_node.id, dst_node_id=global_dst_node.id, edge_type=EdgeType(src_node_type, relation, dst_node_type), ) edge_feature = ( edge_attr[edge_number] if edge_attr is not None else None ) global_edge_to_features_map[edge] = edge_feature return FrozenDict(global_edge_to_features_map)
[docs] def to_hetero_data(self) -> HeteroData: """ Convert the PygGraphData object back to a PyG HeteroData object returns: HeteroData: The converted HeteroData object """ hetero_data = HeteroData() hetero_data.update(data=self) return hetero_data
@classmethod
[docs] def from_hetero_data(cls, data: HeteroData) -> PygGraphData: pyg_graph_data = cls() if hasattr(data, "x_dict"): for x_key, x_val in data.x_dict.items(): pyg_graph_data[x_key].x = x_val if hasattr(data, "_edge_store_dict"): for ( edge_type, edge_store, ) in data._edge_store_dict.items(): pyg_graph_data[edge_type].edge_index = edge_store.edge_index if hasattr(edge_store, "edge_attr"): pyg_graph_data[edge_type].edge_attr = edge_store.edge_attr return pyg_graph_data
def __repr__(self) -> str: return f"""PygGraphData( global_node_to_subgraph_node_mapping={self.global_node_to_subgraph_node_mapping} x_dict={self.x_dict if hasattr(self, "x_dict") else {}} _edge_store_dict={self._edge_store_dict if hasattr(self, "_edge_store_dict") else {}} ) """ def __setattr__(self, key: str, value: Any): """Need to override functionality cause HeteroData does some weird logic with its `__setattr__` function making @property.setter un-usable """ if key in self.__class__.__dict__: return object.__setattr__(self, key, value) return super().__setattr__(key, value)