Source code for gigl.src.inference.v1.lib.inference_output_schema
from __future__ import annotations
from typing import Dict, List, NamedTuple, Optional
import google.cloud.bigquery as bigquery
[docs]
DEFAULT_NODE_ID_FIELD = "node_id"
[docs]
DEFAULT_EMBEDDING_FIELD = "emb"
[docs]
DEFAULT_PREDICTION_FIELD = "pred"
[docs]
class InferenceOutputBigqueryTableSchema(NamedTuple):
"""Thin container for inference output asset metadata
which enables us to build and produce schemas to be fed into
beam.io.WriteToBigQuery. Enables us to track the node
identifier, which assists during de-enumeration.
"""
[docs]
schema: Optional[Dict[str, List[Dict[str, str]]]] = None
[docs]
node_field: Optional[str] = None
[docs]
class InferenceOutputBigqueryTableSchemaBuilder:
def __init__(self) -> None:
self.reset()
[docs]
def reset(self) -> None:
self._fields: Dict[str, bigquery.SchemaField] = dict()
self._node_field: Optional[str] = None
[docs]
def add_field(
self, name: str, field_type: str, mode: str
) -> InferenceOutputBigqueryTableSchemaBuilder:
self._fields[name] = bigquery.SchemaField(
name=name, field_type=field_type, mode=mode
)
return self
[docs]
def register_node_field(
self, name: str
) -> InferenceOutputBigqueryTableSchemaBuilder:
assert name in self._fields, f"Could not find field {name} in output fields."
self._node_field = name
return self
def _build_schema_property(self) -> Dict[str, List[Dict[str, str]]]:
schema_fields = [
{"name": field.name, "type": field.field_type, "mode": field.mode}
for field in self._fields.values()
]
table_schema = {"fields": schema_fields}
return table_schema
[docs]
def build(self) -> InferenceOutputBigqueryTableSchema:
assert (
self._node_field is not None
), "Node field must be defined before building."
assert self._fields is not None, "_fields must be defined before building."
schema = InferenceOutputBigqueryTableSchema(
schema=self._build_schema_property(), node_field=self._node_field
)
self.reset()
return schema
def _build_default_table_schema(
field: str, should_run_unenumeration: bool = False
) -> InferenceOutputBigqueryTableSchema:
builder = InferenceOutputBigqueryTableSchemaBuilder()
if should_run_unenumeration:
builder.add_field(
name=DEFAULT_NODE_ID_FIELD, field_type="STRING", mode="REQUIRED"
)
else:
builder.add_field(
name=DEFAULT_NODE_ID_FIELD, field_type="INTEGER", mode="REQUIRED"
)
if field == DEFAULT_EMBEDDING_FIELD:
builder.add_field(
name=DEFAULT_EMBEDDING_FIELD, field_type="FLOAT", mode="REPEATED"
)
elif field == DEFAULT_PREDICTION_FIELD:
builder.add_field(
name=DEFAULT_PREDICTION_FIELD, field_type="INTEGER", mode="REQUIRED"
)
else:
raise ValueError(
f"Expected field to be one of {DEFAULT_EMBEDDING_FIELD, DEFAULT_PREDICTION_FIELD}, got {field}"
)
builder.register_node_field(name=DEFAULT_NODE_ID_FIELD)
schema = builder.build()
return schema
[docs]
DEFAULT_EMBEDDINGS_TABLE_SCHEMA = _build_default_table_schema(
field=DEFAULT_EMBEDDING_FIELD, should_run_unenumeration=False
)
[docs]
DEFAULT_PREDICTIONS_TABLE_SCHEMA = _build_default_table_schema(
field=DEFAULT_PREDICTION_FIELD, should_run_unenumeration=False
)
[docs]
UNENUMERATED_EMBEDDINGS_TABLE_SCHEMA = _build_default_table_schema(
field=DEFAULT_EMBEDDING_FIELD, should_run_unenumeration=True
)
[docs]
UNENUMERATED_PREDICTIONS_TABLE_SCHEMA = _build_default_table_schema(
field=DEFAULT_PREDICTION_FIELD, should_run_unenumeration=True
)