Source code for gigl.common.beam.sharded_read
from dataclasses import dataclass
import apache_beam as beam
from apache_beam.io.gcp.bigquery import BigQueryQueryPriority
from apache_beam.io.gcp.internal.clients.bigquery import DatasetReference
from apache_beam.pvalue import PBegin
from google.cloud import bigquery
from gigl.common.logger import Logger
@dataclass(frozen=True)
[docs]
class BigQueryShardedReadConfig:
    # The key in the table that we will use to split the data into shards. This should be used if we are operating on
    # very large tables, in which case we want to only read smaller slices of the table at a time to avoid oversized status update
    # payloads.
    # The project ID to use for temporary datasets when running sharded reads.
    # The temporary bigquery dataset name to use when running sharded reads.
    # The number of shards to split the data into. If not provided, the table will be shareded with a default
    # value of 20 shards.
    # TODO (mkolodner-sc): Instead of using this default, infer this value based on number of rows in table
 
def _assert_shard_key_in_table(table_name: str, shard_key: str) -> None:
    """
    Validate that the shard key is a valid column in the BigQuery table.
    """
    client = bigquery.Client()
    table_ref = bigquery.TableReference.from_string(table_name)
    table = client.get_table(table_ref)
    column_names = [field.name for field in table.schema]
    if shard_key not in column_names:
        raise ValueError(
            f"Shard key '{shard_key}' is not a valid column in table '{table_name}'. "
            f"Available columns: {column_names}"
        )
[docs]
class ShardedExportRead(beam.PTransform):
    def __init__(
        self,
        table_name: str,
        sharded_read_info: BigQueryShardedReadConfig,
        **kwargs,
    ):
        super().__init__()
        self._table_name: str = table_name
        self._num_shards: int = sharded_read_info.num_shards
        if self._num_shards <= 0:
            raise ValueError(
                f"Number of shards specified must be greater than 0, got {self._num_shards}"
            )
        self._shard_key: str = sharded_read_info.shard_key
        self._temp_dataset_reference: DatasetReference = DatasetReference(
            projectId=sharded_read_info.project_id,
            datasetId=sharded_read_info.temp_dataset_name,
        )
        self._kwargs = kwargs
        logger.info(
            f"Got ShardedExportRead arguments table_name={table_name}, sharded_read_info={sharded_read_info}, kwargs={kwargs}"
        )
        _assert_shard_key_in_table(self._table_name, self._shard_key)
[docs]
    def expand(self, pbegin: PBegin):
        pcollection_list = []
        for i in range(self._num_shards):
            # We use farm_fingerprint as a determinstic hashing function which will allow us to shard
            # on keys of any type (i.e. strings, integers, etc.) We take the MOD on the returned INT64 value first
            # with the number of shards and then take the ABS value to ensure it is in range [0, num_shards-1].
            # We do it in this order since ABS can error on the largest negative INT64 number, which has no positive equivalent.
            query = (
                f"SELECT * FROM `{self._table_name}` "
                f"WHERE ABS(MOD(FARM_FINGERPRINT(CAST({self._shard_key} AS STRING)), {self._num_shards})) = {i}"
            )
            pcollection_list.append(
                pbegin
                | f"Read slice {i}/{self._num_shards}"
                >> beam.io.ReadFromBigQuery(
                    query=query,
                    use_standard_sql=True,
                    method=beam.io.ReadFromBigQuery.Method.EXPORT,
                    query_priority=BigQueryQueryPriority.INTERACTIVE,
                    temp_dataset=self._temp_dataset_reference,
                    **self._kwargs,
                )
            )
        return pcollection_list | "Flatten shards" >> beam.Flatten()