Source code for gigl.src.common.models.layers.normalization

from typing import Union, overload

import torch
from torch.nn import functional as F

from gigl.src.common.types.graph_data import NodeType


@overload
[docs] def l2_normalize_embeddings( node_typed_embeddings: torch.Tensor, ) -> torch.Tensor: ...
@overload def l2_normalize_embeddings( node_typed_embeddings: dict[NodeType, torch.Tensor], ) -> dict[NodeType, torch.Tensor]: ... def l2_normalize_embeddings( node_typed_embeddings: Union[torch.Tensor, dict[NodeType, torch.Tensor]], ) -> Union[torch.Tensor, dict[NodeType, torch.Tensor]]: if isinstance(node_typed_embeddings, dict): for node_type in node_typed_embeddings: node_typed_embeddings[node_type] = F.normalize( node_typed_embeddings[node_type], # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. p=2, dim=-1, ) # ty: ignore[invalid-assignment] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. elif isinstance(node_typed_embeddings, torch.Tensor): node_typed_embeddings = F.normalize(node_typed_embeddings, p=2, dim=-1) else: raise ValueError( f"Expected type torch.Tensor or dict[NodeType, torch.Tensor], got type {type(node_typed_embeddings)}" ) return node_typed_embeddings