Source code for gigl.common.utils.compute.serialization.serialize_np

from typing import Tuple, TypedDict

import msgpack
import numpy as np

from gigl.common.utils.compute.serialization.coder import CoderProtocol


[docs] class EncodedNdArray(TypedDict):
[docs] dtype: str
[docs] shape: Tuple
[docs] data: bytes
[docs] class NumpyCoder(CoderProtocol[np.ndarray]):
[docs] def encode(self, obj: np.ndarray) -> bytes: return msgpack.dumps( obj, default=self.__encode_nd_array_helper, use_bin_type=True )
[docs] def decode(self, byte_str: bytes) -> np.ndarray: return msgpack.loads( byte_str, object_hook=self.__decode_nd_array_helper, raw=False )
@staticmethod def __decode_nd_array_helper(obj: EncodedNdArray): return np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"]) @staticmethod def __encode_nd_array_helper(array: np.ndarray) -> EncodedNdArray: # Using array.data is a slight optimization given that we can use it serialized_array: bytes = ( array.data if array.flags["C_CONTIGUOUS"] else array.tobytes() ) if array.dtype == object: raise TypeError(f"can't convert np.ndarray of type {array.dtype}") return { "dtype": str(array.dtype), "shape": array.shape, "data": serialized_array, }