gigl.src.common.modeling_task_specs.utils.infer#
Functions#
| 
 | |
| 
 | |
| 
 | 
Module Contents#
- gigl.src.common.modeling_task_specs.utils.infer.infer_root_embeddings(model, graph, root_node_indices, gbml_config_pb_wrapper, device)[source]#
- Parameters:
- model (Union[torch.nn.parallel.DistributedDataParallel, torch.nn.Module]) 
- graph (Union[torch_geometric.data.Data, torch_geometric.data.hetero_data.HeteroData]) 
- root_node_indices (torch.LongTensor) 
- gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper) 
- device (torch.device) 
 
- Return type:
- torch.Tensor 
 
- gigl.src.common.modeling_task_specs.utils.infer.infer_task_inputs(model, gbml_config_pb_wrapper, main_batch, random_neg_batch, should_eval, device)[source]#
- Parameters:
- model (Union[torch.nn.parallel.DistributedDataParallel, torch.nn.Module]) 
- gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper) 
- random_neg_batch (gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader.RootedNodeNeighborhoodBatch) 
- should_eval (bool) 
- device (torch.device) 
 
- Return type:
- gigl.src.common.types.task_inputs.NodeAnchorBasedLinkPredictionTaskInputs 
 
- gigl.src.common.modeling_task_specs.utils.infer.infer_training_batch(model, training_batch, gbml_config_pb_wrapper, device)[source]#
- Parameters:
- model (Union[torch.nn.parallel.DistributedDataParallel, torch.nn.Module]) 
- training_batch (Union[gigl.src.training.v1.lib.data_loaders.node_anchor_based_link_prediction_data_loader.NodeAnchorBasedLinkPredictionBatch, gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader.RootedNodeNeighborhoodBatch, torch_geometric.data.Data, torch_geometric.data.hetero_data.HeteroData]) 
- gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper) 
- device (torch.device) 
 
- Return type:
- dict[gigl.src.common.types.graph_data.CondensedNodeType, torch.Tensor] 
 
