Source code for gigl.distributed.dist_range_partitioner

import gc
import time
from typing import Dict, Optional, Union

import torch
from graphlearn_torch.distributed.rpc import all_gather
from graphlearn_torch.partition import PartitionBook, RangePartitionBook
from graphlearn_torch.utils import convert_to_tensor

from gigl.common.logger import Logger
from gigl.distributed.dist_partitioner import DistPartitioner
from gigl.distributed.utils.partition_book import get_ids_on_rank
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.types.graph import FeaturePartitionData, GraphPartitionData, to_homogeneous

[docs] logger = Logger()
[docs] class DistRangePartitioner(DistPartitioner): """ This class is responsible for implementing range-based partitioning. Rather than using a tensor-based partition book, this approach stores the upper bound of ids for each rank. For example, a range partition book [4, 8, 12] stores edge ids 0-3 on the 0th rank, 4-7 on the 1st rank, and 8-11 on the 2nd rank. While keeping the same id-indexing pattern for rank lookup as the tensor-based partitioning, this partition book does a search through these partition bounds to fetch the ranks, rather than using a direct index lookup. For example, to get the rank of node ids 1 and 6 by doing node_pb[[1, 6]], the range partition book uses torch.searchsorted on the partition bounds to return [0, 1], the ranks of each of these ids. As a result, the range-based partition book trades off more efficient memory storage for a slower lookup time for indices. """
[docs] def register_edge_index( self, edge_index: Union[torch.Tensor, Dict[EdgeType, torch.Tensor]] ) -> None: """ Registers the edge_index to the partitioner. Unlike the tensor-based partitioner, this register pattern does not automatically infer edge ids,as they are not needed for partitioning. For optimal memory management, it is recommended that the reference to edge_index tensor be deleted after calling this function using del <tensor>, as maintaining both original and intermediate tensors can cause OOM concerns. Args: edge_index (Union[torch.Tensor, Dict[EdgeType, torch.Tensor]]): Input edge index which is either a torch.Tensor if homogeneous or a Dict if heterogeneous """ self._assert_and_get_rpc_setup() logger.info("Registering Edge Indices ...") input_edge_index = self._convert_edge_entity_to_heterogeneous_format( input_edge_entity=edge_index ) assert ( input_edge_index ), "Edge Index is an empty dictionary. Please provide edge indices to register." self._edge_types = sorted(input_edge_index.keys()) self._edge_index = convert_to_tensor(input_edge_index, dtype=torch.int64)
def _partition_node(self, node_type: NodeType) -> PartitionBook: """ Partition graph nodes of a specific node type. For range-based partitioning, we partition all the nodes into continuous ranges so that the diff between lengths of any two ranges is no greater than 1. This function gets called by the `partition_node` API from the parent class, which handles the node partitioning across all node types. Args: node_type (NodeType): The node type for input nodes Returns: PartitionBook: The partition book of graph nodes. """ assert ( self._num_nodes is not None ), "Must have registered nodes prior to partitioning them" num_nodes = self._num_nodes[node_type] per_node_num, remainder = divmod(num_nodes, self._world_size) # We set `remainder` number of partitions to have at most one more item. start = 0 partition_ranges: list[tuple[int, int]] = [] for partition_index in range(self._world_size): if partition_index < remainder: end = start + per_node_num + 1 else: end = start + per_node_num partition_ranges.append((start, end)) start = end # Store and return partitioned ranges as GLT's RangePartitionBook node_partition_book = RangePartitionBook( partition_ranges=partition_ranges, partition_idx=self._rank ) logger.info( f"Got node range-based partition book for node type {node_type} on rank {self._rank} with partition bounds: {node_partition_book.partition_bounds}" ) return node_partition_book def _partition_node_features( self, node_partition_book: Dict[NodeType, PartitionBook], node_type: NodeType, ) -> FeaturePartitionData: """ Partitions node features according to the node partition book. We rely on the functionality from the parent tensor-based partitioner here, and add logic to sort the node features by node indices which is specific to range-based partitioning. This is done so that the range-based id2idx corresponds correctly to the node features. Args: node_partition_book (Dict[NodeType, PartitionBook]): The partition book of nodes node_type (NodeType): Node type of input data Returns: FeaturePartitionData: Ids and Features of input nodes """ features_partition_data = super()._partition_node_features( node_partition_book=node_partition_book, node_type=node_type ) # The parent class always returns ids in the feature_partition_data, but we don't need to store the partitioned node feature ids for # range-based partitioning, since this is available from the node partition book. assert features_partition_data.ids is not None sorted_node_ids_indices = torch.argsort(features_partition_data.ids) partitioned_node_features = features_partition_data.feats[ sorted_node_ids_indices ] return FeaturePartitionData(feats=partitioned_node_features, ids=None) def _partition_edge_index_and_edge_features( self, node_partition_book: Dict[NodeType, PartitionBook], edge_type: EdgeType, ) -> tuple[GraphPartitionData, Optional[FeaturePartitionData], PartitionBook]: """ Partition graph topology of a specific edge type. For range-based partitioning, we partition edges and edge features (if they exist) together. Once they have been partitioned across machines, we build the edge partition book based on the number of edges assigned to each machine. Then, we infer the edge IDs from the edge partition book's ranges. Args: node_partition_book (Dict[NodeType, PartitionBook]): The partition books of all graph nodes. edge_type (EdgeType): The edge type for input edges Returns: GraphPartitionData: The graph data of the current partition. FeaturePartitionData: The edge features on the current partition PartitionBook: The partition book of graph edges. """ assert ( self._edge_index is not None ), "Must have registered edges prior to partitioning them" edge_index = self._edge_index[edge_type] input_data: tuple[torch.Tensor, ...] if self._edge_feat is None or edge_type not in self._edge_feat: logger.info( f"No edge features detected for edge type {edge_type}, will only partition edge indices for this edge type." ) edge_feat = None edge_feat_dim = None input_data = (edge_index[0], edge_index[1]) else: assert self._edge_feat_dim is not None and edge_type in self._edge_feat_dim edge_feat = self._edge_feat[edge_type] edge_feat_dim = self._edge_feat_dim[edge_type] input_data = (edge_index[0], edge_index[1], edge_feat) if self._should_assign_edges_by_src_node: target_node_partition_book = node_partition_book[edge_type.src_node_type] target_indices = edge_index[0] else: target_node_partition_book = node_partition_book[edge_type.dst_node_type] target_indices = edge_index[1] def edge_partition_fn(rank_indices, _): return target_node_partition_book[rank_indices] res_list, _ = self._partition_by_chunk( input_data=input_data, rank_indices=target_indices, partition_function=edge_partition_fn, ) del input_data, edge_index, target_indices, edge_feat del self._edge_index[edge_type] if self._edge_feat is not None and edge_type in self._edge_feat: del self._edge_feat[edge_type] # We check if edge_index or edge_feat dict is empty after deleting the tensor. If so, we set these fields to None. if not self._edge_index: self._edge_index = None if not self._edge_feat and not self._edge_feat_dim: self._edge_feat = None self._edge_feat_dim = None gc.collect() if len(res_list) == 0: partitioned_edge_index = torch.empty((2, 0)) else: partitioned_edge_index = torch.stack( ( torch.cat([r[0] for r in res_list]), torch.cat([r[1] for r in res_list]), ), dim=0, ) if edge_feat_dim is not None: if len(res_list) == 0: partitioned_edge_features = torch.empty(0, edge_feat_dim) else: partitioned_edge_features = torch.cat([r[2] for r in res_list]) res_list.clear() gc.collect() # Generating edge partition book num_edges_on_each_rank: list[tuple[int, int]] = sorted( all_gather((self._rank, partitioned_edge_index.size(1))).values(), key=lambda x: x[0], ) partition_ranges: list[tuple[int, int]] = [] start = 0 for _, num_edges in num_edges_on_each_rank: end = start + num_edges partition_ranges.append((start, end)) start = end edge_partition_book = RangePartitionBook( partition_ranges=partition_ranges, partition_idx=self._rank ) partitioned_edge_ids = get_ids_on_rank( partition_book=edge_partition_book, rank=self._rank ) current_graph_part = GraphPartitionData( edge_index=partitioned_edge_index, edge_ids=partitioned_edge_ids, ) if edge_feat_dim is None: current_feat_part = None else: current_feat_part = FeaturePartitionData( feats=partitioned_edge_features, ids=None ) logger.info( f"Got edge range-based partition book for edge type {edge_type} on rank {self._rank} with partition bounds: {edge_partition_book.partition_bounds}" ) return current_graph_part, current_feat_part, edge_partition_book
[docs] def partition_edge_index_and_edge_features( self, node_partition_book: Union[PartitionBook, Dict[NodeType, PartitionBook]] ) -> Union[ tuple[GraphPartitionData, Optional[FeaturePartitionData], PartitionBook], tuple[ Dict[EdgeType, GraphPartitionData], Optional[Dict[EdgeType, FeaturePartitionData]], Dict[EdgeType, PartitionBook], ], ]: """ Partitions edges of a graph, including edge indices and edge features. If heterogeneous, partitions edges for all edge types. You must call `partition_node` first to get the node partition book as input. The difference between this function and its parent is that we no longer need to check that the `edge_ids` have been pre-computed as a prerequisite for partitioning edges and edge features. Args: node_partition_book (Union[PartitionBook, Dict[NodeType, PartitionBook]]): The computed Node Partition Book Returns: Union[ Tuple[GraphPartitionData, FeaturePartitionData, PartitionBook], Tuple[Dict[EdgeType, GraphPartitionData], Dict[EdgeType, FeaturePartitionData], Dict[EdgeType, PartitionBook]], ]: Partitioned Graph Data, Feature Data, and corresponding edge partition book, is a dictionary if heterogeneous """ self._assert_and_get_rpc_setup() assert ( self._edge_index is not None ), "Must have registered edges prior to partitioning them" logger.info("Partitioning Edges ...") start_time = time.time() transformed_node_partition_book = ( self._convert_node_entity_to_heterogeneous_format( input_node_entity=node_partition_book ) ) self._assert_data_type_consistency( input_entity=transformed_node_partition_book, is_node_entity=True, is_subset=False, ) self._assert_data_type_consistency( input_entity=self._edge_index, is_node_entity=False, is_subset=False ) if self._edge_feat is not None: self._assert_data_type_consistency( input_entity=self._edge_feat, is_node_entity=False, is_subset=True ) edge_partition_book: Dict[EdgeType, PartitionBook] = {} partitioned_edge_index: Dict[EdgeType, GraphPartitionData] = {} partitioned_edge_features: Dict[EdgeType, FeaturePartitionData] = {} for edge_type in self._edge_types: ( partitioned_edge_index_per_edge_type, partitioned_edge_features_per_edge_type, edge_partition_book_per_edge_type, ) = self._partition_edge_index_and_edge_features( node_partition_book=transformed_node_partition_book, edge_type=edge_type ) partitioned_edge_index[edge_type] = partitioned_edge_index_per_edge_type edge_partition_book[edge_type] = edge_partition_book_per_edge_type if partitioned_edge_features_per_edge_type is not None: partitioned_edge_features[ edge_type ] = partitioned_edge_features_per_edge_type elapsed_time = time.time() - start_time logger.info(f"Edge Partitioning finished, took {elapsed_time:.3f}s") return_edge_features = ( partitioned_edge_features if partitioned_edge_features else None ) if self._is_input_homogeneous: return ( to_homogeneous(partitioned_edge_index), to_homogeneous(return_edge_features), to_homogeneous(edge_partition_book), ) else: return ( partitioned_edge_index, return_edge_features, edge_partition_book, )