Source code for gigl.distributed.graph_store.remote_dist_dataset

from typing import Literal, Optional, Union

import torch
from graphlearn_torch.distributed import async_request_server, request_server

from gigl.common.logger import Logger
from gigl.distributed.graph_store.remote_dataset import (
    get_edge_dir,
    get_edge_feature_info,
    get_node_feature_info,
    get_node_ids_for_rank,
)
from gigl.distributed.utils.networking import get_free_ports
from gigl.env.distributed import GraphStoreInfo
from gigl.src.common.types.graph_data import EdgeType, NodeType
from gigl.types.graph import FeatureInfo

[docs] logger = Logger()
[docs] class RemoteDistDataset: def __init__(self, cluster_info: GraphStoreInfo, local_rank: int): """ Represents a dataset that is stored on a difference storage cluster. *Must* be used in the GiGL graph-store distributed setup. This class *must* be used on the compute (client) side of the graph-store distributed setup. Args: cluster_info (GraphStoreInfo): The cluster information. local_rank (int): The local rank of the process on the compute node. """ self._cluster_info = cluster_info self._local_rank = local_rank @property
[docs] def cluster_info(self) -> GraphStoreInfo: return self._cluster_info
[docs] def get_node_feature_info( self, ) -> Union[FeatureInfo, dict[NodeType, FeatureInfo], None]: """Get node feature information from the registered dataset. Returns: Node feature information, which can be: - A single FeatureInfo object for homogeneous graphs - A dict mapping NodeType to FeatureInfo for heterogeneous graphs - None if no node features are available """ return request_server( 0, get_node_feature_info, )
[docs] def get_edge_feature_info( self, ) -> Union[FeatureInfo, dict[EdgeType, FeatureInfo], None]: """Get edge feature information from the registered dataset. Returns: Edge feature information, which can be: - A single FeatureInfo object for homogeneous graphs - A dict mapping EdgeType to FeatureInfo for heterogeneous graphs - None if no edge features are available """ return request_server( 0, get_edge_feature_info, )
[docs] def get_edge_dir(self) -> Union[str, Literal["in", "out"]]: """Get the edge direction from the registered dataset. Returns: The edge direction. """ return request_server( 0, get_edge_dir, )
[docs] def get_node_ids( self, node_type: Optional[NodeType] = None, ) -> list[torch.Tensor]: """ Fetches node ids from the storage nodes for the current compute node (machine). The returned list are the node ids for the current compute node, by storage rank. For example, if there are two storage ranks, and two compute ranks, and 16 total nodes, In this scenario, the node ids are sharded as follows: Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7] Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15] NOTE: The GLT sampling enginer expects that all processes on a given compute machine to have the same sampling input (node ids). As such, the input tensors will be duplicated across all processes on a given compute machine. TODO(kmonte): Come up with a solution to avoid this duplication. Then, for compute rank 0 (node 0, process 0), the returned list will be: [ [0, 1, 3, 4], # From storage rank 0 [8, 9, 10, 11] # From storage rank 1 ] Args: node_type (Optional[NodeType]): The type of nodes to get. Must be provided for heterogeneous datasets. Returns: list[torch.Tensor]: A list of node IDs for the given node type. """ futures: list[torch.futures.Future[torch.Tensor]] = [] rank = self.cluster_info.compute_node_rank world_size = self.cluster_info.num_storage_nodes logger.info( f"Getting node ids for rank {rank} / {world_size} with node type {node_type}" ) for server_rank in range(self.cluster_info.num_storage_nodes): futures.append( async_request_server( server_rank, get_node_ids_for_rank, rank=rank, world_size=world_size, node_type=node_type, ) ) node_ids = torch.futures.wait_all(futures) return node_ids
[docs] def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: """ Get free ports from the storage master node. This *must* be used with a torch.distributed process group initialized, for the *entire* training cluster. All compute ranks will receive the same free ports. Args: num_ports (int): Number of free ports to get. """ if not torch.distributed.is_initialized(): raise ValueError( "torch.distributed process group must be initialized for the entire training cluster" ) compute_cluster_rank = ( self.cluster_info.compute_node_rank * self.cluster_info.num_processes_per_compute + self._local_rank ) if compute_cluster_rank == 0: ports = request_server( 0, get_free_ports, num_ports=num_ports, ) logger.info( f"Compute rank {compute_cluster_rank} found free ports: {ports}" ) else: ports = [None] * num_ports torch.distributed.broadcast_object_list(ports, src=0) logger.info(f"Compute rank {compute_cluster_rank} received free ports: {ports}") return ports