Source code for gigl.common.utils.decorator
from typing import Callable, TypeVar
import tensorflow as tf
_ReturnType = TypeVar("_ReturnType")  # Generic Return Type of function for decorator
[docs]
def tf_on_cpu(func: Callable[..., _ReturnType]) -> Callable[..., _ReturnType]:
    """
    A decorator to run a function using TensorFlow's CPU device.
    """
    def wrapper(*args, **kwargs) -> _ReturnType:
        with tf.device("/CPU:0"):
            result = func(*args, **kwargs)
        return result
    return wrapper
