gigl.src.common.modeling_task_specs.utils.infer#
Functions#
|
|
|
|
|
Module Contents#
- gigl.src.common.modeling_task_specs.utils.infer.infer_root_embeddings(model, graph, root_node_indices, gbml_config_pb_wrapper, device)[source]#
- Parameters:
model (Union[torch.nn.parallel.DistributedDataParallel, torch.nn.Module])
graph (Union[torch_geometric.data.Data, torch_geometric.data.hetero_data.HeteroData])
root_node_indices (torch.LongTensor)
gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper)
device (torch.device)
- Return type:
torch.Tensor
- gigl.src.common.modeling_task_specs.utils.infer.infer_task_inputs(model, gbml_config_pb_wrapper, main_batch, random_neg_batch, should_eval, device)[source]#
- Parameters:
model (Union[torch.nn.parallel.DistributedDataParallel, torch.nn.Module])
gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper)
random_neg_batch (gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader.RootedNodeNeighborhoodBatch)
should_eval (bool)
device (torch.device)
- Return type:
gigl.src.common.types.task_inputs.NodeAnchorBasedLinkPredictionTaskInputs
- gigl.src.common.modeling_task_specs.utils.infer.infer_training_batch(model, training_batch, gbml_config_pb_wrapper, device)[source]#
- Parameters:
model (Union[torch.nn.parallel.DistributedDataParallel, torch.nn.Module])
training_batch (Union[gigl.src.training.v1.lib.data_loaders.node_anchor_based_link_prediction_data_loader.NodeAnchorBasedLinkPredictionBatch, gigl.src.training.v1.lib.data_loaders.rooted_node_neighborhood_data_loader.RootedNodeNeighborhoodBatch, torch_geometric.data.Data, torch_geometric.data.hetero_data.HeteroData])
gbml_config_pb_wrapper (gigl.src.common.types.pb_wrappers.gbml_config.GbmlConfigPbWrapper)
device (torch.device)
- Return type:
Dict[gigl.src.common.types.graph_data.CondensedNodeType, torch.Tensor]