Source code for gigl.src.mocking.lib.mocked_dataset_resources
from dataclasses import dataclass
from typing import Dict, List, Optional
import torch
from gigl.src.common.translators.gbml_protos_translator import GbmlProtosTranslator
from gigl.src.common.types.graph_data import (
CondensedEdgeType,
CondensedNodeType,
EdgeType,
EdgeUsageType,
NodeType,
)
from gigl.src.common.types.pb_wrappers.graph_metadata import GraphMetadataPbWrapper
from gigl.src.common.types.task_metadata import TaskMetadataType
from snapchat.research.gbml import graph_schema_pb2
@dataclass
[docs]
class MockedDatasetInfo:
# TODO: (svij) Deprecate MockedDatasetInfo in favor of pyg.HeteroData
@property
[docs]
def node_types(self) -> List[NodeType]:
return list(self.node_feats.keys())
@property
[docs]
def edge_types(self) -> List[EdgeType]:
return list(self.edge_index.keys())
@property
[docs]
def num_nodes(self) -> Dict[NodeType, int]:
return {
node_type: node_feat.shape[0]
for node_type, node_feat in self.node_feats.items()
}
[docs]
def get_num_edges(self, edge_type: EdgeType, edge_usage_type: EdgeUsageType) -> int:
num_edges = 0
if edge_usage_type == EdgeUsageType.MAIN:
main_edge_size_dict = {
edge_type: edge_index.shape[1]
for edge_type, edge_index in self.edge_index.items()
}
num_edges = main_edge_size_dict[edge_type]
elif (
self.user_defined_edge_index is not None
and edge_type in self.user_defined_edge_index
):
# We ignore the edge_type as currently UDL mocking does not have edge type
if (
edge_usage_type == EdgeUsageType.POSITIVE
and EdgeUsageType.POSITIVE in self.user_defined_edge_index[edge_type]
):
num_edges = self.user_defined_edge_index[edge_type][
EdgeUsageType.POSITIVE
].shape[1]
elif (
edge_usage_type == EdgeUsageType.NEGATIVE
and EdgeUsageType.NEGATIVE in self.user_defined_edge_index[edge_type]
):
num_edges = self.user_defined_edge_index[edge_type][
EdgeUsageType.NEGATIVE
].shape[1]
return num_edges
@property
[docs]
def num_node_features(self) -> Dict[NodeType, int]:
return {
node_type: feats.shape[1] for node_type, feats in self.node_feats.items()
}
@property
[docs]
def num_node_distinct_labels(self) -> Dict[NodeType, int]:
if not self.node_labels:
return {}
return {
node_type: labels.unique().numel()
for node_type, labels in self.node_labels.items()
}
@property
[docs]
def num_edge_features(self) -> Dict[EdgeType, int]:
if self.edge_feats:
return {
edge_type: feats.shape[1]
for edge_type, feats in self.edge_feats.items()
}
else:
return {edge_type: 0 for edge_type in self.edge_types}
@property
[docs]
def num_user_def_edge_features(self) -> Dict[EdgeType, Dict[EdgeUsageType, int]]:
num_user_def_edge_feats = {}
if self.user_defined_edge_feats:
for edge_type, udl_edge_feats in self.user_defined_edge_feats.items():
num_user_def_edge_feats[edge_type] = {
edge_usage_type: feats.shape[1]
for edge_usage_type, feats in udl_edge_feats.items()
}
else:
for edge_type in self.edge_types:
num_user_def_edge_feats[edge_type] = {
edge_usage_type: 0
for edge_usage_type in [
EdgeUsageType.POSITIVE,
EdgeUsageType.NEGATIVE,
]
}
return num_user_def_edge_feats
@property
@property
[docs]
def default_node_type(self) -> NodeType:
return self.node_types[0]
@property
[docs]
def default_edge_type(self) -> EdgeType:
return self.edge_types[0]
[docs]
edge_index: Dict[EdgeType, torch.Tensor]
[docs]
node_feats: Dict[NodeType, torch.Tensor]
[docs]
edge_feats: Optional[Dict[EdgeType, torch.Tensor]] = None
[docs]
node_labels: Optional[Dict[NodeType, torch.Tensor]] = None
[docs]
sample_node_type: Optional[NodeType] = None
# TODO (tzhao-sc): currently only supporting 1 supervision edge type, we would need
# to extend this to support multiple supervision edge types for HGS stage 2
[docs]
sample_edge_type: Optional[EdgeType] = None
[docs]
edge_src_column_name: str = "src"
[docs]
edge_dst_column_name: str = "dst"
[docs]
node_id_column_name: str = "node_id"
[docs]
node_label_column_name: str = "node_label"
[docs]
user_defined_edge_index: Optional[
Dict[EdgeType, Dict[EdgeUsageType, torch.Tensor]]
] = None
[docs]
user_defined_edge_feats: Optional[
Dict[EdgeType, Dict[EdgeUsageType, torch.Tensor]]
] = None
[docs]
version: Optional[str] = None