Source code for gigl.common.utils.proto_utils
from tempfile import NamedTemporaryFile
from typing import Optional, Type, TypeVar
import yaml
from google.protobuf import message
from google.protobuf.json_format import MessageToDict, ParseDict
from omegaconf import OmegaConf
from gigl.common import LocalUri, Uri
from gigl.common.logger import Logger
from gigl.common.omegaconf_resolvers import register_resolvers
from gigl.src.common.utils.file_loader import FileLoader
[docs]
T = TypeVar("T", bound=message.Message) 
[docs]
class ProtoUtils:
    def __init__(self, project: Optional[str] = None) -> None:
        self.__file_loader = FileLoader(project=project)
        register_resolvers()
[docs]
    def read_proto_from_yaml(self, uri: Uri, proto_cls: Type[T]) -> T:
        tfh = self.__file_loader.load_to_temp_file(file_uri_src=uri, delete=False)
        with open(tfh.name, "r") as file:
            raw_data = yaml.safe_load(file)
            omega_conf_obj = OmegaConf.create(raw_data)
        tfh.close()
        obj_dict = OmegaConf.to_object(omega_conf_obj)
        proto = ParseDict(js_dict=obj_dict, message=proto_cls())
        return proto 
[docs]
    def read_proto_from_binary(self, uri: Uri, proto_cls: Type[T]) -> T:
        tfh = self.__file_loader.load_to_temp_file(file_uri_src=uri, delete=False)
        with open(tfh.name, "rb") as file:
            proto_bytes = file.read()
        tfh.close()
        proto = proto_cls()
        proto.ParseFromString(proto_bytes)
        return proto 
[docs]
    def write_proto_to_yaml(self, proto: message.Message, uri: Uri) -> None:
        proto_dict = MessageToDict(message=proto)
        tfh = NamedTemporaryFile(delete=False)
        with open(tfh.name, "w") as file:
            yaml_str = yaml.dump(proto_dict, default_flow_style=False)
            file.write(yaml_str)
        tfh.close()
        self.__file_loader.load_file(file_uri_src=LocalUri(tfh.name), file_uri_dst=uri) 
[docs]
    def write_proto_to_binary(self, proto: message.Message, uri: Uri) -> None:
        tfh = NamedTemporaryFile(delete=False)
        with open(tfh.name, "wb") as file:
            proto_bytes = proto.SerializeToString()
            file.write(proto_bytes)
        tfh.close()
        self.__file_loader.load_file(file_uri_src=LocalUri(tfh.name), file_uri_dst=uri)