Source code for gigl.src.mocking.lib.user_defined_edge_sampling
from collections import defaultdict
from typing import Dict, List
from gigl.src.common.types.graph_data import EdgeUsageType, NodeId
from gigl.src.mocking.lib.mocked_dataset_resources import MockedDatasetInfo
[docs]
def sample_hydrate_user_def_edge(
mocked_dataset_info: MockedDatasetInfo, edge_usage_type: EdgeUsageType
) -> Dict[NodeId, List]:
"""
Samples all available pos/neg edges and hydrated these edges with their features.
e.g. for positive edge the output will be {pos_edge_src: [pos_edge_dst, [f0, f1, ..., fn]]}
"""
src_to_dst_map = defaultdict(list)
assert (
mocked_dataset_info.sample_edge_type is not None
), "sample_edge_type is missing in mocked_dataset_info"
edge_index = (
mocked_dataset_info.user_defined_edge_index[
mocked_dataset_info.sample_edge_type # type: ignore
][edge_usage_type]
if mocked_dataset_info.user_defined_edge_index
else None
)
edge_feats = (
mocked_dataset_info.user_defined_edge_feats[
mocked_dataset_info.sample_edge_type # type: ignore
][edge_usage_type]
if mocked_dataset_info.user_defined_edge_feats
else None
)
if edge_feats is not None:
for src, dst, feats in zip(
edge_index[0].tolist(), # type: ignore
edge_index[1].tolist(), # type: ignore
edge_feats.tolist(),
):
src_to_dst_map[src].append([dst, feats])
else:
for src, dst in zip(
edge_index[0].tolist(), # type: ignore
edge_index[1].tolist(), # type: ignore
):
src_to_dst_map[src].append(dst)
return src_to_dst_map