Source code for gigl.src.common.graph_builder.pyg_graph_builder
from __future__ import annotations
import torch
from gigl.common.collections.frozen_dict import FrozenDict
from gigl.common.logger import Logger
from gigl.src.common.graph_builder.abstract_graph_builder import GraphBuilder
from gigl.src.common.graph_builder.pyg_graph_data import PygGraphData
from gigl.src.common.types.graph_data import Node, NodeId
[docs]
class PygGraphBuilder(GraphBuilder[PygGraphData]):
    def __init__(self) -> None:
        self.reset()
[docs]
    def build(self) -> PygGraphData:
        data = PygGraphData()
        # Register Node Features
        for node_type, num_nodes in self.subgraph_node_id_counter.items():
            logger.debug(f"Registering {num_nodes} nodes of type {node_type}")
            data[node_type].x = torch.stack(
                [
                    # This needs to default to 1 if no node features are provided
                    # This is a restriction of PyG, that is it expectes node features of atleast size 1
                    (
                        self.subgraph_node_features_dict[
                            Node(type=node_type, id=NodeId(node_id))
                        ]
                        if self.should_register_node_features
                        else torch.ones(1)
                    )
                    for node_id in range(num_nodes)
                ]
            )
        # Register Edge Features
        for edge_type, ordered_edges in self.ordered_edges.items():
            logger.debug(f"Registering {len(ordered_edges)} edges of type {edge_type}")
            src_node_list: list[int] = []
            dst_node_list: list[int] = []
            edge_features_list: list[torch.Tensor] = []
            for edge in ordered_edges:
                src_node_list.append(int(edge.src_node.id))
                dst_node_list.append(int(edge.dst_node.id))
                if self.should_register_edge_features:
                    edge_feature = self.subgraph_edge_feature_dict[edge]
                    assert edge_feature is not None
                    edge_features_list.append(edge_feature)
            if self.should_register_edge_features and len(edge_features_list) > 0:
                data[tuple(edge_type)].edge_attr = torch.stack(edge_features_list)
            if len(src_node_list) > 0 and len(dst_node_list) > 0:
                data[tuple(edge_type)].edge_index = torch.LongTensor(
                    [
                        src_node_list,
                        dst_node_list,
                    ],
                )
        data.global_node_to_subgraph_node_mapping = FrozenDict(
            self.global_node_to_subgraph_node_map.copy()
        )
        self.reset()  # reset before returning
        return data