Source code for gigl.src.common.models.pyg.utils

import inspect
from typing import Any, Dict, Iterable, Tuple

import torch_geometric

# List of arguments that can be passed to the base class of a PyG message passing layer.
[docs] MESSAGE_PASSING_BASE_CLS_ARGS = list( inspect.signature(torch_geometric.nn.conv.MessagePassing).parameters.keys() )
[docs] def filter_dict( input_dict: Dict[str, Any], keys_to_keep: Iterable[str] = [] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Filters out certain items from an input directory based on keys to keep. Args: input_dict: Input dictionary. keys_to_keep: Iterable of keys to keep from the input dictionary (all others will be discarded). Returns: remaining_kwargs: Dictionary containing the remaining keyword arguments. discarded_kwargs: Dictionary containing the discarded keyword arguments. """ remaining_kwargs = { key: value for key, value in input_dict.items() if key in keys_to_keep } discarded_kwargs = { key: value for key, value in input_dict.items() if key not in keys_to_keep } return remaining_kwargs, discarded_kwargs