import datetime
import itertools
import re
from typing import Dict, Iterable, List, Optional, Tuple, Union
import google.api_core.retry
import google.cloud.bigquery as bigquery
from google.api_core.exceptions import NotFound
from google.cloud.bigquery._helpers import _record_field_to_json
from google.cloud.bigquery.job import _AsyncJob
from google.cloud.bigquery.table import RowIterator
from gigl.common import GcsUri, LocalUri, Uri
from gigl.common.logger import Logger
from gigl.common.utils.retry import retry
from gigl.src.common.constants.time import DEFAULT_DATE_FORMAT
from gigl.src.common.utils.time import convert_days_to_ms, current_datetime
def _load_file_to_bq(
source: Uri,
bq_path: str,
client: bigquery.Client,
job_config: bigquery.LoadJobConfig,
) -> _AsyncJob:
try:
if isinstance(source, GcsUri):
# if gcs, need to use load_table_from_uri
load_job = client.load_table_from_uri(
source.uri, bq_path, job_config=job_config
)
elif isinstance(source, LocalUri):
# if local, need to use load_table_from_file
with open(source.uri, "rb") as source_path_obj:
# API request -- starts the job
load_job = client.load_table_from_file(
source_path_obj, bq_path, job_config=job_config
)
else:
raise ValueError(f"Unsupported source type: {type(source)}")
logger.info(f"Loading {source} to {bq_path}")
return load_job.result() # Waits for job to complete.
except Exception as e:
logger.exception(f"Could not load file to BQ. {repr(e)}")
raise e
@retry()
def _load_file_to_bq_with_retry(
source_path: Uri,
bq_path: str,
client: bigquery.Client,
job_config: bigquery.LoadJobConfig,
) -> _AsyncJob:
return _load_file_to_bq(source_path, bq_path, client, job_config)
[docs]
class BqUtils:
def __init__(self, project: Optional[str] = None) -> None:
logger.info(f"BqUtils initialized with project: {project}")
self.__bq_client = bigquery.Client(project=project)
[docs]
def create_bq_dataset(self, dataset_id, exists_ok=True) -> None:
dataset = bigquery.Dataset(dataset_id)
dataset.location = "US"
try:
self.__bq_client.create_dataset(dataset, exists_ok=exists_ok) # API request
logger.info(f"Created dataset {dataset_id}")
except Exception as e:
logger.exception(f"Could not create dataset. {repr(e)}")
[docs]
def get_dataset_name_from_table(self, bq_path: str) -> str:
dataset_id = ".".join(bq_path.split(".")[:-1])
return dataset_id
[docs]
def create_or_empty_bq_table(
self, bq_path: str, schema: Optional[List[bigquery.SchemaField]] = None
) -> None:
bq_path = self.format_bq_path(bq_path)
split_bq_path = bq_path.split(".")
if len(split_bq_path) == 2:
dataset_name = split_bq_path[0]
elif len(split_bq_path) == 3:
dataset_name = split_bq_path[1]
else:
raise Exception(f"Could not parse BQ table path: {bq_path}")
self.__bq_client.create_dataset(
dataset=dataset_name, exists_ok=True
) # No-Op if dataset exists
self.__bq_client.delete_table(
table=bq_path, not_found_ok=True
) # Deletes if table exists
table: Union[str, bigquery.Table] = (
bigquery.Table(table_ref=bq_path, schema=schema) if schema else bq_path
)
self.__bq_client.create_table(
table=table, exists_ok=False
) # Recreate the table
[docs]
def count_number_of_rows_in_bq_table(
self,
bq_table: str,
labels: Dict[str, str] = {},
) -> int:
bq_table = bq_table.replace(":", ".")
ROW_COUNTING_QUERY = f"""
SELECT count(1) AS ct FROM `{bq_table}`
"""
result = self.run_query(query=ROW_COUNTING_QUERY, labels=labels)
for row in result:
n_rows = row["ct"]
return n_rows
[docs]
def count_number_of_columns_in_bq_table(
self,
bq_table: str,
) -> int:
schema = self.fetch_bq_table_schema(bq_table=bq_table)
return len(schema.keys())
[docs]
def run_query(
self,
query,
labels: Dict[str, str],
**job_config_args,
) -> RowIterator:
logger.info(f"Running query: {query}")
job_config = bigquery.QueryJobConfig(**job_config_args)
job_config.labels = labels
# Start the query, passing in the extra configuration.
try:
query_job = self.__bq_client.query(
query, location="US", job_config=job_config
) # API request - starts the query
# Waits for the query to finish and returns the result row iterator object.
result = query_job.result()
return result
except Exception as e:
logger.exception(f"Could not run query: {e}")
raise e
@staticmethod
@staticmethod
[docs]
def join_path(path: str, *paths) -> str:
joined_path = ".".join([path, *paths])
assert joined_path.count(".") <= 2, f"Invalid BQ path: {joined_path}"
return BqUtils.format_bq_path(joined_path)
@staticmethod
[docs]
def parse_bq_table_path(bq_table_path: str) -> Tuple[str, str, str]:
"""
Parses a joined bq table path into its project, dataset, and table names
Args:
bq_table_path (str): Joined bq table path of format `project.dataset.table`
Returns:
bq_project_id (str): Parsed BQ Project ID
bq_dataset_id (str): Parsed Dataset ID
bq_table_name (str): Parsed Table Name
"""
split_bq_table_path = BqUtils.format_bq_path(bq_table_path).split(".")
assert (
len(split_bq_table_path) == 3
), "bqtable_path should be in the format project.dataset.table"
bq_project_id, bq_dataset_id, bq_table_name = split_bq_table_path
return bq_project_id, bq_dataset_id, bq_table_name
[docs]
def update_bq_dataset_retention(
self,
bq_dataset_path: str,
retention_in_days: int,
apply_retroactively: Optional[bool] = False,
) -> None:
"""
Update default retention for a whole BQ dataset.
This applies only to new tables unless apply_retroactively=True.
:param bq_dataset_path: The BigQuery dataset path in the format `project_id.dataset_id`.
:param retention_in_days: The number of days to retain data in BigQuery tables.
:param apply_retroactively: If True, applies this retention policy retroactively to all existing tables in the dataset.
"""
bq_dataset_path = BqUtils.format_bq_path(bq_dataset_path)
dataset = self.__bq_client.get_dataset(bq_dataset_path)
retention_in_ms = convert_days_to_ms(retention_in_days)
dataset.default_table_expiration_ms = retention_in_ms
try:
self.__bq_client.update_dataset(dataset, ["default_table_expiration_ms"])
logger.info(
f"Updated dataset {bq_dataset_path} with default expiration in {retention_in_days} days."
)
except Exception as e:
logger.exception(e)
if apply_retroactively:
for table_item in self.__bq_client.list_tables(dataset):
table_id = table_item.full_table_id
self.update_bq_table_retention(table_id, retention_in_days)
[docs]
def update_bq_table_retention(
self, bq_table_path: str, retention_in_days: int
) -> None:
"""
Update retention of a single BQ table.
:param bq_table_path:
:param retention_in_days:
:param client:
:return:
"""
bq_table_path = BqUtils.format_bq_path(bq_table_path)
table = bigquery.Table(bq_table_path)
expiration_dt = current_datetime() + datetime.timedelta(days=retention_in_days)
table.expires = expiration_dt
try:
self.__bq_client.update_table(table, ["expires"])
logger.info(
f"Updated table {bq_table_path} to expire in {retention_in_days} days."
)
except Exception as e:
logger.exception(e)
[docs]
def does_bq_table_exist(self, bq_table_path: str) -> bool:
exists = False
try:
bq_table_path = BqUtils.format_bq_path(bq_table_path)
self.__bq_client.get_table(bq_table_path) # Make an API request.
exists = True
logger.info(f"Table {bq_table_path} exists.")
except NotFound:
logger.info(f"Table {bq_table_path} not found.")
except Exception as e:
logger.info(f"Could not evaluate table existence. {repr(e)}")
return exists
[docs]
def list_matching_tables(
self, bq_dataset_path: str, table_match_string: str
) -> List[str]:
bq_dataset_path = BqUtils.format_bq_path(bq_dataset_path)
tables = self.__bq_client.list_tables(bq_dataset_path)
matching_tables = list()
for table_list_item in tables:
if table_match_string in table_list_item.table_id:
formatted_table_path = BqUtils.format_bq_path(
table_list_item.full_table_id
)
matching_tables.append(formatted_table_path)
return matching_tables
[docs]
def delete_matching_tables(
self, bq_dataset_path: str, table_match_string: str
) -> None:
try:
bq_dataset_path = BqUtils.format_bq_path(bq_dataset_path)
tables = self.__bq_client.list_tables(bq_dataset_path)
for table_list_item in tables:
if table_match_string in table_list_item.table_id:
formatted_table_path = BqUtils.format_bq_path(
table_list_item.full_table_id
)
self.__bq_client.delete_table(formatted_table_path)
logger.info(f"Deleted table {formatted_table_path}.")
except Exception as e:
logger.exception("Error in deleting tables." + repr(e))
[docs]
def get_table_names_within_date_range(
self,
bq_dataset_path: str,
table_match_string: str,
start_date: str,
end_date: str,
) -> List[str]:
"""
start_date and end_date are in the format of 'YYYYMMDD'
table_match_string is a regex string to match table names
"""
_start_date = datetime.datetime.strptime(start_date, DEFAULT_DATE_FORMAT).date()
_end_date = datetime.datetime.strptime(end_date, DEFAULT_DATE_FORMAT).date()
filtered_tables_by_name = list()
filtered_tables_by_date = list()
bq_dataset_path = BqUtils.format_bq_path(bq_dataset_path)
all_tables = self.__bq_client.list_tables(bq_dataset_path)
for table in all_tables:
if re.search(table_match_string, table.table_id):
filtered_tables_by_name.append(table)
sorted_tables_by_date = sorted(
filtered_tables_by_name, key=lambda x: x.created, reverse=True
)
for table in sorted_tables_by_date:
if _start_date <= table.created.date() <= _end_date:
filtered_tables_by_date.append(
".".join([bq_dataset_path, table.table_id])
)
return filtered_tables_by_date
[docs]
def delete_bq_table_if_exist(
self, bq_table_path: str, not_found_ok: bool = True
) -> None:
"""bq_table_path = 'your-project.your_dataset.your_table'"""
bq_table_path = BqUtils.format_bq_path(bq_table_path)
try:
self.__bq_client.delete_table(bq_table_path, not_found_ok=not_found_ok)
logger.info(f"Table deleted '{bq_table_path}'")
except Exception as e:
logger.exception(f"Failed to delete table '{bq_table_path}' due to \n {e}")
[docs]
def fetch_bq_table_schema(self, bq_table: str) -> Dict[str, bigquery.SchemaField]:
"""
Create a dictionary representation for SchemaFields from BigQuery table.
"""
bq_table = bq_table.replace(":", ".")
bq_schema = self.__bq_client.get_table(bq_table).schema
schema_dict = {field.name: field for field in bq_schema}
return schema_dict
[docs]
def load_file_to_bq(
self,
source_path: Uri,
bq_path: str,
job_config: bigquery.LoadJobConfig,
retry: bool = False,
) -> _AsyncJob:
"""
Uploads a single file to biqquery.
Args:
source_path (Uri): The source file to upload.
bq_path (str): The BigQuery table path to upload to.
job_config (bigquery.LoadJobConfig): The job configuration for the upload.
retry (bool, optional): Whether to retry the upload if it fails. Defaults to False.
Returns: The job object for the upload.
"""
if retry:
result = _load_file_to_bq_with_retry(
source_path, bq_path, self.__bq_client, job_config
)
else:
result = _load_file_to_bq(
source_path, bq_path, self.__bq_client, job_config
)
return result
[docs]
def load_rows_to_bq(
self,
bq_path: str,
schema: List[bigquery.SchemaField],
rows: Iterable[Tuple],
) -> None:
first_item = next(iter(rows), None)
if first_item is None:
logger.warning(f"No rows to insert into {bq_path}.")
return
_BQ_INSERT_REQUEST_LIMIT_BYTES = (
10_000_000 - 1_000_000
) # 10MB is the limit; we leave 1MB of buffer
_RETRY_BACKOFF = google.api_core.retry.Retry(
# retries if and only if the 'reason' is 'backendError' or 'rateLimitExceeded'
predicate=bigquery.DEFAULT_RETRY._predicate,
initial=1.0, # 1 second
maximum=60.0 * 5, # 5 minutes
multiplier=1.5,
timeout=60 * 30, # 30 mins
)
table = bigquery.Table(table_ref=bq_path, schema=schema)
batch_rows = []
estimated_batch_rows_size_bytes = 0
for row in itertools.chain([first_item], rows):
json_row: dict = _record_field_to_json(schema, row)
batch_rows.append(json_row)
estimated_batch_rows_size_bytes += len(str(json_row))
if estimated_batch_rows_size_bytes > _BQ_INSERT_REQUEST_LIMIT_BYTES:
self.__bq_client.insert_rows_json(
table=table,
json_rows=batch_rows,
retry=_RETRY_BACKOFF,
)
batch_rows = []
estimated_batch_rows_size_bytes = 0
if len(batch_rows) > 0:
self.__bq_client.insert_rows_json(
table=table,
json_rows=batch_rows,
retry=_RETRY_BACKOFF,
)
[docs]
def check_columns_exist_in_table(
self, bq_table: str, columns: Iterable[str]
) -> None:
schema = self.fetch_bq_table_schema(bq_table=bq_table)
all_fields = set(schema.keys())
missing_fields = [field for field in columns if field not in all_fields]
if missing_fields:
raise ValueError(f"Fields {missing_fields} missing from table {bq_table}.")
else:
logger.info(f"All requisite fields found in table {bq_table}")
[docs]
def export_to_gcs(
self,
bq_table_path: str,
destination_gcs_uri: GcsUri,
destination_format: str = "NEWLINE_DELIMITED_JSON",
) -> None:
"""
Export a BigQuery table to Google Cloud Storage.
Args:
bq_table_path (str): The full BigQuery table path to export.
destination_gcs_uri (str): The destination GCS URI where the table will be exported.
If the gcs uri has * in it, the table will be exported to multiple shards.
destination_format (str, optional): The format of the exported data. Defaults to 'NEWLINE_DELIMITED_JSON'.
'CSV', 'AVRO', 'PARQUET' also available.
"""
try:
job_config = bigquery.job.ExtractJobConfig()
job_config.destination_format = destination_format
extract_job = self.__bq_client.extract_table(
source=bigquery.TableReference.from_string(bq_table_path),
destination_uris=destination_gcs_uri.uri,
job_config=job_config,
)
logger.info(
f"Exporting `{bq_table_path}` to {destination_gcs_uri} with format '{destination_format}'..."
)
extract_job.result() # Waits for job to complete.
logger.info(
f"Exported `{bq_table_path}` to {destination_gcs_uri} successfully."
)
except Exception as e:
logger.exception(f"Failed to export table to GCS.")
raise e