Source code for gigl.experimental.knowledge_graph_embedding.common.torchrec.large_embedding_lookup
from typing import List
import torch
import torch.nn as nn
import torchrec
from gigl.common.logger import Logger
[docs]
class LargeEmbeddingLookup(nn.Module):
def __init__(self, embeddings_config: List[torchrec.EmbeddingBagConfig]):
super().__init__()
[docs]
self.ebc = torchrec.EmbeddingBagCollection(
tables=embeddings_config,
device=torch.device("meta"),
)
logger.info(
f"EmbeddingBagCollection named parameters: {list(self.ebc.named_parameters())}"
)
[docs]
def forward(
self, sparse_features: torchrec.KeyedJaggedTensor
) -> torchrec.KeyedTensor:
# Forward pass through the embedding bag collection
return self.ebc(sparse_features)