Source code for gigl.src.common.graph_builder.graph_builder_factory

from gigl.src.common.graph_builder.abstract_graph_builder import GraphBuilder
from gigl.src.common.graph_builder.pyg_graph_builder import PygGraphBuilder
from gigl.src.common.types.model import GraphBackend


[docs] class GraphBuilderFactory: """ Instantiates a `GraphBuilder` object based on valid `GraphBackend` names """ @classmethod
[docs] def get_graph_builder(cls, backend_name: GraphBackend) -> GraphBuilder: if backend_name == GraphBackend.PYG: return PygGraphBuilder() else: raise ValueError( f"{backend_name} is not valid. backend_name can be one of {[gb.value for gb in GraphBackend]}" )