Source code for gigl.src.common.utils.data.training
import torch
from gigl.src.data_preprocessor.lib.types import FeatureSchema
[docs]
def filter_features(
    feature_schema: FeatureSchema,
    feature_names: list[str],
    x: torch.Tensor,
) -> torch.Tensor:
    """
    Returns tensor with features from x based on feature_names
    """
    indices = []
    for feature in feature_names:
        assert feature in feature_schema.feature_index, f"feature {feature} not found"
        start, end = feature_schema.feature_index[feature]
        indices.extend(list(range(start, end)))
    return x[:, indices].view(-1, len(indices))