# type: ignore
import hashlib
import os
import pathlib
from collections import defaultdict
from difflib import unified_diff
from enum import Enum
from typing import Optional, Type, Union
import matplotlib.pyplot as plt
import networkx as nx
import tensorflow as tf
import torch_geometric.utils
import yaml
from IPython.display import HTML, display
from torch_geometric.data import HeteroData
from gigl.common import Uri
from gigl.common.collections.frozen_dict import FrozenDict
from gigl.src.common.graph_builder.pyg_graph_builder import PygGraphBuilder
from gigl.src.common.translators.gbml_protos_translator import GbmlProtosTranslator
from gigl.src.common.types.graph_data import CondensedNodeType, EdgeType, Node, NodeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.file_loader import FileLoader
from snapchat.research.gbml import training_samples_schema_pb2
[docs]
gigl_root_dir = pathlib.Path(__file__).parent.parent.parent.parent.parent 
[docs]
def change_working_dir_to_gigl_root():
    """
    Can be used inside notebooks to change the working directory to the GIGL root directory.
    This is useful for ensuring that relative imports and file paths work correctly no matter where the notebook is located.
    """
    os.chdir(gigl_root_dir)
    print(f"Changed working directory to: {gigl_root_dir}") 
[docs]
class GraphVisualizerLayoutMode(Enum):
[docs]
    HOMOGENEOUS = "homogeneous" 
[docs]
    BIPARTITE = "bipartite" 
 
[docs]
class PbVisualizerFromOutput(Enum):
[docs]
    SPLIT_TRAIN = "split_train" 
[docs]
    SPLIT_VAL = "split_val" 
[docs]
    SPLIT_TEST = "split_test" 
 
[docs]
class PbVisualizer:
    def __init__(self, frozen_task_config: GbmlConfigPbWrapper):
[docs]
        self.frozen_task_config = frozen_task_config 
        preprocessed_metadata = (
            frozen_task_config.preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb
        )
        graph_metadata_pb_wrapper = frozen_task_config.graph_metadata_pb_wrapper
        from gigl.src.common.utils.bq import BqUtils
        bq_utils = BqUtils()
        # dict[tuple[condensed_node_type, enumerated_node_id], tuple[node_type, unenumerated_node_id]]
[docs]
        self.enumerated_node_to_unenumerated_node_id_map: dict[
            tuple[int, int], tuple[str, int]
        ] = {} 
        for (
            condensed_node_type,
            node_metadata,
        ) in preprocessed_metadata.condensed_node_type_to_preprocessed_metadata.items():
            node_type = graph_metadata_pb_wrapper.condensed_node_type_to_node_type_map[
                CondensedNodeType(condensed_node_type)
            ]
            result = bq_utils.run_query(
                query=f"""
                SELECT int_id, node_id FROM `{node_metadata.enumerated_node_ids_bq_table}`
                """,
                labels={},
            )
            for row in result:
                self.enumerated_node_to_unenumerated_node_id_map[
                    (condensed_node_type, row.int_id)
                ] = (node_type, row.node_id)
        # dict[tuple[node_type, unenumerated_node_id], tuple[condensed_node_type, enumerated_node_id]]
[docs]
        self.unenumerated_node_id_to_enumerated_node_id_map: dict[
            tuple[str, int], tuple[int, int]
        ] = {} 
        for (condensed_node_type, int_id), (
            node_type,
            node_id,
        ) in self.enumerated_node_to_unenumerated_node_id_map.items():
            self.unenumerated_node_id_to_enumerated_node_id_map[
                (node_type, node_id)
            ] = (CondensedNodeType(condensed_node_type), int_id)
[docs]
    def plot_pb(
        self,
        pb: Union[
            training_samples_schema_pb2.RootedNodeNeighborhood,
            training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample,
        ],
        layout_mode: GraphVisualizerLayoutMode = GraphVisualizerLayoutMode.BIPARTITE,
    ):
        if not pb:
            print("No pb to plot")
            return
        builder = PygGraphBuilder()
        graph_metadata_pb_wrapper = self.frozen_task_config.graph_metadata_pb_wrapper
        graph_data = GbmlProtosTranslator.graph_data_from_GraphPb(
            samples=[pb.neighborhood],
            graph_metadata_pb_wrapper=graph_metadata_pb_wrapper,
            builder=builder,
        )
        # Extract positive edges if this is a NodeAnchorBasedLinkPredictionSample
        pos_edges: Optional[dict[tuple[str, str, str], list[tuple[int, int]]]] = None
        global_root_node: Optional[tuple[int, str]] = None
        if isinstance(
            pb, training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample
        ):
            pos_edges = defaultdict(list)
            for edge in pb.pos_edges:
                edge_type: EdgeType = (
                    graph_metadata_pb_wrapper.condensed_edge_type_to_edge_type_map[
                        edge.condensed_edge_type
                    ]
                )
                pos_edges[
                    (
                        edge_type.src_node_type,
                        edge_type.relation,
                        edge_type.dst_node_type,
                    )
                ].append((edge.src_node_id, edge.dst_node_id))
            global_root_node = (
                pb.root_node.node_id,
                graph_metadata_pb_wrapper.condensed_node_type_to_node_type_map[
                    pb.root_node.condensed_node_type
                ],
            )
        subgraph_node_to_unenumerated_node_id_map: dict[Node, Node] = {}
        for (
            node,
            global_node,
        ) in graph_data.subgraph_node_to_global_node_mapping.items():
            condensed_node_type = (
                graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[
                    global_node.type
                ]
            )
            (
                unenumerated_node_type,
                unenumerated_node_id,
            ) = self.enumerated_node_to_unenumerated_node_id_map[
                (condensed_node_type, global_node.id)
            ]
            subgraph_node_to_unenumerated_node_id_map[node] = Node(
                id=unenumerated_node_id,
                type=NodeType(unenumerated_node_type),
            )
        return GraphVisualizer.visualize_graph(
            data=graph_data.to_hetero_data(),
            layout_mode=layout_mode,
            subgraph_node_to_unenumerated_node_id_map=subgraph_node_to_unenumerated_node_id_map,
            pos_edges=pos_edges,
            global_root_node=global_root_node,
        ) 
[docs]
    def find_node_pb(
        self,
        unenumerated_node_id: int,
        unenumerated_node_type: str,
        from_output: PbVisualizerFromOutput,
        pb_type: Type[
            Union[
                training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample,
                training_samples_schema_pb2.RootedNodeNeighborhood,
            ]
        ],
    ) -> Optional[
        Union[
            training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample,
            training_samples_schema_pb2.RootedNodeNeighborhood,
        ]
    ]:
        tfrecord_uri_prefix: str
        if from_output == PbVisualizerFromOutput.SGS:
            flattened_graph_metadata = (
                self.frozen_task_config.shared_config.flattened_graph_metadata
            )
            assert hasattr(
                flattened_graph_metadata, "node_anchor_based_link_prediction_output"
            ), f"find_node_pb only supported for node_anchor_based_link_prediction, not {flattened_graph_metadata}"
            if (
                pb_type
                == training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample
            ):
                tfrecord_uri_prefix = (
                    flattened_graph_metadata.node_anchor_based_link_prediction_output.tfrecord_uri_prefix
                )
            elif pb_type == training_samples_schema_pb2.RootedNodeNeighborhood:
                tfrecord_uri_prefix = flattened_graph_metadata.node_anchor_based_link_prediction_output.node_type_to_random_negative_tfrecord_uri_prefix[
                    unenumerated_node_type
                ]
            else:
                raise ValueError(f"Unsupported pb_type: {pb_type}")
        else:
            assert hasattr(
                self.frozen_task_config.shared_config.dataset_metadata,
                "node_anchor_based_link_prediction_dataset",
            ), f"find_node_pb only supported for node_anchor_based_link_prediction, not {self.frozen_task_config.shared_config.dataset_metadata}"
            dataset = (
                self.frozen_task_config.shared_config.dataset_metadata.node_anchor_based_link_prediction_dataset
            )
            if (
                pb_type
                == training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample
            ):
                if from_output == PbVisualizerFromOutput.SPLIT_TRAIN:
                    tfrecord_uri_prefix = dataset.train_main_data_uri
                elif from_output == PbVisualizerFromOutput.SPLIT_VAL:
                    tfrecord_uri_prefix = dataset.val_main_data_uri
                elif from_output == PbVisualizerFromOutput.SPLIT_TEST:
                    tfrecord_uri_prefix = dataset.test_main_data_uri
                else:
                    raise ValueError(f"Unsupported from_output: {from_output}")
            elif pb_type == training_samples_schema_pb2.RootedNodeNeighborhood:
                if from_output == PbVisualizerFromOutput.SPLIT_TRAIN:
                    tfrecord_uri_prefix = (
                        dataset.train_node_type_to_random_negative_data_uri[
                            unenumerated_node_type
                        ]
                    )
                elif from_output == PbVisualizerFromOutput.SPLIT_VAL:
                    tfrecord_uri_prefix = (
                        dataset.val_node_type_to_random_negative_data_uri[
                            unenumerated_node_type
                        ]
                    )
                elif from_output == PbVisualizerFromOutput.SPLIT_TEST:
                    tfrecord_uri_prefix = (
                        dataset.test_node_type_to_random_negative_data_uri[
                            unenumerated_node_type
                        ]
                    )
                else:
                    raise ValueError(f"Unsupported from_output: {from_output}")
            else:
                raise ValueError(f"Unsupported pb_type: {pb_type}")
        uri = tfrecord_uri_prefix + "*.tfrecord"
        (
            search_node_type,
            search_node_id,
        ) = self.unenumerated_node_id_to_enumerated_node_id_map[
            (unenumerated_node_type, unenumerated_node_id)
        ]
        print(
            f"The node id {unenumerated_node_id}, type {unenumerated_node_type} maps to node id {search_node_id}, type {search_node_type}"
        )
        ds = tf.data.TFRecordDataset(tf.io.gfile.glob(uri)).as_numpy_iterator()
        pb: Optional[
            Union[
                training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample,
                training_samples_schema_pb2.RootedNodeNeighborhood,
            ]
        ] = None
        print(f" Looking for node {search_node_id} in {uri}")
        pb_output: Optional[
            Union[
                training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample,
                training_samples_schema_pb2.RootedNodeNeighborhood,
            ]
        ] = None
        for bytestr in ds:
            try:
                if pb_type == training_samples_schema_pb2.RootedNodeNeighborhood:
                    pb = training_samples_schema_pb2.RootedNodeNeighborhood()
                elif (
                    pb_type
                    == training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample
                ):
                    pb = (
                        training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample()
                    )
                else:
                    raise ValueError(f"Unsupported pb_type: {pb_type}")
                pb.ParseFromString(bytestr)
                if pb.root_node.node_id == search_node_id:
                    pb_output = pb
                    break
            except StopIteration:
                break
        return pb_output 
 
[docs]
class GraphVisualizer:
    """
    Used to build and visualize graph which is user configured in a yaml file.
    """
    # Fixed node color palette — extend as needed
[docs]
    node_colors = [
        "#64B5F6",  # blue
        "#E57373",  # red
        "#81C784",  # green
        "#FFD54F",  # yellow
        "#BA68C8",  # purple
        "#4DB6AC",  # teal
        "#F06292",  # pink
        "#A1887F",  # brown
        "#FFB74D",  # orange
    ] 
    # Fixed edge color palette — best for white background
[docs]
    edge_colors = [
        "#1565C0",  # medium blue
        "#43A047",  # vivid green
        "#000000",  # black
    ] 
    @staticmethod
[docs]
    def assign_node_color(name: str) -> str:
        """Assign a node color to a name based on deterministic hash and a fixed palette."""
        # Use SHA256 for deterministic hashing
        hash_value = int(hashlib.sha256(name.encode("utf-8")).hexdigest(), 16)
        return GraphVisualizer.node_colors[
            hash_value % len(GraphVisualizer.node_colors)
        ] 
    @staticmethod
[docs]
    def assign_edge_color(name: str) -> str:
        """Assign an edge color to a name based on deterministic hash and a fixed palette (optimized for white background)."""
        # Use SHA256 for deterministic hashing
        hash_value = int(hashlib.sha256(name.encode("utf-8")).hexdigest(), 16)
        return GraphVisualizer.edge_colors[
            hash_value % len(GraphVisualizer.edge_colors)
        ] 
    @staticmethod
    def _create_type_grouped_layout(
        g,
        node_index_to_type,
        node_types,
        seed=42,
        layout_mode=GraphVisualizerLayoutMode.BIPARTITE,
    ):
        """
        Warning: This is mostly just AI slop, but it serves the purpose for now.
        Create a layout based on the specified mode (bipartite or homogeneous).
        """
        # Handle empty graph case
        if len(g.nodes()) == 0:
            return {}
        if layout_mode == GraphVisualizerLayoutMode.HOMOGENEOUS:
            # For homogeneous graphs, use layouts that work well for general graph structure
            num_nodes = len(g.nodes())
            if num_nodes <= 30:
                # Small to medium graphs - use Kamada-Kawai (good for showing structure)
                try:
                    # Increase scale significantly to prevent node overlap (node_size=500)
                    return nx.kamada_kawai_layout(g, scale=15)
                except Exception as e:
                    print(
                        f"Kamada-Kawai layout failed: {e}, falling back to spring layout"
                    )
                    # Fallback to spring layout if kamada_kawai fails
                    # Increase k (ideal distance) and scale to prevent overlap
                    k = max(4.0, num_nodes / 3.0)
                    return nx.spring_layout(g, seed=seed, k=k, iterations=300, scale=15)
            else:
                # Large graphs - use spring layout with good parameters
                k = max(3.0, num_nodes / 6.0)
                return nx.spring_layout(g, seed=seed, k=k, iterations=250, scale=20)
        elif layout_mode == GraphVisualizerLayoutMode.BIPARTITE:
            # Group nodes by their types for bipartite/heterogeneous layout
            type_to_nodes = {}
            for node in g.nodes():
                node_type = node_index_to_type.get(node, "unknown")
                if node_type not in type_to_nodes:
                    type_to_nodes[node_type] = []
                type_to_nodes[node_type].append(node)
            num_types = len(type_to_nodes)
            if num_types == 1:
                # Single type - use circular layout with more spacing
                return nx.circular_layout(g, scale=15)
            elif num_types == 2:
                # Two types - use bipartite layout with more spacing
                types = list(type_to_nodes.keys())
                first_type_nodes = set(type_to_nodes[types[0]])
                return nx.bipartite_layout(g, first_type_nodes, scale=15)
            else:
                # Multiple types or fallback - use spring layout with much more spacing
                k = max(
                    4.0, len(g.nodes()) / 3.0
                )  # Dynamic spacing based on node count
                return nx.spring_layout(g, seed=seed, k=k, iterations=300, scale=15)
        else:
            raise ValueError(f"Invalid layout mode: {layout_mode}")
    @staticmethod
[docs]
    def visualize_graph(
        data: HeteroData,
        seed=42,
        layout_mode=GraphVisualizerLayoutMode.BIPARTITE,
        subgraph_node_to_unenumerated_node_id_map: Optional[
            FrozenDict[Node, Node]
        ] = None,
        # pos_edges is a dictionary of edge type (src_node_type, relation, dst_node_type) to list of (src_node_id, dst_node_id) pairs
        pos_edges: Optional[dict[tuple[str, str, str], list[tuple[int, int]]]] = None,
        global_root_node: Optional[tuple[int, str]] = None,
    ):
        """
        Warning: This is mostly just AI slop, but it serves the purpose for now.
        Visualize a graph.
        Args:
            data: The HeteroData object to visualize.
            seed: The seed for the random number generator - fix it to ensure reproducibility in visualizations
            layout_mode: Either GraphVisualizerLayoutMode.HOMOGENEOUS or GraphVisualizerLayoutMode.BIPARTITE
            subgraph_node_to_global_node_mapping: A mapping from local node indices to global node indices.
        """
        # Build a mapping from global node indices to node types BEFORE conversion
        node_index_to_type = {}
        current_index = 0
        # HeteroData stores nodes by type - we need to map the global indices
        # that NetworkX will use back to the original node types
        for node_type in data.node_types:
            if hasattr(data[node_type], "num_nodes"):
                num_nodes = data[node_type].num_nodes
                for i in range(num_nodes):
                    node_index_to_type[current_index] = node_type
                    current_index += 1
        # Convert to NetworkX
        g = torch_geometric.utils.to_networkx(data)
        if subgraph_node_to_unenumerated_node_id_map:
            mapping = {}
            new_node_index_to_type = {}
            for node in g.nodes():
                node_type = node_index_to_type.get(node, "unknown")
                local_node = Node(type=node_type, id=node)
                unenumerated_node: Node = subgraph_node_to_unenumerated_node_id_map[
                    local_node
                ]
                mapping[node] = unenumerated_node.id
                # Preserve the node type information for the global node
                new_node_index_to_type[unenumerated_node.id] = unenumerated_node.type
            g = nx.relabel_nodes(g, mapping)
            # Update the node_index_to_type mapping to use global node types
            node_index_to_type = new_node_index_to_type  # type: ignore
        # Add positive edges to the graph if they don't already exist
        pos_edge_pairs = set()
        if pos_edges:
            for (
                src_node_type,
                relation,
                dst_node_type,
            ), edge_pairs in pos_edges.items():
                for src_id, dst_id in edge_pairs:
                    pos_edge_pairs.add((src_id, dst_id))
                    # Add nodes if they don't exist
                    if src_id not in g.nodes():
                        g.add_node(src_id)
                        node_index_to_type[src_id] = src_node_type
                    if dst_id not in g.nodes():
                        g.add_node(dst_id)
                        node_index_to_type[dst_id] = dst_node_type
                    # Add edge if it doesn't exist
                    if not g.has_edge(src_id, dst_id):
                        g.add_edge(src_id, dst_id)
        # Create node type to color mapping
        node_type_to_color = {}
        for node_type in data.node_types:
            node_type_to_color[node_type] = GraphVisualizer.assign_node_color(node_type)
        # Assign colors based on the mapping we built
        node_colors = []
        for node in g.nodes():
            node_type = node_index_to_type.get(node, "unknown")
            # Get color for this node type
            if node_type not in node_type_to_color:
                node_type_to_color[node_type] = GraphVisualizer.assign_node_color(
                    node_type
                )
            node_colors.append(node_type_to_color[node_type])
        # Create a larger figure for better node spacing
        plt.figure(figsize=(10, 6))
        # Generate a layout based on the selected mode
        # Get all unique node types actually present in the graph
        actual_node_types = set(node_index_to_type.values())
        pos = GraphVisualizer._create_type_grouped_layout(
            g, node_index_to_type, actual_node_types, seed, layout_mode
        )
        # Safety check: if pos is None or empty and we have nodes, create a fallback layout
        if pos is None or (len(g.nodes()) > 0 and not pos):
            print("Layout generation failed, using fallback spring layout")
            k = max(4.0, len(g.nodes()) / 3.0)
            pos = nx.spring_layout(g, seed=seed, k=k, iterations=300, scale=15)
        # Identify isolated nodes and root node for special styling
        isolated_nodes = [node for node in g.nodes() if g.degree(node) == 0]
        root_node_id = global_root_node[0] if global_root_node else None
        # Create border styling (red border for root node, thick black border for isolated nodes)
        node_edge_colors = []
        node_line_widths = []
        node_sizes = []
        for node in g.nodes():
            if node == root_node_id:
                node_edge_colors.append("#E53935")  # Red border for root node
                node_line_widths.append(4)  # Thick border for root node
                node_sizes.append(1000)  # Twice the size for root node
            elif node in isolated_nodes:
                node_edge_colors.append(BLACK)  # Black border for isolated nodes
                node_line_widths.append(3)
                node_sizes.append(500)  # Normal size
            else:
                node_edge_colors.append(CHARCOAL)  # Default border color
                node_line_widths.append(1)
                node_sizes.append(500)  # Normal size
        # Create edge type to color mapping
        edge_type_to_color = {}
        edge_colors = []
        # Extract edge types from the graph (now includes any added positive edges)
        for edge in g.edges():
            # Check if this is a positive edge
            is_positive_edge = (edge[0], edge[1]) in pos_edge_pairs
            if is_positive_edge:
                # Color positive edges red
                edge_colors.append("#E53935")  # Red color for positive edges
            else:
                # Get node types for source and destination
                src_node_type = node_index_to_type.get(edge[0], "unknown")
                dst_node_type = node_index_to_type.get(edge[1], "unknown")
                # Create edge type identifier
                edge_type = f"{src_node_type} → {dst_node_type}"
                # Look for a more specific edge type in HeteroData if available
                if hasattr(data, "edge_types") and data.edge_types:
                    for et in data.edge_types:
                        if len(et) == 3:  # (src_type, relation, dst_type)
                            if et[0] == src_node_type and et[2] == dst_node_type:
                                edge_type = f"{et[0]} --{et[1]}--> {et[2]}"
                                break
                        elif (
                            isinstance(et, tuple) and len(et) == 2
                        ):  # Some formats might be (src, dst)
                            if et[0] == src_node_type and et[1] == dst_node_type:
                                edge_type = f"{et[0]} → {et[1]}"
                                break
                # Assign color to edge type
                if edge_type not in edge_type_to_color:
                    edge_type_to_color[edge_type] = GraphVisualizer.assign_edge_color(
                        edge_type
                    )
                edge_colors.append(edge_type_to_color[edge_type])
        # Draw nodes first
        nx.draw_networkx_nodes(
            g,
            pos,
            node_color=node_colors if node_colors else "lightblue",  # type: ignore
            edgecolors=node_edge_colors if node_edge_colors else CHARCOAL,  # type: ignore
            linewidths=node_line_widths if node_line_widths else 1,  # type: ignore
            node_size=node_sizes if node_sizes else 500,  # type: ignore
        )
        # Draw edges - straight for homogeneous, curved for bipartite
        if g.edges() and edge_colors:
            if layout_mode == GraphVisualizerLayoutMode.HOMOGENEOUS:
                # Straight edges for homogeneous graphs
                nx.draw_networkx_edges(
                    g,
                    pos,
                    edge_color=edge_colors,  # type: ignore
                    width=0.75,  # 75% of default edge width
                    alpha=0.9,  # Less transparent for cleaner look
                )
            else:
                # Curved edges for bipartite graphs to reduce overlap
                nx.draw_networkx_edges(
                    g,
                    pos,
                    edge_color=edge_colors,  # type: ignore
                    width=0.75,  # 75% of default edge width
                    alpha=0.8,  # Slightly transparent for better overlap visibility
                    connectionstyle="arc3,rad=0.1",  # Curved edges to reduce overlap
                )
        # Draw labels last so they appear on top
        nx.draw_networkx_labels(
            g,
            pos,
            font_size=10,
            font_weight="bold",
        )
        # Add a legend to show node type colors and edge types
        legend_elements = []
        # Add node types
        if len(node_type_to_color) > 1:
            for node_type in sorted(node_type_to_color.keys()):
                legend_elements.append(
                    plt.Line2D(
                        [0],
                        [0],
                        marker="o",
                        color="w",
                        markerfacecolor=node_type_to_color[node_type],
                        markersize=10,
                        label=f"Node: {node_type}",
                    )
                )
        # Add isolated node indicator
        if isolated_nodes:
            legend_elements.append(
                plt.Line2D(
                    [0],
                    [0],
                    marker="o",
                    color="black",
                    markerfacecolor="white",
                    markeredgewidth=3,
                    markersize=10,
                    label="Isolated nodes",
                )
            )
        # Add root node indicator
        if global_root_node:
            legend_elements.append(
                plt.Line2D(
                    [0],
                    [0],
                    marker="o",
                    color="#E53935",
                    markerfacecolor="white",
                    markeredgewidth=4,
                    markersize=15,  # Larger marker to represent the larger size
                    label="Root node",
                )
            )
        # Add positive edges to legend if they exist
        if pos_edges:
            legend_elements.append(
                plt.Line2D(
                    [0],
                    [0],
                    color="#E53935",
                    linewidth=3,
                    label="Positive edges",
                )
            )
        # Add edge types
        if edge_type_to_color:
            for edge_type in sorted(edge_type_to_color.keys()):
                legend_elements.append(
                    plt.Line2D(
                        [0],
                        [0],
                        color=edge_type_to_color[edge_type],
                        linewidth=2,
                        label=f"Edge: {edge_type}",
                    )
                )
        if legend_elements:
            plt.legend(
                handles=legend_elements, loc="upper right", bbox_to_anchor=(1.4, 1)
            )
        plt.show() 
 
[docs]
def sort_yaml_dict_recursively(obj: dict) -> dict:
    # We sort the json recursively as the GiGL proto serialization code does not guarantee order of original keys.
    # This is important for the diff to be stable and not show errors due to key/list order changes.
    if isinstance(obj, dict):
        return {k: sort_yaml_dict_recursively(obj[k]) for k in sorted(obj)}
    elif isinstance(obj, list):
        return [sort_yaml_dict_recursively(item) for item in obj]
    else:
        return obj 
[docs]
def show_colored_unified_diff(f1_lines, f2_lines, f1_name, f2_name):
    diff_lines = list(
        unified_diff(f1_lines, f2_lines, fromfile=f2_name, tofile=f1_name)
    )
    html_lines = []
    for line in diff_lines:
        if line.startswith("+") and not line.startswith("+++"):
            color = "#228B22"  # green
        elif line.startswith("-") and not line.startswith("---"):
            color = "#B22222"  # red
        elif line.startswith("@"):
            color = "#1E90FF"  # blue
        else:
            color = "#000000"  # black
        html_lines.append(
            f'<pre style="margin:0; color:{color}; background-color:white;">{line.rstrip()}</pre>'
        )
    display(HTML("".join(html_lines))) 
[docs]
def show_task_config_colored_unified_diff(
    f1_uri: Uri, f2_uri: Uri, f1_name: str, f2_name: str
):
    """
    Displays a colored unified diff of two task config files.
    Args:
        f1_uri (Uri): URI of the first file.
        f2_uri (Uri): URI of the second file.
    """
    file_loader = FileLoader()
    frozen_task_config_file_contents: str
    template_task_config_file_contents: str
    with open(file_loader.load_to_temp_file(file_uri_src=f1_uri).name, "r") as f:
        data = yaml.safe_load(f)
        # sort_keys by default
        frozen_task_config_file_contents = yaml.dump(sort_yaml_dict_recursively(data))
    with open(file_loader.load_to_temp_file(file_uri_src=f2_uri).name, "r") as f:
        data = yaml.safe_load(f)
        template_task_config_file_contents = yaml.dump(sort_yaml_dict_recursively(data))
    show_colored_unified_diff(
        template_task_config_file_contents.splitlines(),
        frozen_task_config_file_contents.splitlines(),
        f1_name=f1_name,
        f2_name=f2_name,
    )