Source code for gigl.experimental.knowledge_graph_embedding.common.torchrec.large_embedding_lookup

from typing import List

import torch
import torch.nn as nn
import torchrec

from gigl.common.logger import Logger

[docs] logger = Logger()
[docs] class LargeEmbeddingLookup(nn.Module): def __init__(self, embeddings_config: List[torchrec.EmbeddingBagConfig]): super().__init__()
[docs] self.ebc = torchrec.EmbeddingBagCollection( tables=embeddings_config, device=torch.device("meta"), )
logger.info( f"EmbeddingBagCollection named parameters: {list(self.ebc.named_parameters())}" )
[docs] def forward( self, sparse_features: torchrec.KeyedJaggedTensor ) -> torchrec.KeyedTensor: # Forward pass through the embedding bag collection return self.ebc(sparse_features)