Source code for gigl.src.common.graph_builder.graph_builder_factory
from gigl.src.common.graph_builder.abstract_graph_builder import GraphBuilder
from gigl.src.common.graph_builder.pyg_graph_builder import PygGraphBuilder
from gigl.src.common.types.model import GraphBackend
[docs]
class GraphBuilderFactory:
    """
    Instantiates a `GraphBuilder` object based on valid `GraphBackend` names
    """
    @classmethod
[docs]
    def get_graph_builder(cls, backend_name: GraphBackend) -> GraphBuilder:
        if backend_name == GraphBackend.PYG:
            return PygGraphBuilder()
        else:
            raise ValueError(
                f"{backend_name} is not valid. backend_name can be one of {[gb.value for gb in GraphBackend]}"
            )