Source code for gigl.common.utils.compute.serialization.serialize_np
from typing import Tuple, TypedDict
import msgpack
import numpy as np
from gigl.common.utils.compute.serialization.coder import CoderProtocol
[docs]
class EncodedNdArray(TypedDict):
 
[docs]
class NumpyCoder(CoderProtocol[np.ndarray]):
[docs]
    def encode(self, obj: np.ndarray) -> bytes:
        return msgpack.dumps(
            obj, default=self.__encode_nd_array_helper, use_bin_type=True
        ) 
[docs]
    def decode(self, byte_str: bytes) -> np.ndarray:
        return msgpack.loads(
            byte_str, object_hook=self.__decode_nd_array_helper, raw=False
        ) 
    @staticmethod
    def __decode_nd_array_helper(obj: EncodedNdArray):
        return np.frombuffer(obj["data"], dtype=obj["dtype"]).reshape(obj["shape"])
    @staticmethod
    def __encode_nd_array_helper(array: np.ndarray) -> EncodedNdArray:
        # Using array.data is a slight optimization given that we can use it
        serialized_array: bytes = (
            array.data if array.flags["C_CONTIGUOUS"] else array.tobytes()
        )
        if array.dtype == object:
            raise TypeError(f"can't convert np.ndarray of type {array.dtype}")
        return {
            "dtype": str(array.dtype),
            "shape": array.shape,
            "data": serialized_array,
        }