gigl.distributed.dist_neighbor_sampler#
Classes#
We inherit from the GLT DistNeighborSampler base class and override the _sample_from_nodes function. Specifically, we |
Module Contents#
- class gigl.distributed.dist_neighbor_sampler.DistABLPNeighborSampler(data, num_neighbors=None, with_edge=False, with_neg=False, with_weight=False, edge_dir='out', collect_features=False, channel=None, use_all2all=False, concurrency=1, device=None, seed=None)[source]#
Bases:
graphlearn_torch.distributed.DistNeighborSampler
We inherit from the GLT DistNeighborSampler base class and override the _sample_from_nodes function. Specifically, we introduce functionality to read parse ABLPNodeSamplerInput, which contains information about the supervision nodes and node types that we also want to fanout around. We add the supervision nodes to the initial fanout seeds, and inject the label information into the output SampleMessage metadata.
- Parameters:
data (graphlearn_torch.distributed.dist_dataset.DistDataset)
num_neighbors (Optional[graphlearn_torch.typing.NumNeighbors])
with_edge (bool)
with_neg (bool)
with_weight (bool)
edge_dir (Literal['in', 'out'])
collect_features (bool)
channel (Optional[graphlearn_torch.channel.ChannelBase])
use_all2all (bool)
concurrency (int)
device (Optional[torch.device])
seed (int)