Source code for gigl.src.common.models.layers.normalization
from typing import Union, overload
import torch
from torch.nn import functional as F
from gigl.src.common.types.graph_data import NodeType
@overload
[docs]
def l2_normalize_embeddings(
node_typed_embeddings: torch.Tensor,
) -> torch.Tensor: ...
@overload
def l2_normalize_embeddings(
node_typed_embeddings: dict[NodeType, torch.Tensor],
) -> dict[NodeType, torch.Tensor]: ...
def l2_normalize_embeddings(
node_typed_embeddings: Union[torch.Tensor, dict[NodeType, torch.Tensor]],
) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]:
if isinstance(node_typed_embeddings, dict):
for node_type in node_typed_embeddings:
node_typed_embeddings[node_type] = F.normalize(
node_typed_embeddings[node_type], # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
p=2,
dim=-1,
) # ty: ignore[invalid-assignment] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
elif isinstance(node_typed_embeddings, torch.Tensor):
node_typed_embeddings = F.normalize(node_typed_embeddings, p=2, dim=-1)
else:
raise ValueError(
f"Expected type torch.Tensor or dict[NodeType, torch.Tensor], got type {type(node_typed_embeddings)}"
)
return node_typed_embeddings