Source code for gigl.distributed.distributed_neighborloader

import time
from collections import Counter, abc
from typing import Optional, Tuple, Union

import torch
from graphlearn_torch.channel import SampleMessage
from graphlearn_torch.distributed import (
    DistLoader,
    MpDistSamplingWorkerOptions,
    RemoteDistSamplingWorkerOptions,
)
from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import EdgeType

import gigl.distributed.utils
from gigl.common.logger import Logger
from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT
from gigl.distributed.dist_context import DistributedContext
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset
from gigl.distributed.utils.neighborloader import (
    DatasetSchema,
    SamplingClusterSetup,
    labeled_to_homogeneous,
    patch_fanout_for_sampling,
    set_missing_features,
    shard_nodes_by_process,
    strip_label_edges,
)
from gigl.src.common.types.graph_data import (
    NodeType,  # TODO (mkolodner-sc): Change to use torch_geometric.typing
)
from gigl.types.graph import (
    DEFAULT_HOMOGENEOUS_EDGE_TYPE,
    DEFAULT_HOMOGENEOUS_NODE_TYPE,
)

[docs] logger = Logger()
# When using CPU based inference/training, we default cpu threads for neighborloading on top of the per process parallelism.
[docs] DEFAULT_NUM_CPU_THREADS = 2
[docs] class DistNeighborLoader(DistLoader): def __init__( self, dataset: Union[DistDataset, RemoteDistDataset], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ Union[ torch.Tensor, Tuple[NodeType, torch.Tensor], abc.Mapping[int, torch.Tensor], Tuple[NodeType, abc.Mapping[int, torch.Tensor]], ] ] = None, num_workers: int = 1, batch_size: int = 1, context: Optional[DistributedContext] = None, # TODO: (svij) Deprecate this local_process_rank: Optional[int] = None, # TODO: (svij) Deprecate this local_process_world_size: Optional[int] = None, # TODO: (svij) Deprecate this pin_memory_device: Optional[torch.device] = None, worker_concurrency: int = 4, channel_size: str = "4GB", process_start_gap_seconds: float = 60.0, num_cpu_threads: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, ): """ Note: We try to adhere to pyg dataloader api as much as possible. See the following for reference: https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/loader/node_loader.html#NodeLoader https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/distributed/dist_neighbor_loader.html#DistNeighborLoader Args: dataset (DistDataset | RemoteDistDataset): The dataset to sample from. If this is a `RemoteDistDataset`, then we assumed to be in "Graph Store" mode. num_neighbors (list[int] or dict[Tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node. input_nodes (Tensor | Tuple[NodeType, Tensor] | dict[int, Tensor] | Tuple[NodeType, dict[int, Tensor]]): The nodes to start sampling from. It is of type `torch.LongTensor` for homogeneous graphs. If set to `None` for homogeneous settings, all nodes will be considered. In heterogeneous graphs, this flag must be passed in as a tuple that holds the node type and node indices. (default: `None`) For Graph Store mode, this must be a tuple of (NodeType, dict[int, Tensor]) or dict[int, Tensor]. Where each Tensor in the dict is the node ids to sample from, by server. e.g. {0: [10, 20], 1: [30, 40]} means sample from nodes 10 and 20 on server 0, and nodes 30 and 40 on server 1. If a Graph Store input (e.g. list[Tensor]) is provided to colocated mode, or colocated input (e.g. Tensor) is provided to Graph Store mode, then an error will be raised. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load (default: ``1``). pin_memory_device (str, optional): The target device that the sampled results should be copied to. If set to ``None``, the device is inferred based off of (got by ``gigl.distributed.utils.device.get_available_device``). Which uses the local_process_rank and torch.cuda.device_count() to assign the device. If cuda is not available, the cpu device will be used. (default: ``None``). worker_concurrency (int): The max sampling concurrency for each sampling worker. Load testing has showed that setting worker_concurrency to 4 yields the best performance for sampling. Although, you may whish to explore higher/lower settings when performance tuning. (default: `4`). channel_size (int or str): The shared-memory buffer size (bytes) allocated for the channel. Can be modified for performance tuning; a good starting point is: ``num_workers * 64MB`` (default: "4GB"). process_start_gap_seconds (float): Delay between each process for initializing neighbor loader. At large scales, it is recommended to set this value to be between 60 and 120 seconds -- otherwise multiple processes may attempt to initialize dataloaders at overlapping times, which can cause CPU memory OOM. num_cpu_threads (Optional[int]): Number of cpu threads PyTorch should use for CPU training/inference neighbor loading; on top of the per process parallelism. Defaults to `2` if set to `None` when using cpu training/inference. shuffle (bool): Whether to shuffle the input nodes. (default: ``False``). drop_last (bool): Whether to drop the last incomplete batch. (default: ``False``). """ # Set self._shutdowned right away, that way if we throw here, and __del__ is called, # then we can properly clean up and don't get extraneous error messages. # We set to `True` as we don't need to cleanup right away, and this will get set # to `False` in super().__init__()` e.g. # https://github.com/alibaba/graphlearn-for-pytorch/blob/26fe3d4e050b081bc51a79dc9547f244f5d314da/graphlearn_torch/python/distributed/dist_loader.py#L125C1-L126C1 self._shutdowned = True node_world_size: int node_rank: int rank: int world_size: int local_rank: int local_world_size: int master_ip_address: str should_cleanup_distributed_context: bool = False if context: assert ( local_process_world_size is not None ), "context: DistributedContext provided, so local_process_world_size must be provided." assert ( local_process_rank is not None ), "context: DistributedContext provided, so local_process_rank must be provided." master_ip_address = context.main_worker_ip_address node_world_size = context.global_world_size node_rank = context.global_rank local_world_size = local_process_world_size local_rank = local_process_rank rank = node_rank * local_world_size + local_rank world_size = node_world_size * local_world_size if not torch.distributed.is_initialized(): logger.info( "process group is not available, trying to torch.distributed.init_process_group to communicate necessary setup information." ) should_cleanup_distributed_context = True logger.info( f"Initializing process group with master ip address: {master_ip_address}, rank: {rank}, world size: {world_size}, local_rank: {local_rank}, local_world_size: {local_world_size}." ) torch.distributed.init_process_group( backend="gloo", # We just default to gloo for this temporary process group init_method=f"tcp://{master_ip_address}:{DEFAULT_MASTER_INFERENCE_PORT}", rank=rank, world_size=world_size, ) else: assert ( torch.distributed.is_initialized() ), f"context: DistributedContext is None, so process group must be initialized before constructing this object {self.__class__.__name__}." world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() rank_ip_addresses = gigl.distributed.utils.get_internal_ip_from_all_ranks() master_ip_address = rank_ip_addresses[0] count_ranks_per_ip_address = Counter(rank_ip_addresses) local_world_size = count_ranks_per_ip_address[master_ip_address] for rank_ip_address, count in count_ranks_per_ip_address.items(): if count != local_world_size: raise ValueError( f"All ranks must have the same number of processes, but found {count} processes for rank {rank} on ip {rank_ip_address}, expected {local_world_size}." + f"count_ranks_per_ip_address = {count_ranks_per_ip_address}" ) node_world_size = len(count_ranks_per_ip_address) local_rank = rank % local_world_size node_rank = rank // local_world_size del ( context, local_process_rank, local_process_world_size, ) # delete deprecated vars so we don't accidentally use them. if isinstance(dataset, RemoteDistDataset): self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE else: self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") device = ( pin_memory_device if pin_memory_device else gigl.distributed.utils.get_available_device( local_process_rank=local_rank ) ) # Determines if the node ids passed in are heterogeneous or homogeneous. if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: assert isinstance( dataset, DistDataset ), "When using colocated mode, dataset must be a DistDataset." input_data, worker_options, dataset_metadata = self._setup_for_colocated( input_nodes, dataset, local_rank, local_world_size, device, master_ip_address, node_rank, node_world_size, num_workers, worker_concurrency, channel_size, num_cpu_threads, ) else: # Graph Store mode assert isinstance( dataset, RemoteDistDataset ), "When using Graph Store mode, dataset must be a RemoteDistDataset." input_data, worker_options, dataset_metadata = self._setup_for_graph_store( input_nodes, dataset, num_workers, ) self._is_labeled_heterogeneous = dataset_metadata.is_labeled_heterogeneous self._node_feature_info = dataset_metadata.node_feature_info self._edge_feature_info = dataset_metadata.edge_feature_info logger.info(f"num_neighbors before patch: {num_neighbors}") num_neighbors = patch_fanout_for_sampling( edge_types=dataset_metadata.edge_types, num_neighbors=num_neighbors, ) logger.info( f"num_neighbors: {num_neighbors}, edge_types: {dataset_metadata.edge_types}" ) sampling_config = SamplingConfig( sampling_type=SamplingType.NODE, num_neighbors=num_neighbors, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, with_edge=True, collect_features=True, with_neg=False, with_weight=False, edge_dir=dataset_metadata.edge_dir, seed=None, # it's actually optional - None means random. ) if should_cleanup_distributed_context and torch.distributed.is_initialized(): logger.info( f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." ) torch.distributed.destroy_process_group() if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. # The current hypothesis is making connections across machines require a lot of memory. # If we start all data loaders in all processes simultaneously, the spike of memory # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. logger.info( f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" ) time.sleep(process_start_gap_seconds * local_rank) super().__init__( dataset, # Pass in the dataset for colocated mode. input_data, sampling_config, device, worker_options, ) else: # For Graph Store mode, we need to start the communcation between compute and storage nodes sequentially, by compute node. # E.g. intialize connections between compute node 0 and storage nodes 0, 1, 2, 3, then compute node 1 and storage nodes 0, 1, 2, 3, etc. # Note that each compute node may have multiple connections to each storage node, once per compute process. # It's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). # Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. # E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. # We need to this because if we don't, then there is a race condition when initalizing the samplers on the storage nodes [1] # Where since the lock is per *server* (e.g. per storage node), if we try to start one connection from compute node 0, and compute node 1 # Then we deadlock and fail. # Specifically, the race condition happens in `DistLoader.__init__` when it initializes the sampling producers on the storage nodes. [2] # [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L129-L167 # [2]: https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_loader.py#L187-L193 # See below for a connection setup. # ╔═══════════════════════════════════════════════════════════════════════════════════════╗ # ║ COMPUTE TO STORAGE NODE CONNECTIONS ║ # ╚═══════════════════════════════════════════════════════════════════════════════════════╝ # COMPUTE NODES STORAGE NODES # ═════════════ ═════════════ # ┌──────────────────────┐ (1) ┌───────────────┐ # │ COMPUTE NODE 0 │ │ │ # │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ # │ │GPU │GPU │GPU │GPU │ ╱ │ │ # │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ # │ └────┴────┴────┴────┤ (2) ╲ ╱ # └──────────────────────┘ ╲ ╱ # ╳ # (3) ╱ ╲ (4) # ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ # │ COMPUTE NODE 1 │ ╱ ╲ │ │ # │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ # │ │GPU │GPU │GPU │GPU │ │ │ # │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ # │ └────┴────┴────┴────┤ └───────────────┘ # └──────────────────────┘ # ┌─────────────────────────────────────────────────────────────────────────────┐ # │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ # │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ # │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ # │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ # └─────────────────────────────────────────────────────────────────────────────┘ node_rank = dataset.cluster_info.compute_node_rank for target_node_rank in range(dataset.cluster_info.num_compute_nodes): if node_rank == target_node_rank: # TODO: (kmontemayor2-sc) Evaluate if we need to stagger the initialization of the data loaders # to smooth the memory usage. super().__init__( None, # Pass in None for Graph Store mode. input_data, sampling_config, device, worker_options, ) logger.info(f"node_rank {node_rank} initialized the dist loader") torch.distributed.barrier() torch.distributed.barrier() logger.info("All node ranks initialized the dist loader") def _setup_for_graph_store( self, input_nodes: Optional[ Union[ torch.Tensor, Tuple[NodeType, torch.Tensor], abc.Mapping[int, torch.Tensor], Tuple[NodeType, abc.Mapping[int, torch.Tensor]], ] ], dataset: RemoteDistDataset, num_workers: int, ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetSchema]: if input_nodes is None: raise ValueError( f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" ) elif isinstance(input_nodes, torch.Tensor): raise ValueError( f"When using Graph Store mode, input nodes must be of type (abc.Mapping[int, torch.Tensor] | (NodeType, abc.Mapping[int, torch.Tensor]), received {type(input_nodes)}" ) elif isinstance(input_nodes, tuple) and isinstance( input_nodes[1], torch.Tensor ): raise ValueError( f"When using Graph Store mode, input nodes must be of type (dict[int, torch.Tensor] | (NodeType, dict[int, torch.Tensor])), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})" ) is_labeled_heterogeneous = False node_feature_info = dataset.get_node_feature_info() edge_feature_info = dataset.get_edge_feature_info() edge_types = dataset.get_edge_types() node_rank = dataset.cluster_info.compute_node_rank # Get sampling ports for compute-storage connections. sampling_ports = dataset.get_free_ports_on_storage_cluster( num_ports=dataset.cluster_info.num_compute_nodes ) sampling_port = sampling_ports[node_rank] worker_options = RemoteDistSamplingWorkerOptions( server_rank=list(range(dataset.cluster_info.num_storage_nodes)), num_workers=num_workers, worker_devices=[torch.device("cpu") for i in range(num_workers)], master_addr=dataset.cluster_info.storage_cluster_master_ip, master_port=sampling_port, worker_key=f"compute_rank_{node_rank}", ) logger.info( f"Rank {torch.distributed.get_rank()}! init for sampling rpc: {f'tcp://{dataset.cluster_info.storage_cluster_master_ip}:{sampling_port}'}" ) # Setup input data for the dataloader. # Determine nodes list and fallback input_type based on input_nodes structure if isinstance(input_nodes, abc.Mapping): nodes = input_nodes fallback_input_type = None require_edge_feature_info = False elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping): nodes = input_nodes[1] fallback_input_type = input_nodes[0] require_edge_feature_info = True else: raise ValueError( f"When using Graph Store mode, input nodes must be of type (list[torch.Tensor] | (NodeType, list[torch.Tensor])), received {type(input_nodes)}" ) # Determine input_type based on edge_feature_info if isinstance(edge_types, list): if edge_types == [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE else: input_type = fallback_input_type elif require_edge_feature_info: raise ValueError( "When using Graph Store mode, edge types must be provided for heterogeneous graphs." ) else: input_type = None # Convert from dict to list which is what the GLT DistNeighborLoader expects. servers = nodes.keys() if max(servers) >= dataset.cluster_info.num_storage_nodes or min(servers) < 0: raise ValueError( f"When using Graph Store mode, the server ranks must be less than the number of storage nodes and greater than 0, received inputs for servers: {list(nodes.keys())}" ) input_data: list[NodeSamplerInput] = [] for server_rank in range(dataset.cluster_info.num_storage_nodes): if server_rank in nodes: input_data.append( NodeSamplerInput(node=nodes[server_rank], input_type=input_type) ) else: input_data.append( NodeSamplerInput( node=torch.empty(0, dtype=torch.long), input_type=input_type ) ) return ( input_data, worker_options, DatasetSchema( is_labeled_heterogeneous=is_labeled_heterogeneous, edge_types=edge_types, node_feature_info=node_feature_info, edge_feature_info=edge_feature_info, edge_dir=dataset.get_edge_dir(), ), ) def _setup_for_colocated( self, input_nodes: Optional[ Union[ torch.Tensor, Tuple[NodeType, torch.Tensor], abc.Mapping[int, torch.Tensor], Tuple[NodeType, abc.Mapping[int, torch.Tensor]], ] ], dataset: DistDataset, local_rank: int, local_world_size: int, device: torch.device, master_ip_address: str, node_rank: int, node_world_size: int, num_workers: int, worker_concurrency: int, channel_size: str, num_cpu_threads: Optional[int], ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema]: if input_nodes is None: if dataset.node_ids is None: raise ValueError( "Dataset must have node ids if input_nodes are not provided." ) if isinstance(dataset.node_ids, abc.Mapping): raise ValueError( f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}" ) input_nodes = dataset.node_ids if isinstance(input_nodes, abc.Mapping): raise ValueError( f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)}" ) elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping): raise ValueError( f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})" ) is_labeled_heterogeneous = False if isinstance(input_nodes, torch.Tensor): node_ids = input_nodes # If the dataset is heterogeneous, we may be in the "labeled homogeneous" setting, # if so, then we should use DEFAULT_HOMOGENEOUS_NODE_TYPE. if isinstance(dataset.node_ids, abc.Mapping): if ( len(dataset.node_ids) == 1 and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids ): node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE is_labeled_heterogeneous = True else: raise ValueError( f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}" ) else: node_type = None else: node_type, node_ids = input_nodes assert isinstance( dataset.node_ids, abc.Mapping ), "Dataset must be heterogeneous if provided input nodes are a tuple." curr_process_nodes = shard_nodes_by_process( input_nodes=node_ids, local_process_rank=local_rank, local_process_world_size=local_world_size, ) input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize # the memory overhead and CPU contention. logger.info( f"Initializing neighbor loader worker in process: {local_rank}/{local_world_size} using device: {device}" ) should_use_cpu_workers = device.type == "cpu" if should_use_cpu_workers and num_cpu_threads is None: logger.info( "Using CPU workers, but found num_cpu_threads to be None. " f"Will default setting num_cpu_threads to {DEFAULT_NUM_CPU_THREADS}." ) num_cpu_threads = DEFAULT_NUM_CPU_THREADS neighbor_loader_ports = gigl.distributed.utils.get_free_ports_from_master_node( num_ports=local_world_size ) neighbor_loader_port_for_current_rank = neighbor_loader_ports[local_rank] gigl.distributed.utils.init_neighbor_loader_worker( master_ip_address=master_ip_address, local_process_rank=local_rank, local_process_world_size=local_world_size, rank=node_rank, world_size=node_world_size, master_worker_port=neighbor_loader_port_for_current_rank, device=device, should_use_cpu_workers=should_use_cpu_workers, # Lever to explore tuning for CPU based inference num_cpu_threads=num_cpu_threads, ) logger.info( f"Finished initializing neighbor loader worker: {local_rank}/{local_world_size}" ) # Sets up worker options for the dataloader dist_sampling_ports = gigl.distributed.utils.get_free_ports_from_master_node( num_ports=local_world_size ) dist_sampling_port_for_current_rank = dist_sampling_ports[local_rank] worker_options = MpDistSamplingWorkerOptions( num_workers=num_workers, worker_devices=[torch.device("cpu") for _ in range(num_workers)], worker_concurrency=worker_concurrency, # Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group # need to be connected. Thus, we need master ip address and master port to # initate the connection. # Note that different groups of workers are independent, and thus # the sampling processes in different groups should be independent, and should # use different master ports. master_addr=master_ip_address, master_port=dist_sampling_port_for_current_rank, # Load testing show that when num_rpc_threads exceed 16, the performance # will degrade. num_rpc_threads=min(dataset.num_partitions, 16), rpc_timeout=600, channel_size=channel_size, pin_memory=device.type == "cuda", ) if isinstance(dataset.graph, dict): edge_types = list(dataset.graph.keys()) else: edge_types = None return ( input_data, worker_options, DatasetSchema( is_labeled_heterogeneous=is_labeled_heterogeneous, edge_types=edge_types, node_feature_info=dataset.node_feature_info, edge_feature_info=dataset.edge_feature_info, edge_dir=dataset.edge_dir, ), ) def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = super()._collate_fn(msg) data = set_missing_features( data=data, node_feature_info=self._node_feature_info, edge_feature_info=self._edge_feature_info, device=self.to_device, ) if isinstance(data, HeteroData): data = strip_label_edges(data) if self._is_labeled_heterogeneous: data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data) return data