Source code for gigl.types.graph

from collections import abc
from dataclasses import dataclass
from typing import Optional, TypeVar, Union, overload

import torch
from graphlearn_torch.partition import PartitionBook

from gigl.common.data.dataloaders import SerializedTFRecordInfo
from gigl.common.logger import Logger

# TODO(kmonte) - we should move gigl.src.common.types.graph_data to this file.
from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation

[docs] logger = Logger()
[docs] DEFAULT_HOMOGENEOUS_NODE_TYPE = NodeType("default_homogeneous_node_type")
[docs] DEFAULT_HOMOGENEOUS_EDGE_TYPE = EdgeType( src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, relation=Relation("to"), dst_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE, )
_POSITIVE_LABEL_TAG = "gigl_positive" _NEGATIVE_LABEL_TAG = "gigl_negative" # We really should support PyG EdgeType natively but since we type ignore it that's not ideal atm... # We can use this TypeVar to try and stem the bleeding (hopefully). _EdgeType = TypeVar("_EdgeType", EdgeType, tuple[str, str, str]) # TODO(kmonte, mkolodner): Move SerializedGraphMetadata and maybe convert_pb_to_serialized_graph_metadata here. @dataclass(frozen=True)
[docs] class FeaturePartitionData: """Data and indexing info of a node/edge feature partition.""" # node/edge feature tensor
[docs] feats: torch.Tensor
# node/edge ids tensor corresponding to `feats`. This is Optional since we do not need this field for range-based partitioning
[docs] ids: Optional[torch.Tensor]
@dataclass(frozen=True)
[docs] class GraphPartitionData: """Data and indexing info of a graph partition.""" # edge index (rows, cols)
[docs] edge_index: torch.Tensor
# edge ids tensor corresponding to `edge_index`
[docs] edge_ids: torch.Tensor
# weights tensor corresponding to `edge_index`
[docs] weights: Optional[torch.Tensor] = None
# This dataclass should not be frozen, as we are expected to delete partition outputs once they have been registered inside of GLT DistDataset # in order to save memory. @dataclass
[docs] class PartitionOutput: # Node partition book
[docs] node_partition_book: Union[PartitionBook, dict[NodeType, PartitionBook]]
# Edge partition book
[docs] edge_partition_book: Union[PartitionBook, dict[EdgeType, PartitionBook]]
# Partitioned edge index on current rank. This field will always be populated after partitioning. However, we may set this # field to None during dataset.build() in order to minimize the peak memory usage, and as a result type this as Optional.
[docs] partitioned_edge_index: Optional[ Union[GraphPartitionData, dict[EdgeType, GraphPartitionData]] ]
# Node features on current rank, May be None if node features are not partitioned
[docs] partitioned_node_features: Optional[ Union[FeaturePartitionData, dict[NodeType, FeaturePartitionData]] ]
# Edge features on current rank, May be None if edge features are not partitioned
[docs] partitioned_edge_features: Optional[ Union[FeaturePartitionData, dict[EdgeType, FeaturePartitionData]] ]
# Positive edge indices on current rank, May be None if positive edge labels are not partitioned
[docs] partitioned_positive_labels: Optional[ Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ]
# Negative edge indices on current rank, May be None if negative edge labels are not partitioned
[docs] partitioned_negative_labels: Optional[ Union[torch.Tensor, dict[EdgeType, torch.Tensor]] ]
# This dataclass should not be frozen, as we are expected to delete its members once they have been registered inside of the partitioner # in order to save memory. @dataclass
[docs] class LoadedGraphTensors: # Unpartitioned Node Ids
[docs] node_ids: Union[torch.Tensor, dict[NodeType, torch.Tensor]]
# Unpartitioned Node Features
[docs] node_features: Optional[Union[torch.Tensor, dict[NodeType, torch.Tensor]]]
# Unpartitioned Edge Index
[docs] edge_index: Union[torch.Tensor, dict[EdgeType, torch.Tensor]]
# Unpartitioned Edge Features
[docs] edge_features: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]
# Unpartitioned Positive Edge Label
[docs] positive_label: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]
# Unpartitioned Negative Edge Label
[docs] negative_label: Optional[Union[torch.Tensor, dict[EdgeType, torch.Tensor]]]
[docs] def treat_labels_as_edges(self) -> None: """ Convert positive and negative labels to edges. Converts this object in-place to a "heterogeneous" representation. This function requires the following conditions and will throw if they are not met: 1. The positive_label is not None """ if self.positive_label is None: raise ValueError( "Cannot treat labels as edges when positive label is None." ) edge_index_with_labels = to_heterogeneous_edge(self.edge_index) if len(edge_index_with_labels) == 1: main_edge_type = next(iter(edge_index_with_labels.keys())) logger.info( f"Basing positive and negative labels on edge types on edge type: {main_edge_type}." ) else: main_edge_type = None if isinstance(self.positive_label, torch.Tensor): if main_edge_type is None: raise ValueError( "Detected multiple edge types in provided edge_index, but no edge types specified for provided positive label." ) positive_label_edge_type = message_passing_to_positive_label(main_edge_type) logger.info( f"Treating homogeneous positive labels as edge type {positive_label_edge_type}." ) edge_index_with_labels[positive_label_edge_type] = self.positive_label elif isinstance(self.positive_label, dict): for ( positive_label_type, positive_label_tensor, ) in self.positive_label.items(): positive_label_edge_type = message_passing_to_positive_label( positive_label_type ) logger.info( f"Treating heterogeneous positive labels {positive_label_type} as edge type {positive_label_edge_type}." ) edge_index_with_labels[positive_label_edge_type] = positive_label_tensor if isinstance(self.negative_label, torch.Tensor): if main_edge_type is None: raise ValueError( "Detected multiple edge types in provided edge_index, but no edge types specified for provided negative label." ) negative_label_edge_type = message_passing_to_negative_label(main_edge_type) logger.info( f"Treating homogeneous negative labels as edge type {negative_label_edge_type}." ) edge_index_with_labels[negative_label_edge_type] = self.negative_label elif isinstance(self.negative_label, dict): for ( negative_label_type, negative_label_tensor, ) in self.negative_label.items(): negative_label_edge_type = message_passing_to_negative_label( negative_label_type ) logger.info( f"Treating heterogeneous negative labels {negative_label_type} as edge type {negative_label_edge_type}." ) edge_index_with_labels[negative_label_edge_type] = negative_label_tensor self.node_ids = to_heterogeneous_node(self.node_ids) self.node_features = to_heterogeneous_node(self.node_features) self.edge_index = edge_index_with_labels self.edge_features = to_heterogeneous_edge(self.edge_features) self.positive_label = None self.negative_label = None
[docs] def message_passing_to_positive_label( message_passing_edge_type: _EdgeType, ) -> _EdgeType: """Convert a message passing edge type to a positive label edge type. Args: message_passing_edge_type (EdgeType): The message passing edge type. Returns: EdgeType: The positive label edge type. """ edge_type = ( str(message_passing_edge_type[0]), f"{message_passing_edge_type[1]}_{_POSITIVE_LABEL_TAG}", str(message_passing_edge_type[2]), ) if isinstance(message_passing_edge_type, EdgeType): return EdgeType( NodeType(edge_type[0]), Relation(edge_type[1]), NodeType(edge_type[2]) ) else: return edge_type
[docs] def message_passing_to_negative_label( message_passing_edge_type: _EdgeType, ) -> _EdgeType: """Convert a message passing edge type to a negative label edge type. Args: message_passing_edge_type (EdgeType): The message passing edge type. Returns: EdgeType: The negative label edge type. """ edge_type = ( str(message_passing_edge_type[0]), f"{message_passing_edge_type[1]}_{_NEGATIVE_LABEL_TAG}", str(message_passing_edge_type[2]), ) if isinstance(message_passing_edge_type, EdgeType): return EdgeType( NodeType(edge_type[0]), Relation(edge_type[1]), NodeType(edge_type[2]) ) else: return edge_type
[docs] def select_label_edge_types( message_passing_edge_type: _EdgeType, edge_entities: abc.Iterable[_EdgeType] ) -> tuple[_EdgeType, Optional[_EdgeType]]: """Select label edge types for a given message passing edge type. Args: message_passing_edge_type (EdgeType): The message passing edge type. edge_entities (abc.Iterable[EdgeType]): The edge entities to select from. Returns: tuple[EdgeType, Optional[EdgeType]]: A tuple containing the positive label edge type and optionally the negative label edge type. """ positive_label_type = None negative_label_type = None for edge_type in edge_entities: if message_passing_to_positive_label(message_passing_edge_type) == edge_type: positive_label_type = edge_type if message_passing_to_negative_label(message_passing_edge_type) == edge_type: negative_label_type = edge_type if positive_label_type is None: raise ValueError( f"Could not find positive label edge type for message passing edge type {message_passing_edge_type} from edge entities {edge_entities}." ) return positive_label_type, negative_label_type
# Entities that represent a graph, somehow. # Ideally, this would be anything, e.g. `_T = TypeVar("_T")`, but we need to be more specific. # As if we type `to_homogeneous(x: _T | dict[NodeType, _T] | dict[EdgeType, _T]) -> _T`, # then `_T` captures the "dict" types, and the output type is not correctly narrowed. # e.g. `reveal_type(to_homogeneous(d: Tensor | dict[..., Tensor] | None]))` is `object` # Instead, we enumerate these types, as MyPy does not allow "not" in a TypeVar. # We should extend this as necessary, just make sure *never* add any Mapping types. # NOTE: We have `Optional[SerializedTFRecordInfo]` in the type, # As adding `None` and `SerializedTFRecordInfo` separately do not accomplish the equivalent thing. # I believe this is due to the fact that the contraints on a `TypeVar` are not # are not treated as a union of the types, but rather each as their own case. _GraphEntity = TypeVar( "_GraphEntity", torch.Tensor, GraphPartitionData, FeaturePartitionData, SerializedTFRecordInfo, Optional[SerializedTFRecordInfo], list, # TODO(kmonte): Add GLT Partition book here # We cannot at the moment as we mypy ignore GLT # And adding it as a type here will break mypy. # PartitionBook ) @overload
[docs] def to_heterogeneous_node(x: None) -> None: ...
@overload def to_heterogeneous_node( x: Union[_GraphEntity, dict[NodeType, _GraphEntity]] ) -> dict[NodeType, _GraphEntity]: ... def to_heterogeneous_node( x: Optional[Union[_GraphEntity, dict[NodeType, _GraphEntity]]] ) -> Optional[dict[NodeType, _GraphEntity]]: """Convert a value to a heterogeneous node representation. If the input is None, return None. If the input is a dictionary, return it as is. If the input is a single value, return it as a dictionary with the default homogeneous node type as the key. Args: x (Optional[Union[_GraphEntity, dict[NodeType, _GraphEntity]]]): The input value to convert. Returns: Optional[dict[NodeType, _GraphEntity]]: The converted heterogeneous node representation. """ if x is None: return None if isinstance(x, dict): return x return {DEFAULT_HOMOGENEOUS_NODE_TYPE: x} @overload
[docs] def to_heterogeneous_edge(x: None) -> None: ...
@overload def to_heterogeneous_edge( x: Union[_GraphEntity, dict[EdgeType, _GraphEntity]] ) -> dict[EdgeType, _GraphEntity]: ... def to_heterogeneous_edge( x: Optional[Union[_GraphEntity, dict[EdgeType, _GraphEntity]]] ) -> Optional[dict[EdgeType, _GraphEntity]]: """Convert a value to a heterogeneous edge representation. If the input is None, return None. If the input is a dictionary, return it as is. If the input is a single value, return it as a dictionary with the default homogeneous edge type as the key. Args: x (Optional[Union[_GraphEntity, dict[EdgeType, _GraphEntity]]]): The input value to convert. Returns: Optional[dict[EdgeType, _GraphEntity]]: The converted heterogeneous edge representation. """ if x is None: return None if isinstance(x, dict): return x return {DEFAULT_HOMOGENEOUS_EDGE_TYPE: x} @overload
[docs] def to_homogeneous(x: None) -> None: ...
@overload def to_homogeneous(x: abc.Mapping[NodeType, _GraphEntity]) -> _GraphEntity: ... @overload def to_homogeneous(x: abc.Mapping[EdgeType, _GraphEntity]) -> _GraphEntity: ... @overload def to_homogeneous(x: _GraphEntity) -> _GraphEntity: ... def to_homogeneous( x: Optional[ Union[ _GraphEntity, abc.Mapping[NodeType, _GraphEntity], abc.Mapping[EdgeType, _GraphEntity], ] ] ) -> Optional[_GraphEntity]: """Convert a value to a homogeneous representation. If the input is None, return None. If the input is a dictionary, return the single value in the dictionary. If the input is a single value, return it as is. Args: x (Optional[Union[_T, dict[Union[NodeType, EdgeType], _T]]]): The input value to convert. Returns: Optional[_T]: The converted homogeneous representation. """ if x is None: return None if isinstance(x, abc.Mapping): if len(x) != 1: raise ValueError( f"Expected a single value in the dictionary, but got multiple keys: {x.keys()}" ) n = next(iter(x.values())) return n return x