Source code for gigl.src.common.utils.data.training
from typing import List
import torch
from gigl.src.data_preprocessor.lib.types import FeatureSchema
[docs]
def filter_features(
feature_schema: FeatureSchema,
feature_names: List[str],
x: torch.Tensor,
) -> torch.Tensor:
"""
Returns tensor with features from x based on feature_names
"""
indices = []
for feature in feature_names:
assert feature in feature_schema.feature_index, f"feature {feature} not found"
start, end = feature_schema.feature_index[feature]
indices.extend(list(range(start, end)))
return x[:, indices].view(-1, len(indices))