import ast
from typing import Any, Union
from gigl.common.logger import Logger
from gigl.src.common.types.graph_data import EdgeType
def _validate_parsed_edge_type(parsed_edge_type: Any) -> None:
    """
    Validates that the parsed edge type is correctly a tuple[str, str, str], denoting an edge type.
    Args:
        parsed_edge_type (Any): Edge type which is expected to be a tuple[str, str, str], corresponding to the source node type, relation, and destination node type, respectively.
    Raises:
        ValueError: if not a tuple
        ValueError: if tuple has a length which is not equal to 3
        ValueError: if not all elements of the tuple are strings
    """
    if not isinstance(parsed_edge_type, tuple) or len(parsed_edge_type) != 3:
        raise ValueError(
            f"Parsed edge type expected to be a tuple[str, str, str], got {parsed_edge_type}"
        )
    if not all([isinstance(edge_type, str) for edge_type in parsed_edge_type]):
        raise ValueError(
            f"Edge type must a tuple[str, str, str] integers, got {parsed_edge_type}"
        )
def _validate_parsed_hops(parsed_fanout: Any) -> None:
    """
    Validates that the parsed fanout is correctly specified as a list of integers.
    Args:
        parsed_fanout (Any): Fanout which is expected to be a list of integers
    Raises:
        ValueError: if not a list
        ValueError: if not all elements of the list are ints
    """
    if not isinstance(parsed_fanout, list):
        raise ValueError(
            f"Parsed fanout expected to be a list, got {parsed_fanout} of type {type(parsed_fanout)}"
        )
    if not all([isinstance(hop, int) for hop in parsed_fanout]):
        raise ValueError(f"Fanout must contain integers, got {parsed_fanout}")
[docs]
def parse_fanout(fanout_str: str) -> Union[list[int], dict[EdgeType, list[int]]]:
    """
    Parses fanout from a string. The fanout string should be equivalent to a str(list[int]) or a
    str(dict[tuple[str, str, str], list[int]]), where each item in the tuple corresponds to the source node type, relation, and destination node type, respectively.
    For example, to parse a list[int], one could provide a fanout_str such as
        '[10, 15, 20]'
    To parse a dict[EdgeType, list[int]], one could provide a fanout_str such as
        '{("user", "to", "user"): [10, 10], ("user", "to", "item"): [20, 20]}'
    Args:
        fanout_str (str): Provided string to be parsed into fanout
    Returns:
        Union[list[int], dict[EdgeType, list[int]]]: Either a list of fanout per hop of a dictionary of edge types to their respective fanouts per hop
    """
    loaded_fanout = ast.literal_eval(fanout_str)
    if isinstance(loaded_fanout, list):
        _validate_parsed_hops(parsed_fanout=loaded_fanout)
        logger.info(f"Parsed list fanout from args: {loaded_fanout}")
        return loaded_fanout
    elif isinstance(loaded_fanout, dict):
        fanout: dict[EdgeType, list[int]] = {}
        for parsed_edge_type, parsed_fanout in loaded_fanout.items():
            _validate_parsed_edge_type(parsed_edge_type=parsed_edge_type)
            _validate_parsed_hops(parsed_fanout=parsed_fanout)
            edge_type = EdgeType(
                src_node_type=parsed_edge_type[0],
                relation=parsed_edge_type[1],
                dst_node_type=parsed_edge_type[2],
            )
            fanout[edge_type] = parsed_fanout
        return fanout
    else:
        raise ValueError(
            f"Fanout must be parsed as either a dictionary or a list, got {loaded_fanout} of type {type(loaded_fanout)}"
        )