Source code for gigl.src.mocking.lib.mocked_dataset_resources
from dataclasses import dataclass
from typing import Optional
import torch
from gigl.src.common.translators.gbml_protos_translator import GbmlProtosTranslator
from gigl.src.common.types.graph_data import (
    CondensedEdgeType,
    CondensedNodeType,
    EdgeType,
    EdgeUsageType,
    NodeType,
)
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from snapchat.research.gbml import graph_schema_pb2
@dataclass
[docs]
class MockedDatasetInfo:
    # TODO: (svij) Deprecate MockedDatasetInfo in favor of pyg.HeteroData
    @property
[docs]
    def node_types(self) -> list[NodeType]:
        return list(self.node_feats.keys()) 
    @property
[docs]
    def edge_types(self) -> list[EdgeType]:
        return list(self.edge_index.keys()) 
    @property
[docs]
    def num_nodes(self) -> dict[NodeType, int]:
        return {
            node_type: node_feat.shape[0]
            for node_type, node_feat in self.node_feats.items()
        } 
[docs]
    def get_num_edges(self, edge_type: EdgeType, edge_usage_type: EdgeUsageType) -> int:
        num_edges = 0
        if edge_usage_type == EdgeUsageType.MAIN:
            main_edge_size_dict = {
                edge_type: edge_index.shape[1]
                for edge_type, edge_index in self.edge_index.items()
            }
            num_edges = main_edge_size_dict[edge_type]
        elif (
            self.user_defined_edge_index is not None
            and edge_type in self.user_defined_edge_index
        ):
            # We ignore the edge_type as currently UDL mocking does not have edge type
            if (
                edge_usage_type == EdgeUsageType.POSITIVE
                and EdgeUsageType.POSITIVE in self.user_defined_edge_index[edge_type]
            ):
                num_edges = self.user_defined_edge_index[edge_type][
                    EdgeUsageType.POSITIVE
                ].shape[1]
            elif (
                edge_usage_type == EdgeUsageType.NEGATIVE
                and EdgeUsageType.NEGATIVE in self.user_defined_edge_index[edge_type]
            ):
                num_edges = self.user_defined_edge_index[edge_type][
                    EdgeUsageType.NEGATIVE
                ].shape[1]
        return num_edges 
    @property
[docs]
    def num_node_features(self) -> dict[NodeType, int]:
        return {
            node_type: feats.shape[1] for node_type, feats in self.node_feats.items()
        } 
    @property
[docs]
    def num_node_distinct_labels(self) -> dict[NodeType, int]:
        if not self.node_labels:
            return {}
        return {
            node_type: labels.unique().numel()
            for node_type, labels in self.node_labels.items()
        } 
    @property
[docs]
    def num_edge_features(self) -> dict[EdgeType, int]:
        if self.edge_feats:
            return {
                edge_type: feats.shape[1]
                for edge_type, feats in self.edge_feats.items()
            }
        else:
            return {edge_type: 0 for edge_type in self.edge_types} 
    @property
[docs]
    def num_user_def_edge_features(self) -> dict[EdgeType, dict[EdgeUsageType, int]]:
        num_user_def_edge_feats = {}
        if self.user_defined_edge_feats:
            for edge_type, udl_edge_feats in self.user_defined_edge_feats.items():
                num_user_def_edge_feats[edge_type] = {
                    edge_usage_type: feats.shape[1]
                    for edge_usage_type, feats in udl_edge_feats.items()
                }
        else:
            for edge_type in self.edge_types:
                num_user_def_edge_feats[edge_type] = {
                    edge_usage_type: 0
                    for edge_usage_type in [
                        EdgeUsageType.POSITIVE,
                        EdgeUsageType.NEGATIVE,
                    ]
                }
        return num_user_def_edge_feats 
    @property
    @property
[docs]
    def default_node_type(self) -> NodeType:
        return self.node_types[0] 
    @property
[docs]
    def default_edge_type(self) -> EdgeType:
        return self.edge_types[0] 
[docs]
    edge_index: dict[EdgeType, torch.Tensor] 
[docs]
    node_feats: dict[NodeType, torch.Tensor] 
[docs]
    edge_feats: Optional[dict[EdgeType, torch.Tensor]] = None 
[docs]
    node_labels: Optional[dict[NodeType, torch.Tensor]] = None 
[docs]
    sample_node_type: Optional[NodeType] = None 
    # TODO (tzhao-sc): currently only supporting 1 supervision edge type, we would need
    #      to extend this to support multiple supervision edge types for HGS stage 2
[docs]
    sample_edge_type: Optional[EdgeType] = None 
[docs]
    edge_src_column_name: str = "src" 
[docs]
    edge_dst_column_name: str = "dst" 
[docs]
    node_id_column_name: str = "node_id" 
[docs]
    node_label_column_name: str = "node_label" 
[docs]
    user_defined_edge_index: Optional[
        dict[EdgeType, dict[EdgeUsageType, torch.Tensor]]
    ] = None 
[docs]
    user_defined_edge_feats: Optional[
        dict[EdgeType, dict[EdgeUsageType, torch.Tensor]]
    ] = None 
[docs]
    version: Optional[str] = None