Source code for gigl.src.mocking.lib.user_defined_edge_sampling
from collections import defaultdict
from typing import List
from gigl.src.common.types.graph_data import EdgeUsageType, NodeId
from gigl.src.mocking.lib.mocked_dataset_resources import MockedDatasetInfo
[docs]
def sample_hydrate_user_def_edge(
    mocked_dataset_info: MockedDatasetInfo, edge_usage_type: EdgeUsageType
) -> dict[NodeId, List]:
    """
    Samples all available pos/neg edges and hydrated these edges with their features.
    e.g. for positive edge the output will be {pos_edge_src: [pos_edge_dst, [f0, f1, ..., fn]]}
    """
    src_to_dst_map = defaultdict(list)
    assert (
        mocked_dataset_info.sample_edge_type is not None
    ), "sample_edge_type is missing in mocked_dataset_info"
    edge_index = (
        mocked_dataset_info.user_defined_edge_index[
            mocked_dataset_info.sample_edge_type  # type: ignore
        ][edge_usage_type]
        if mocked_dataset_info.user_defined_edge_index
        else None
    )
    edge_feats = (
        mocked_dataset_info.user_defined_edge_feats[
            mocked_dataset_info.sample_edge_type  # type: ignore
        ][edge_usage_type]
        if mocked_dataset_info.user_defined_edge_feats
        else None
    )
    if edge_feats is not None:
        for src, dst, feats in zip(
            edge_index[0].tolist(),  # type: ignore
            edge_index[1].tolist(),  # type: ignore
            edge_feats.tolist(),
        ):
            src_to_dst_map[src].append([dst, feats])
    else:
        for src, dst in zip(
            edge_index[0].tolist(),  # type: ignore
            edge_index[1].tolist(),  # type: ignore
        ):
            src_to_dst_map[src].append(dst)
    return src_to_dst_map