import asyncio
import gc
from collections import defaultdict
from typing import Optional, Union
import torch
from graphlearn_torch.channel import SampleMessage
from graphlearn_torch.distributed import DistNeighborSampler
from graphlearn_torch.sampler import (
HeteroSamplerOutput,
NeighborOutput,
NodeSamplerInput,
SamplerOutput,
)
from graphlearn_torch.typing import EdgeType, NodeType
from graphlearn_torch.utils import count_dict, merge_dict, reverse_edge_type
from gigl.distributed.sampler import (
NEGATIVE_LABEL_METADATA_KEY,
POSITIVE_LABEL_METADATA_KEY,
ABLPNodeSamplerInput,
)
from gigl.utils.data_splitters import PADDING_NODE
# TODO (mkolodner-sc): Investigate upstreaming this change back to GLT
[docs]
class DistABLPNeighborSampler(DistNeighborSampler):
"""
We inherit from the GLT DistNeighborSampler base class and override the _sample_from_nodes function. Specifically, we
introduce functionality to read parse ABLPNodeSamplerInput, which contains information about the supervision nodes and node types
that we also want to fanout around. We add the supervision nodes to the initial fanout seeds, and inject the label information into the
output SampleMessage metadata.
"""
async def _sample_from_nodes(
self,
inputs: NodeSamplerInput,
) -> Optional[SampleMessage]:
assert isinstance(inputs, ABLPNodeSamplerInput)
input_seeds = inputs.node.to(self.device)
input_type = inputs.input_type
# Since GLT swaps src/dst for edge_dir = "out",
# and GiGL assumes that supervision edge types are always (anchor_node_type, to, supervision_node_type),
# we need to index into supervision edge types accordingly.
label_edge_index = 0 if self.edge_dir == "in" else 2
# Go through the positive and negative labels and add them to the metadata and input seeds builder.
# We need to sample from the supervision nodes as well, and ensure that we are sampling from the correct node type.
metadata: dict[str, torch.Tensor] = {}
input_seeds_builder: dict[
Union[str, NodeType], list[torch.Tensor]
] = defaultdict(list)
input_seeds_builder[input_type].append(input_seeds)
for edge_type, label_tensor in inputs.positive_label_by_edge_types.items():
filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to(
self.device
)
input_seeds_builder[edge_type[label_edge_index]].append(
filtered_label_tensor
)
# Update the metadata per positive label edge type.
# We do this because GLT only supports dict[str, torch.Tensor] for metadata.
metadata[
f"{POSITIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}"
] = label_tensor
for edge_type, label_tensor in inputs.negative_label_by_edge_types.items():
filtered_label_tensor = label_tensor[label_tensor != PADDING_NODE].to(
self.device
)
input_seeds_builder[edge_type[label_edge_index]].append(
filtered_label_tensor
)
# Update the metadata per negative label edge type.
# We do this because GLT only supports dict[str, torch.Tensor] for metadata.
metadata[
f"{NEGATIVE_LABEL_METADATA_KEY}{str(tuple(edge_type))}"
] = label_tensor
# As a perf optimization, we *could* have `input_nodes` be only the unique nodes,
# but since torch.unique() calls a sort, we should investigate if it's worth it.
# TODO(kmonte, mkolodner-sc): Investigate if this is worth it.
input_nodes: dict[Union[str, NodeType], torch.Tensor] = {
node_type: torch.cat(seeds, dim=0).to(self.device)
for node_type, seeds in input_seeds_builder.items()
}
del filtered_label_tensor, label_tensor
for value in input_seeds_builder.values():
value.clear()
input_seeds_builder.clear()
del input_seeds_builder
gc.collect()
self.max_input_size: int = max(self.max_input_size, input_seeds.numel())
inducer = self._acquire_inducer()
is_hetero = self.dist_graph.data_cls == "hetero"
output: NeighborOutput
if is_hetero:
assert input_type is not None
out_nodes_hetero: dict[NodeType, list[torch.Tensor]] = {}
out_rows_hetero: dict[EdgeType, list[torch.Tensor]] = {}
out_cols_hetero: dict[EdgeType, list[torch.Tensor]] = {}
out_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {}
num_sampled_nodes_hetero: dict[NodeType, list[torch.Tensor]] = {}
num_sampled_edges_hetero: dict[EdgeType, list[torch.Tensor]] = {}
src_dict = inducer.init_node(input_nodes)
batch = {input_type: input_seeds}
merge_dict(src_dict, out_nodes_hetero)
count_dict(src_dict, num_sampled_nodes_hetero, 1)
for i in range(self.num_hops):
task_dict: dict[EdgeType, asyncio.Task] = {}
nbr_dict: dict[EdgeType, list[torch.Tensor]] = {}
edge_dict: dict[EdgeType, torch.Tensor] = {}
for etype in self.edge_types:
req_num = self.num_neighbors[etype][i]
if self.edge_dir == "in":
srcs = src_dict.get(etype[-1], None)
if srcs is not None and srcs.numel() > 0:
task_dict[
reverse_edge_type(etype)
] = self._loop.create_task(
self._sample_one_hop(srcs, req_num, etype)
)
elif self.edge_dir == "out":
srcs = src_dict.get(etype[0], None)
if srcs is not None and srcs.numel() > 0:
task_dict[etype] = self._loop.create_task(
self._sample_one_hop(srcs, req_num, etype)
)
for etype, task in task_dict.items():
output = await task
if output.nbr.numel() == 0:
continue
nbr_dict[etype] = [src_dict[etype[0]], output.nbr, output.nbr_num]
if output.edge is not None:
edge_dict[etype] = output.edge
if len(nbr_dict) == 0:
continue
nodes_dict, rows_dict, cols_dict = inducer.induce_next(nbr_dict)
merge_dict(nodes_dict, out_nodes_hetero)
merge_dict(rows_dict, out_rows_hetero)
merge_dict(cols_dict, out_cols_hetero)
merge_dict(edge_dict, out_edges_hetero)
count_dict(nodes_dict, num_sampled_nodes_hetero, i + 2)
count_dict(cols_dict, num_sampled_edges_hetero, i + 1)
src_dict = nodes_dict
sample_output = HeteroSamplerOutput(
node={
ntype: torch.cat(nodes) for ntype, nodes in out_nodes_hetero.items()
},
row={etype: torch.cat(rows) for etype, rows in out_rows_hetero.items()},
col={etype: torch.cat(cols) for etype, cols in out_cols_hetero.items()},
edge=(
{etype: torch.cat(eids) for etype, eids in out_edges_hetero.items()}
if self.with_edge
else None
),
batch=batch,
num_sampled_nodes=num_sampled_nodes_hetero,
num_sampled_edges=num_sampled_edges_hetero,
input_type=input_type,
metadata=metadata,
)
else:
assert (
len(input_nodes) == 1
), f"Expected 1 input node type, got {len(input_nodes)}"
assert (
input_type == list(input_nodes.keys())[0]
), f"Expected input type {input_type}, got {list(input_nodes.keys())[0]}"
srcs = inducer.init_node(input_nodes[input_type])
batch = input_seeds
out_nodes: list[torch.Tensor] = []
out_edges: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = []
num_sampled_nodes: list[torch.Tensor] = []
num_sampled_edges: list[torch.Tensor] = []
out_nodes.append(srcs)
num_sampled_nodes.append(srcs.size(0))
# Sample subgraph.
for req_num in self.num_neighbors:
output = await self._sample_one_hop(srcs, req_num, None)
if output.nbr.numel() == 0:
break
nodes, rows, cols = inducer.induce_next(
srcs, output.nbr, output.nbr_num
)
out_nodes.append(nodes)
out_edges.append((rows, cols, output.edge))
num_sampled_nodes.append(nodes.size(0))
num_sampled_edges.append(cols.size(0))
srcs = nodes
sample_output = SamplerOutput(
node=torch.cat(out_nodes),
row=torch.cat([e[0] for e in out_edges]),
col=torch.cat([e[1] for e in out_edges]),
edge=(torch.cat([e[2] for e in out_edges]) if self.with_edge else None),
batch=batch,
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
metadata=metadata,
)
# Reclaim inducer into pool.
self.inducer_pool.put(inducer)
return sample_output