Source code for gigl.common.utils.decorator
from typing import Callable, TypeVar
import tensorflow as tf
_ReturnType = TypeVar("_ReturnType") # Generic Return Type of function for decorator
[docs]
def tf_on_cpu(func: Callable[..., _ReturnType]) -> Callable[..., _ReturnType]:
"""
A decorator to run a function using TensorFlow's CPU device.
"""
def wrapper(*args, **kwargs) -> _ReturnType:
with tf.device("/CPU:0"):
result = func(*args, **kwargs)
return result
return wrapper