gigl.src.common.modeling_task_specs.utils.infer#

Functions#

infer_root_embeddings(model, graph, root_node_indices, ...)

infer_task_inputs(model, gbml_config_pb_wrapper, ...)

infer_training_batch(model, training_batch, ...)

Module Contents#

gigl.src.common.modeling_task_specs.utils.infer.infer_root_embeddings(model, graph, root_node_indices, gbml_config_pb_wrapper, device)[source]#
Parameters:
Return type:

torch.Tensor

gigl.src.common.modeling_task_specs.utils.infer.infer_task_inputs(model, gbml_config_pb_wrapper, main_batch, random_neg_batch, should_eval, device)[source]#
Parameters:
Return type:

gigl.src.common.types.task_inputs.NodeAnchorBasedLinkPredictionTaskInputs

gigl.src.common.modeling_task_specs.utils.infer.infer_training_batch(model, training_batch, gbml_config_pb_wrapper, device)[source]#
Parameters:
Return type:

Dict[gigl.src.common.types.graph_data.CondensedNodeType, torch.Tensor]