Source code for gigl.src.mocking.toy_asset_mocker
from typing import List
import torch
import yaml
from torch_geometric.data import HeteroData
[docs]
def load_toy_graph(graph_config: str) -> HeteroData:
with open(graph_config, "r") as f:
graph_config = yaml.safe_load(f)
node_config = graph_config["graph"]["node_types"]
edge_config = graph_config["graph"]["edge_types"]
data = HeteroData()
# Add node features
for node_type in node_config:
node_feats_list: List[str] = []
for node in graph_config["nodes"][node_type]:
features = node["features"]
node_feats_list.append(features)
data[node_type].x = torch.tensor(node_feats_list)
# Add edge indices and edge features
for edge_type in edge_config:
src_type = edge_config[edge_type]["src_node_type"]
dst_type = edge_config[edge_type]["dst_node_type"]
rel_type = edge_config[edge_type]["relation_type"]
edge_index_list = []
for adj in graph_config["adj_list"][edge_type]:
dst_list = adj["dst"]
edge_index_list.extend([(adj["src"], dst) for dst in dst_list])
edge_index = torch.tensor(edge_index_list).t().contiguous()
data[(src_type, rel_type, dst_type)].edge_index = edge_index
# Dummy edge features: edge_index.T * 0.1
data[(src_type, rel_type, dst_type)].edge_attr = edge_index.t() * 0.1
return data