Source code for gigl.common.utils.jupyter_magics

# type: ignore
import hashlib
import os
import pathlib
from collections import defaultdict
from difflib import unified_diff
from enum import Enum
from typing import Optional, Type, Union

import matplotlib.pyplot as plt
import networkx as nx
import tensorflow as tf
import torch_geometric.utils
import yaml
from IPython.display import HTML, display
from torch_geometric.data import HeteroData

from gigl.common import Uri
from gigl.common.collections.frozen_dict import FrozenDict
from gigl.src.common.graph_builder.pyg_graph_builder import PygGraphBuilder
from gigl.src.common.translators.gbml_protos_translator import GbmlProtosTranslator
from gigl.src.common.types.graph_data import CondensedNodeType, EdgeType, Node, NodeType
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.utils.file_loader import FileLoader
from snapchat.research.gbml import training_samples_schema_pb2

[docs] gigl_root_dir = pathlib.Path(__file__).parent.parent.parent.parent.parent
[docs] def change_working_dir_to_gigl_root(): """ Can be used inside notebooks to change the working directory to the GIGL root directory. This is useful for ensuring that relative imports and file paths work correctly no matter where the notebook is located. """ os.chdir(gigl_root_dir) print(f"Changed working directory to: {gigl_root_dir}")
[docs] CHARCOAL = "#36454F"
[docs] BLACK = "#000000"
[docs] class GraphVisualizerLayoutMode(Enum):
[docs] HOMOGENEOUS = "homogeneous"
[docs] BIPARTITE = "bipartite"
[docs] class PbVisualizerFromOutput(Enum):
[docs] SGS = "sgs"
[docs] SPLIT_TRAIN = "split_train"
[docs] SPLIT_VAL = "split_val"
[docs] SPLIT_TEST = "split_test"
[docs] class PbVisualizer: def __init__(self, frozen_task_config: GbmlConfigPbWrapper):
[docs] self.frozen_task_config = frozen_task_config
preprocessed_metadata = ( frozen_task_config.preprocessed_metadata_pb_wrapper.preprocessed_metadata_pb ) graph_metadata_pb_wrapper = frozen_task_config.graph_metadata_pb_wrapper from gigl.src.common.utils.bq import BqUtils bq_utils = BqUtils() # dict[tuple[condensed_node_type, enumerated_node_id], tuple[node_type, unenumerated_node_id]]
[docs] self.enumerated_node_to_unenumerated_node_id_map: dict[ tuple[int, int], tuple[str, int] ] = {}
for ( condensed_node_type, node_metadata, ) in preprocessed_metadata.condensed_node_type_to_preprocessed_metadata.items(): node_type = graph_metadata_pb_wrapper.condensed_node_type_to_node_type_map[ CondensedNodeType(condensed_node_type) ] result = bq_utils.run_query( query=f""" SELECT int_id, node_id FROM `{node_metadata.enumerated_node_ids_bq_table}` """, labels={}, ) for row in result: self.enumerated_node_to_unenumerated_node_id_map[ (condensed_node_type, row.int_id) ] = (node_type, row.node_id) # dict[tuple[node_type, unenumerated_node_id], tuple[condensed_node_type, enumerated_node_id]]
[docs] self.unenumerated_node_id_to_enumerated_node_id_map: dict[ tuple[str, int], tuple[int, int] ] = {}
for (condensed_node_type, int_id), ( node_type, node_id, ) in self.enumerated_node_to_unenumerated_node_id_map.items(): self.unenumerated_node_id_to_enumerated_node_id_map[ (node_type, node_id) ] = (CondensedNodeType(condensed_node_type), int_id)
[docs] def plot_pb( self, pb: Union[ training_samples_schema_pb2.RootedNodeNeighborhood, training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample, ], layout_mode: GraphVisualizerLayoutMode = GraphVisualizerLayoutMode.BIPARTITE, ): if not pb: print("No pb to plot") return builder = PygGraphBuilder() graph_metadata_pb_wrapper = self.frozen_task_config.graph_metadata_pb_wrapper graph_data = GbmlProtosTranslator.graph_data_from_GraphPb( samples=[pb.neighborhood], graph_metadata_pb_wrapper=graph_metadata_pb_wrapper, builder=builder, ) # Extract positive edges if this is a NodeAnchorBasedLinkPredictionSample pos_edges: Optional[dict[tuple[str, str, str], list[tuple[int, int]]]] = None global_root_node: Optional[tuple[int, str]] = None if isinstance( pb, training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample ): pos_edges = defaultdict(list) for edge in pb.pos_edges: edge_type: EdgeType = ( graph_metadata_pb_wrapper.condensed_edge_type_to_edge_type_map[ edge.condensed_edge_type ] ) pos_edges[ ( edge_type.src_node_type, edge_type.relation, edge_type.dst_node_type, ) ].append((edge.src_node_id, edge.dst_node_id)) global_root_node = ( pb.root_node.node_id, graph_metadata_pb_wrapper.condensed_node_type_to_node_type_map[ pb.root_node.condensed_node_type ], ) subgraph_node_to_unenumerated_node_id_map: dict[Node, Node] = {} for ( node, global_node, ) in graph_data.subgraph_node_to_global_node_mapping.items(): condensed_node_type = ( graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[ global_node.type ] ) ( unenumerated_node_type, unenumerated_node_id, ) = self.enumerated_node_to_unenumerated_node_id_map[ (condensed_node_type, global_node.id) ] subgraph_node_to_unenumerated_node_id_map[node] = Node( id=unenumerated_node_id, type=NodeType(unenumerated_node_type), ) return GraphVisualizer.visualize_graph( data=graph_data.to_hetero_data(), layout_mode=layout_mode, subgraph_node_to_unenumerated_node_id_map=subgraph_node_to_unenumerated_node_id_map, pos_edges=pos_edges, global_root_node=global_root_node, )
[docs] def find_node_pb( self, unenumerated_node_id: int, unenumerated_node_type: str, from_output: PbVisualizerFromOutput, pb_type: Type[ Union[ training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample, training_samples_schema_pb2.RootedNodeNeighborhood, ] ], ) -> Optional[ Union[ training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample, training_samples_schema_pb2.RootedNodeNeighborhood, ] ]: tfrecord_uri_prefix: str if from_output == PbVisualizerFromOutput.SGS: flattened_graph_metadata = ( self.frozen_task_config.shared_config.flattened_graph_metadata ) assert hasattr( flattened_graph_metadata, "node_anchor_based_link_prediction_output" ), f"find_node_pb only supported for node_anchor_based_link_prediction, not {flattened_graph_metadata}" if ( pb_type == training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample ): tfrecord_uri_prefix = ( flattened_graph_metadata.node_anchor_based_link_prediction_output.tfrecord_uri_prefix ) elif pb_type == training_samples_schema_pb2.RootedNodeNeighborhood: tfrecord_uri_prefix = flattened_graph_metadata.node_anchor_based_link_prediction_output.node_type_to_random_negative_tfrecord_uri_prefix[ unenumerated_node_type ] else: raise ValueError(f"Unsupported pb_type: {pb_type}") else: assert hasattr( self.frozen_task_config.shared_config.dataset_metadata, "node_anchor_based_link_prediction_dataset", ), f"find_node_pb only supported for node_anchor_based_link_prediction, not {self.frozen_task_config.shared_config.dataset_metadata}" dataset = ( self.frozen_task_config.shared_config.dataset_metadata.node_anchor_based_link_prediction_dataset ) if ( pb_type == training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample ): if from_output == PbVisualizerFromOutput.SPLIT_TRAIN: tfrecord_uri_prefix = dataset.train_main_data_uri elif from_output == PbVisualizerFromOutput.SPLIT_VAL: tfrecord_uri_prefix = dataset.val_main_data_uri elif from_output == PbVisualizerFromOutput.SPLIT_TEST: tfrecord_uri_prefix = dataset.test_main_data_uri else: raise ValueError(f"Unsupported from_output: {from_output}") elif pb_type == training_samples_schema_pb2.RootedNodeNeighborhood: if from_output == PbVisualizerFromOutput.SPLIT_TRAIN: tfrecord_uri_prefix = ( dataset.train_node_type_to_random_negative_data_uri[ unenumerated_node_type ] ) elif from_output == PbVisualizerFromOutput.SPLIT_VAL: tfrecord_uri_prefix = ( dataset.val_node_type_to_random_negative_data_uri[ unenumerated_node_type ] ) elif from_output == PbVisualizerFromOutput.SPLIT_TEST: tfrecord_uri_prefix = ( dataset.test_node_type_to_random_negative_data_uri[ unenumerated_node_type ] ) else: raise ValueError(f"Unsupported from_output: {from_output}") else: raise ValueError(f"Unsupported pb_type: {pb_type}") uri = tfrecord_uri_prefix + "*.tfrecord" ( search_node_type, search_node_id, ) = self.unenumerated_node_id_to_enumerated_node_id_map[ (unenumerated_node_type, unenumerated_node_id) ] print( f"The node id {unenumerated_node_id}, type {unenumerated_node_type} maps to node id {search_node_id}, type {search_node_type}" ) ds = tf.data.TFRecordDataset(tf.io.gfile.glob(uri)).as_numpy_iterator() pb: Optional[ Union[ training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample, training_samples_schema_pb2.RootedNodeNeighborhood, ] ] = None print(f" Looking for node {search_node_id} in {uri}") pb_output: Optional[ Union[ training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample, training_samples_schema_pb2.RootedNodeNeighborhood, ] ] = None for bytestr in ds: try: if pb_type == training_samples_schema_pb2.RootedNodeNeighborhood: pb = training_samples_schema_pb2.RootedNodeNeighborhood() elif ( pb_type == training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample ): pb = ( training_samples_schema_pb2.NodeAnchorBasedLinkPredictionSample() ) else: raise ValueError(f"Unsupported pb_type: {pb_type}") pb.ParseFromString(bytestr) if pb.root_node.node_id == search_node_id: pb_output = pb break except StopIteration: break return pb_output
[docs] class GraphVisualizer: """ Used to build and visualize graph which is user configured in a yaml file. """ # Fixed node color palette — extend as needed
[docs] node_colors = [ "#64B5F6", # blue "#E57373", # red "#81C784", # green "#FFD54F", # yellow "#BA68C8", # purple "#4DB6AC", # teal "#F06292", # pink "#A1887F", # brown "#FFB74D", # orange ]
# Fixed edge color palette — best for white background
[docs] edge_colors = [ "#1565C0", # medium blue "#43A047", # vivid green "#000000", # black ]
@staticmethod
[docs] def assign_node_color(name: str) -> str: """Assign a node color to a name based on deterministic hash and a fixed palette.""" # Use SHA256 for deterministic hashing hash_value = int(hashlib.sha256(name.encode("utf-8")).hexdigest(), 16) return GraphVisualizer.node_colors[ hash_value % len(GraphVisualizer.node_colors) ]
@staticmethod
[docs] def assign_edge_color(name: str) -> str: """Assign an edge color to a name based on deterministic hash and a fixed palette (optimized for white background).""" # Use SHA256 for deterministic hashing hash_value = int(hashlib.sha256(name.encode("utf-8")).hexdigest(), 16) return GraphVisualizer.edge_colors[ hash_value % len(GraphVisualizer.edge_colors) ]
@staticmethod def _create_type_grouped_layout( g, node_index_to_type, node_types, seed=42, layout_mode=GraphVisualizerLayoutMode.BIPARTITE, ): """ Warning: This is mostly just AI slop, but it serves the purpose for now. Create a layout based on the specified mode (bipartite or homogeneous). """ # Handle empty graph case if len(g.nodes()) == 0: return {} if layout_mode == GraphVisualizerLayoutMode.HOMOGENEOUS: # For homogeneous graphs, use layouts that work well for general graph structure num_nodes = len(g.nodes()) if num_nodes <= 30: # Small to medium graphs - use Kamada-Kawai (good for showing structure) try: # Increase scale significantly to prevent node overlap (node_size=500) return nx.kamada_kawai_layout(g, scale=15) except Exception as e: print( f"Kamada-Kawai layout failed: {e}, falling back to spring layout" ) # Fallback to spring layout if kamada_kawai fails # Increase k (ideal distance) and scale to prevent overlap k = max(4.0, num_nodes / 3.0) return nx.spring_layout(g, seed=seed, k=k, iterations=300, scale=15) else: # Large graphs - use spring layout with good parameters k = max(3.0, num_nodes / 6.0) return nx.spring_layout(g, seed=seed, k=k, iterations=250, scale=20) elif layout_mode == GraphVisualizerLayoutMode.BIPARTITE: # Group nodes by their types for bipartite/heterogeneous layout type_to_nodes = {} for node in g.nodes(): node_type = node_index_to_type.get(node, "unknown") if node_type not in type_to_nodes: type_to_nodes[node_type] = [] type_to_nodes[node_type].append(node) num_types = len(type_to_nodes) if num_types == 1: # Single type - use circular layout with more spacing return nx.circular_layout(g, scale=15) elif num_types == 2: # Two types - use bipartite layout with more spacing types = list(type_to_nodes.keys()) first_type_nodes = set(type_to_nodes[types[0]]) return nx.bipartite_layout(g, first_type_nodes, scale=15) else: # Multiple types or fallback - use spring layout with much more spacing k = max( 4.0, len(g.nodes()) / 3.0 ) # Dynamic spacing based on node count return nx.spring_layout(g, seed=seed, k=k, iterations=300, scale=15) else: raise ValueError(f"Invalid layout mode: {layout_mode}") @staticmethod
[docs] def visualize_graph( data: HeteroData, seed=42, layout_mode=GraphVisualizerLayoutMode.BIPARTITE, subgraph_node_to_unenumerated_node_id_map: Optional[ FrozenDict[Node, Node] ] = None, # pos_edges is a dictionary of edge type (src_node_type, relation, dst_node_type) to list of (src_node_id, dst_node_id) pairs pos_edges: Optional[dict[tuple[str, str, str], list[tuple[int, int]]]] = None, global_root_node: Optional[tuple[int, str]] = None, ): """ Warning: This is mostly just AI slop, but it serves the purpose for now. Visualize a graph. Args: data: The HeteroData object to visualize. seed: The seed for the random number generator - fix it to ensure reproducibility in visualizations layout_mode: Either GraphVisualizerLayoutMode.HOMOGENEOUS or GraphVisualizerLayoutMode.BIPARTITE subgraph_node_to_global_node_mapping: A mapping from local node indices to global node indices. """ # Build a mapping from global node indices to node types BEFORE conversion node_index_to_type = {} current_index = 0 # HeteroData stores nodes by type - we need to map the global indices # that NetworkX will use back to the original node types for node_type in data.node_types: if hasattr(data[node_type], "num_nodes"): num_nodes = data[node_type].num_nodes for i in range(num_nodes): node_index_to_type[current_index] = node_type current_index += 1 # Convert to NetworkX g = torch_geometric.utils.to_networkx(data) if subgraph_node_to_unenumerated_node_id_map: mapping = {} new_node_index_to_type = {} for node in g.nodes(): node_type = node_index_to_type.get(node, "unknown") local_node = Node(type=node_type, id=node) unenumerated_node: Node = subgraph_node_to_unenumerated_node_id_map[ local_node ] mapping[node] = unenumerated_node.id # Preserve the node type information for the global node new_node_index_to_type[unenumerated_node.id] = unenumerated_node.type g = nx.relabel_nodes(g, mapping) # Update the node_index_to_type mapping to use global node types node_index_to_type = new_node_index_to_type # type: ignore # Add positive edges to the graph if they don't already exist pos_edge_pairs = set() if pos_edges: for ( src_node_type, relation, dst_node_type, ), edge_pairs in pos_edges.items(): for src_id, dst_id in edge_pairs: pos_edge_pairs.add((src_id, dst_id)) # Add nodes if they don't exist if src_id not in g.nodes(): g.add_node(src_id) node_index_to_type[src_id] = src_node_type if dst_id not in g.nodes(): g.add_node(dst_id) node_index_to_type[dst_id] = dst_node_type # Add edge if it doesn't exist if not g.has_edge(src_id, dst_id): g.add_edge(src_id, dst_id) # Create node type to color mapping node_type_to_color = {} for node_type in data.node_types: node_type_to_color[node_type] = GraphVisualizer.assign_node_color(node_type) # Assign colors based on the mapping we built node_colors = [] for node in g.nodes(): node_type = node_index_to_type.get(node, "unknown") # Get color for this node type if node_type not in node_type_to_color: node_type_to_color[node_type] = GraphVisualizer.assign_node_color( node_type ) node_colors.append(node_type_to_color[node_type]) # Create a larger figure for better node spacing plt.figure(figsize=(10, 6)) # Generate a layout based on the selected mode # Get all unique node types actually present in the graph actual_node_types = set(node_index_to_type.values()) pos = GraphVisualizer._create_type_grouped_layout( g, node_index_to_type, actual_node_types, seed, layout_mode ) # Safety check: if pos is None or empty and we have nodes, create a fallback layout if pos is None or (len(g.nodes()) > 0 and not pos): print("Layout generation failed, using fallback spring layout") k = max(4.0, len(g.nodes()) / 3.0) pos = nx.spring_layout(g, seed=seed, k=k, iterations=300, scale=15) # Identify isolated nodes and root node for special styling isolated_nodes = [node for node in g.nodes() if g.degree(node) == 0] root_node_id = global_root_node[0] if global_root_node else None # Create border styling (red border for root node, thick black border for isolated nodes) node_edge_colors = [] node_line_widths = [] node_sizes = [] for node in g.nodes(): if node == root_node_id: node_edge_colors.append("#E53935") # Red border for root node node_line_widths.append(4) # Thick border for root node node_sizes.append(1000) # Twice the size for root node elif node in isolated_nodes: node_edge_colors.append(BLACK) # Black border for isolated nodes node_line_widths.append(3) node_sizes.append(500) # Normal size else: node_edge_colors.append(CHARCOAL) # Default border color node_line_widths.append(1) node_sizes.append(500) # Normal size # Create edge type to color mapping edge_type_to_color = {} edge_colors = [] # Extract edge types from the graph (now includes any added positive edges) for edge in g.edges(): # Check if this is a positive edge is_positive_edge = (edge[0], edge[1]) in pos_edge_pairs if is_positive_edge: # Color positive edges red edge_colors.append("#E53935") # Red color for positive edges else: # Get node types for source and destination src_node_type = node_index_to_type.get(edge[0], "unknown") dst_node_type = node_index_to_type.get(edge[1], "unknown") # Create edge type identifier edge_type = f"{src_node_type}{dst_node_type}" # Look for a more specific edge type in HeteroData if available if hasattr(data, "edge_types") and data.edge_types: for et in data.edge_types: if len(et) == 3: # (src_type, relation, dst_type) if et[0] == src_node_type and et[2] == dst_node_type: edge_type = f"{et[0]} --{et[1]}--> {et[2]}" break elif ( isinstance(et, tuple) and len(et) == 2 ): # Some formats might be (src, dst) if et[0] == src_node_type and et[1] == dst_node_type: edge_type = f"{et[0]}{et[1]}" break # Assign color to edge type if edge_type not in edge_type_to_color: edge_type_to_color[edge_type] = GraphVisualizer.assign_edge_color( edge_type ) edge_colors.append(edge_type_to_color[edge_type]) # Draw nodes first nx.draw_networkx_nodes( g, pos, node_color=node_colors if node_colors else "lightblue", # type: ignore edgecolors=node_edge_colors if node_edge_colors else CHARCOAL, # type: ignore linewidths=node_line_widths if node_line_widths else 1, # type: ignore node_size=node_sizes if node_sizes else 500, # type: ignore ) # Draw edges - straight for homogeneous, curved for bipartite if g.edges() and edge_colors: if layout_mode == GraphVisualizerLayoutMode.HOMOGENEOUS: # Straight edges for homogeneous graphs nx.draw_networkx_edges( g, pos, edge_color=edge_colors, # type: ignore width=0.75, # 75% of default edge width alpha=0.9, # Less transparent for cleaner look ) else: # Curved edges for bipartite graphs to reduce overlap nx.draw_networkx_edges( g, pos, edge_color=edge_colors, # type: ignore width=0.75, # 75% of default edge width alpha=0.8, # Slightly transparent for better overlap visibility connectionstyle="arc3,rad=0.1", # Curved edges to reduce overlap ) # Draw labels last so they appear on top nx.draw_networkx_labels( g, pos, font_size=10, font_weight="bold", ) # Add a legend to show node type colors and edge types legend_elements = [] # Add node types if len(node_type_to_color) > 1: for node_type in sorted(node_type_to_color.keys()): legend_elements.append( plt.Line2D( [0], [0], marker="o", color="w", markerfacecolor=node_type_to_color[node_type], markersize=10, label=f"Node: {node_type}", ) ) # Add isolated node indicator if isolated_nodes: legend_elements.append( plt.Line2D( [0], [0], marker="o", color="black", markerfacecolor="white", markeredgewidth=3, markersize=10, label="Isolated nodes", ) ) # Add root node indicator if global_root_node: legend_elements.append( plt.Line2D( [0], [0], marker="o", color="#E53935", markerfacecolor="white", markeredgewidth=4, markersize=15, # Larger marker to represent the larger size label="Root node", ) ) # Add positive edges to legend if they exist if pos_edges: legend_elements.append( plt.Line2D( [0], [0], color="#E53935", linewidth=3, label="Positive edges", ) ) # Add edge types if edge_type_to_color: for edge_type in sorted(edge_type_to_color.keys()): legend_elements.append( plt.Line2D( [0], [0], color=edge_type_to_color[edge_type], linewidth=2, label=f"Edge: {edge_type}", ) ) if legend_elements: plt.legend( handles=legend_elements, loc="upper right", bbox_to_anchor=(1.4, 1) ) plt.show()
[docs] def sort_yaml_dict_recursively(obj: dict) -> dict: # We sort the json recursively as the GiGL proto serialization code does not guarantee order of original keys. # This is important for the diff to be stable and not show errors due to key/list order changes. if isinstance(obj, dict): return {k: sort_yaml_dict_recursively(obj[k]) for k in sorted(obj)} elif isinstance(obj, list): return [sort_yaml_dict_recursively(item) for item in obj] else: return obj
[docs] def show_colored_unified_diff(f1_lines, f2_lines, f1_name, f2_name): diff_lines = list( unified_diff(f1_lines, f2_lines, fromfile=f2_name, tofile=f1_name) ) html_lines = [] for line in diff_lines: if line.startswith("+") and not line.startswith("+++"): color = "#228B22" # green elif line.startswith("-") and not line.startswith("---"): color = "#B22222" # red elif line.startswith("@"): color = "#1E90FF" # blue else: color = "#000000" # black html_lines.append( f'<pre style="margin:0; color:{color}; background-color:white;">{line.rstrip()}</pre>' ) display(HTML("".join(html_lines)))
[docs] def show_task_config_colored_unified_diff( f1_uri: Uri, f2_uri: Uri, f1_name: str, f2_name: str ): """ Displays a colored unified diff of two task config files. Args: f1_uri (Uri): URI of the first file. f2_uri (Uri): URI of the second file. """ file_loader = FileLoader() frozen_task_config_file_contents: str template_task_config_file_contents: str with open(file_loader.load_to_temp_file(file_uri_src=f1_uri).name, "r") as f: data = yaml.safe_load(f) # sort_keys by default frozen_task_config_file_contents = yaml.dump(sort_yaml_dict_recursively(data)) with open(file_loader.load_to_temp_file(file_uri_src=f2_uri).name, "r") as f: data = yaml.safe_load(f) template_task_config_file_contents = yaml.dump(sort_yaml_dict_recursively(data)) show_colored_unified_diff( template_task_config_file_contents.splitlines(), frozen_task_config_file_contents.splitlines(), f1_name=f1_name, f2_name=f2_name, )