Source code for gigl.src.common.modeling_task_specs.utils.infer

from collections import defaultdict
from typing import Dict, List, Set, Union

import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.data.hetero_data import HeteroData

from gigl.src.common.models.layers.decoder import LinkPredictionDecoder
from gigl.src.common.models.layers.loss import ModelResultType
from gigl.src.common.types.graph_data import (
    CondensedEdgeType,
    CondensedNodeType,
    NodeId,
    NodeType,
)
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.task_inputs import (
    BatchCombinedScores,
    BatchEmbeddings,
    BatchScores,
    InputBatch,
    NodeAnchorBasedLinkPredictionTaskInputs,
)
from gigl.src.training.v1.lib.data_loaders.node_anchor_based_link_prediction_data_loader import (
    NodeAnchorBasedLinkPredictionBatch,
)
from gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader import (
    RootedNodeNeighborhoodBatch,
)

# TODO (mkolodner-sc) Move PyG Logic to PyG-specific location


[docs] def infer_training_batch( model: Union[torch.nn.parallel.DistributedDataParallel, nn.Module], training_batch: Union[ NodeAnchorBasedLinkPredictionBatch, RootedNodeNeighborhoodBatch, Data, HeteroData, ], gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: torch.device, ) -> Dict[CondensedNodeType, torch.Tensor]: # Compute embeddings for all nodes in the main and random batches. if isinstance(training_batch, NodeAnchorBasedLinkPredictionBatch) or isinstance( training_batch, RootedNodeNeighborhoodBatch ): training_batch = training_batch.graph node_type_to_condensed_node_type_map = ( gbml_config_pb_wrapper.graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map ) supervision_node_types = ( gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_supervision_edge_node_types( should_include_src_nodes=True, should_include_dst_nodes=True, ) ) output_node_types = [NodeType(node_type) for node_type in supervision_node_types] training_batch = training_batch.to(device=device) node_type_to_embeddings: Dict[NodeType, torch.Tensor] = model( data=training_batch, output_node_types=output_node_types, device=device ) return { node_type_to_condensed_node_type_map[node_type]: node_type_to_embeddings[ node_type ] for node_type in node_type_to_embeddings }
[docs] def infer_root_embeddings( model: Union[torch.nn.parallel.DistributedDataParallel, nn.Module], graph: Union[Data, HeteroData], root_node_indices: torch.LongTensor, gbml_config_pb_wrapper: GbmlConfigPbWrapper, device: torch.device, ) -> torch.Tensor: batch_graph = graph.to(device=device) batch_root_node_indices = root_node_indices.to(device=device) output_node_types = list( gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_supervision_edge_node_types( should_include_src_nodes=True, should_include_dst_nodes=False, ) ) # TODO (mkolodner) Add support for multiple root_node_indices node types in Stage 3 HGS if len(output_node_types) != 1: raise NotImplementedError( "Stage 3 HGS is not yet supported -- training can only be performed with one unique source node type." ) node_type_to_embeddings: Dict[NodeType, torch.Tensor] = model( data=batch_graph, output_node_types=output_node_types, device=device ) out = node_type_to_embeddings[output_node_types[0]] embed = out[batch_root_node_indices] return embed
[docs] def infer_task_inputs( model: Union[torch.nn.parallel.DistributedDataParallel, nn.Module], gbml_config_pb_wrapper: GbmlConfigPbWrapper, main_batch: NodeAnchorBasedLinkPredictionBatch, random_neg_batch: RootedNodeNeighborhoodBatch, should_eval: bool, device: torch.device, ) -> NodeAnchorBasedLinkPredictionTaskInputs: # Initializing empty container values batch_scores: List[Dict[CondensedEdgeType, BatchScores]] = [] batch_combined_scores: Dict[CondensedEdgeType, BatchCombinedScores] = {} pos_embeddings: Dict[CondensedEdgeType, torch.FloatTensor] = {} hard_neg_embeddings: Dict[CondensedEdgeType, torch.FloatTensor] = {} repeated_anchor_embeddings: Dict[CondensedEdgeType, torch.FloatTensor] = {} _pos_embeddings: Dict[CondensedEdgeType, List[torch.FloatTensor]] = defaultdict( list ) _hard_neg_embeddings: Dict[ CondensedEdgeType, List[torch.FloatTensor] ] = defaultdict(list) _positive_ids: Dict[CondensedEdgeType, List[torch.LongTensor]] = defaultdict(list) _hard_neg_ids: Dict[CondensedEdgeType, List[torch.LongTensor]] = defaultdict(list) # Map of Condensed Edge Type to list of num_pos_nodes for retrieval calculation repeated_anchor_count: Dict[CondensedEdgeType, List[int]] = defaultdict(list) # Populate main_batch and RNN task inputs field input_batch = InputBatch(main_batch=main_batch, random_neg_batch=random_neg_batch) batch_result_types: Set[ModelResultType] decoder: LinkPredictionDecoder # Unwrap any DDP layers if isinstance(model, torch.nn.parallel.DistributedDataParallel): decoder = model.module.decode batch_result_types = model.module.tasks.result_types else: decoder = model.decode batch_result_types = model.tasks.result_types # If we only have losses which only require the input batch, don't forward here and return the # input batch immediately to minimize computation we don't need, such as encoding and decoding. should_forward_batch: bool = ( ModelResultType.batch_scores in batch_result_types or ModelResultType.batch_embeddings in batch_result_types or ModelResultType.batch_combined_scores in batch_result_types or should_eval ) if not should_forward_batch: return NodeAnchorBasedLinkPredictionTaskInputs( input_batch=input_batch, batch_embeddings=None, batch_scores=batch_scores, batch_combined_scores=batch_combined_scores, ) # Forward input batch through model main_embeddings: Dict[CondensedNodeType, torch.Tensor] = infer_training_batch( model=model, training_batch=main_batch, gbml_config_pb_wrapper=gbml_config_pb_wrapper, device=device, ) random_neg_embeddings = infer_training_batch( model=model, training_batch=random_neg_batch, gbml_config_pb_wrapper=gbml_config_pb_wrapper, device=device, ) main_batch_node_id_mapping: Dict[ CondensedNodeType, Dict[NodeId, NodeId] ] = main_batch.condensed_node_type_to_subgraph_id_to_global_node_id random_negative_batch_node_id_mapping: Dict[ CondensedNodeType, Dict[NodeId, NodeId] ] = random_neg_batch.condensed_node_type_to_subgraph_id_to_global_node_id # Getting all condensed anchor node types for getting query embeddings anchor_node_types = list( gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_supervision_edge_node_types( should_include_src_nodes=True, should_include_dst_nodes=False, ) ) condensed_anchor_node_types = [ gbml_config_pb_wrapper.graph_metadata_pb_wrapper.node_type_to_condensed_node_type_map[ node_type ] for node_type in anchor_node_types ] # TODO (mkolodner-sc) Add support for multiple root_node_indices node types in Stage 3 HGS if len(condensed_anchor_node_types) != 1: raise NotImplementedError( "Stage 3 HGS is not yet supported -- training can only be performed with one unique source node type." ) main_batch_root_node_indices = main_batch.root_node_indices.to(device=device) query_embeddings = main_embeddings[condensed_anchor_node_types[0]][ main_batch_root_node_indices ] # Getting RNN Embeddings and Scores random_neg_root_embeddings: Dict[CondensedNodeType, torch.FloatTensor] = {} random_neg_scores: Dict[CondensedNodeType, torch.FloatTensor] = {} for ( condensed_node_type ) in random_neg_batch.condensed_node_type_to_root_node_indices_map: random_neg_root_node_indices = ( random_neg_batch.condensed_node_type_to_root_node_indices_map[ condensed_node_type ].to(device=device) ) random_neg_root_embeddings[condensed_node_type] = ( random_neg_embeddings[condensed_node_type][random_neg_root_node_indices] # type: ignore if random_neg_root_node_indices.numel() else torch.FloatTensor([]).to(device=device) ) if ModelResultType.batch_scores in batch_result_types or should_eval: random_neg_scores[condensed_node_type] = ( decoder( query_embeddings, random_neg_root_embeddings[condensed_node_type] ) if random_neg_root_embeddings[condensed_node_type].numel() else torch.FloatTensor([]).to(device=device) ) # Loop through all root nodes and populate ids, embeddings, and scores per condensed edge type for root_node_idx, root_node in enumerate(main_batch_root_node_indices): root_node = torch.unsqueeze(root_node, 0) # shape=[1] _batch_scores: Dict[CondensedEdgeType, BatchScores] = {} for ( supervision_edge_type ) in ( gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_supervision_edge_types() ): condensed_supervision_edge_type = gbml_config_pb_wrapper.graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[ supervision_edge_type ] ( condensed_anchor_node_type, condensed_supervision_target_node_type, ) = gbml_config_pb_wrapper.graph_metadata_pb_wrapper.condensed_edge_type_to_condensed_node_types[ condensed_supervision_edge_type ] pos_nodes: torch.LongTensor = main_batch.pos_supervision_edge_data[ condensed_supervision_edge_type ].root_node_to_target_node_id[ root_node.item() ] # shape=[num_pos_nodes] hard_neg_nodes: ( torch.LongTensor ) = main_batch.hard_neg_supervision_edge_data[ condensed_supervision_edge_type ].root_node_to_target_node_id[ root_node.item() ] # shape=[num_hard_neg_nodes] repeated_anchor_count[condensed_supervision_edge_type].append( pos_nodes.numel() ) if pos_nodes.numel(): _pos_embeddings[condensed_supervision_edge_type].append(main_embeddings[condensed_supervision_target_node_type][pos_nodes]) # type: ignore _positive_ids[condensed_supervision_edge_type].append(pos_nodes) if hard_neg_nodes.numel(): _hard_neg_embeddings[condensed_supervision_edge_type].append(main_embeddings[condensed_supervision_target_node_type][hard_neg_nodes]) # type: ignore _hard_neg_ids[condensed_supervision_edge_type].append(hard_neg_nodes) # If any tasks need batch score information, decode embeddings into scores if ModelResultType.batch_scores in batch_result_types or should_eval: pos_scores = ( decoder( main_embeddings[condensed_anchor_node_type][root_node], main_embeddings[condensed_supervision_target_node_type][ pos_nodes ], ) if pos_nodes.numel() else torch.FloatTensor([]).to(device=device) ) hard_neg_scores = ( decoder( main_embeddings[condensed_anchor_node_type][root_node], main_embeddings[condensed_supervision_target_node_type][ hard_neg_nodes ], ) if hard_neg_nodes.numel() else torch.FloatTensor([]).to(device=device) ) random_neg_scores_root = random_neg_scores[ condensed_supervision_target_node_type ][[root_node_idx], :].to(device=device) _batch_scores[condensed_supervision_edge_type] = BatchScores( pos_scores=pos_scores, hard_neg_scores=hard_neg_scores, random_neg_scores=random_neg_scores_root, # type: ignore ) if ModelResultType.batch_scores in batch_result_types or should_eval: batch_scores.append(_batch_scores) # Loop through all condensed edge types and collapse lists of same type into single tensor for ( supervision_edge_type ) in gbml_config_pb_wrapper.task_metadata_pb_wrapper.get_supervision_edge_types(): condensed_supervision_edge_type = gbml_config_pb_wrapper.graph_metadata_pb_wrapper.edge_type_to_condensed_edge_type_map[ supervision_edge_type ] ( condensed_anchor_node_type, condensed_supervision_target_node_type, ) = gbml_config_pb_wrapper.graph_metadata_pb_wrapper.condensed_edge_type_to_condensed_node_types[ condensed_supervision_edge_type ] pos_embeddings[condensed_supervision_edge_type] = ( torch.cat(tuple(_pos_embeddings[condensed_supervision_edge_type])) # type: ignore if len(_pos_embeddings[condensed_supervision_edge_type]) else torch.tensor([]) ) hard_neg_embeddings[condensed_supervision_edge_type] = ( torch.cat(tuple(_hard_neg_embeddings[condensed_supervision_edge_type])) # type: ignore if len(_hard_neg_embeddings[condensed_supervision_edge_type]) else torch.tensor([]) ) repeated_anchor_embeddings[ condensed_supervision_edge_type ] = query_embeddings.repeat_interleave( torch.tensor(repeated_anchor_count[condensed_supervision_edge_type]).to(device=device), dim=0 # type: ignore ) # If needed, calculate task inputs for retrieval loss per condensed edge type if ModelResultType.batch_combined_scores in batch_result_types: candidate_embeddings = torch.cat( ( pos_embeddings[condensed_supervision_edge_type].to(device=device), hard_neg_embeddings[condensed_supervision_edge_type].to( device=device ), random_neg_root_embeddings[ condensed_supervision_target_node_type ].to(device=device), ) ) repeated_subgraph_query_ids = ( main_batch_root_node_indices.repeat_interleave( torch.tensor( repeated_anchor_count[condensed_supervision_edge_type] ).to(device=device) ) ) repeated_global_query_ids = torch.tensor( [ main_batch_node_id_mapping[condensed_anchor_node_type][ node_id.item() ] for node_id in repeated_subgraph_query_ids ] ).to(device=device) subgraph_positive_ids = ( torch.cat(tuple(_positive_ids[condensed_supervision_edge_type])) if len(_positive_ids[condensed_supervision_edge_type]) else torch.tensor([]) ) global_positive_ids = torch.tensor( [ main_batch_node_id_mapping[condensed_supervision_target_node_type][ node_id.item() ] for node_id in subgraph_positive_ids ] ).to(device=device) subgraph_hard_neg_ids = ( torch.cat(tuple(_hard_neg_ids[condensed_supervision_edge_type])) if len(_hard_neg_ids[condensed_supervision_edge_type]) else torch.tensor([]) ) global_hard_neg_ids = torch.tensor( [ main_batch_node_id_mapping[condensed_supervision_target_node_type][ node_id.item() ] for node_id in subgraph_hard_neg_ids ] ).to(device=device) random_neg_root_node_indices = ( random_neg_batch.condensed_node_type_to_root_node_indices_map[ condensed_supervision_target_node_type ].to(device=device) ) subgraph_random_neg_ids = ( random_neg_root_node_indices if random_neg_root_node_indices.numel() else torch.tensor([]) ) global_random_neg_ids = torch.tensor( [ random_negative_batch_node_id_mapping[ condensed_supervision_target_node_type ][node_id.item()] for node_id in subgraph_random_neg_ids ] ).to(device=device) repeated_candidate_scores = ( decoder( repeated_anchor_embeddings[condensed_supervision_edge_type], candidate_embeddings, ) if repeated_anchor_embeddings[condensed_supervision_edge_type].numel() else torch.tensor([]) ) batch_combined_scores[ condensed_supervision_edge_type ] = BatchCombinedScores( repeated_candidate_scores=repeated_candidate_scores, positive_ids=global_positive_ids, # type: ignore hard_neg_ids=global_hard_neg_ids, # type: ignore random_neg_ids=global_random_neg_ids, # type: ignore repeated_query_ids=repeated_global_query_ids, # type: ignore num_unique_query_ids=main_batch_root_node_indices.shape[0], ) # Populate all computed embeddings for task input batch_embeddings = BatchEmbeddings( query_embeddings=query_embeddings, # type: ignore repeated_query_embeddings=repeated_anchor_embeddings, # type: ignore pos_embeddings=pos_embeddings, # type: ignore hard_neg_embeddings=hard_neg_embeddings, # type: ignore random_neg_embeddings=random_neg_root_embeddings, # type: ignore ) return NodeAnchorBasedLinkPredictionTaskInputs( input_batch=input_batch, batch_embeddings=batch_embeddings, batch_scores=batch_scores, batch_combined_scores=batch_combined_scores, )