Source code for gigl.src.common.models.layers.normalization
from typing import Dict, Union
import torch
from torch.nn import functional as F
from gigl.src.common.types.graph_data import NodeType
[docs]
def l2_normalize_embeddings(
node_typed_embeddings: Union[torch.Tensor, Dict[NodeType, torch.Tensor]]
) -> Union[torch.Tensor, Dict[NodeType, torch.Tensor]]:
if isinstance(node_typed_embeddings, dict):
for node_type in node_typed_embeddings:
node_typed_embeddings[node_type] = F.normalize(
node_typed_embeddings[node_type], p=2, dim=-1
)
elif isinstance(node_typed_embeddings, torch.Tensor):
node_typed_embeddings = F.normalize(node_typed_embeddings, p=2, dim=-1)
else:
raise ValueError(
f"Expected type torch.Tensor or Dict[NodeType, torch.Tensor], got type {type(node_typed_embeddings)}"
)
return node_typed_embeddings