Source code for gigl.src.common.models.utils.torch
from typing import Dict, List
import torch
from gigl.src.common.types.graph_data import NodeType
[docs]
def to_hetero_feat(
h: torch.Tensor, type_indices: torch.LongTensor, types: List[str]
) -> Dict[NodeType, torch.Tensor]:
"""
Convert homogeneous graph features into heterogeneous graph feature dict.
Args:
h (torch.Tensor): feature tensor for a homogeneous graph
type_indices (torch.LongTensor): indicates the type of each row in h, corresponding to `types`
types (list): indicates the possible types
Returns
Dict[str, torch.Tensor]: dictionary mapping each type to a tensor of corresponding rows in the heterogeneous graph
"""
h_dict = {}
for type_idx, element_type in enumerate(types):
h_dict[NodeType(element_type)] = h[torch.where(type_indices == type_idx)]
return h_dict