Source code for gigl.common.utils.test_utils
import argparse
import time
import unittest
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from typing import Iterator, Tuple
from unittest import TestCase
from gigl.common import LocalUri
from gigl.common.logger import Logger
@dataclass(frozen=True)
[docs]
class TestArgs:
"""Container for CLI arguements to Python tests.
Attributes:
test_file_pattern (str): Glob pattern for filtering which test files to run.
See doc comment in `parse_args` for more details.
"""
[docs]
def parse_args() -> TestArgs:
"""Parses test-exclusive CLI arguements."""
parser = argparse.ArgumentParser()
parser.add_argument(
"-tf",
"--test_file_pattern",
default="*_test.py",
help="""
Glob pattern for filtering which test files to run. By default runs *all* files ("*_test.py").
Only *one* regex is supported at a time.
Only the file *name* is checked, if a file *path* is provided then nothing will be matched.
(Unless your file name has "/" in it, which is very unlikely.)
Examples:
```
-tf="frozen_dict_test.py"
-tf="pyg*_test.py"
```
""",
)
args, _ = parser.parse_known_args()
test_args = TestArgs(test_file_pattern=args.test_file_pattern)
logger.info(f"Test args: {test_args}")
return test_args
def _run_individual_test(test: TestCase) -> Tuple[bool, int]:
# If we don't have any test cases, we skip running the test.
# This reduces some noise in the logs.
if test.countTestCases() == 0:
logger.warning(
f"Test {test} has no test cases to run. Skipping execution of this test."
)
return (True, 0)
runner = unittest.TextTestRunner(verbosity=2)
result: unittest.TestResult = runner.run(test=test)
return (result.wasSuccessful(), test.countTestCases())
[docs]
def run_tests(
start_dir: LocalUri, pattern: str, use_sequential_execution: bool = False
) -> bool:
"""
Args:
start_dir (LocalUri): Local Directory for running tests
pattern (str): file text pattern for running tests
use_sequential_execution (bool): Whether sequential exection should be used
Return:
bool: Whether all tests passed successfully
"""
start = time.perf_counter()
loader = unittest.TestLoader()
# Find all tests in "tests/unit" signified by name of the file ending in the provided pattern
suite: unittest.TestSuite = loader.discover(
start_dir=start_dir.uri,
pattern=pattern,
)
was_successful: bool
total_num_test_cases: int = 0
if use_sequential_execution:
runner = unittest.TextTestRunner(verbosity=2)
was_successful = runner.run(suite).wasSuccessful()
total_num_test_cases = suite.countTestCases()
else:
with ProcessPoolExecutor() as executor:
was_successful_iter: Iterator[Tuple[bool, int]] = executor.map(
_run_individual_test, suite._tests
)
was_successful = True
for was_successful_batch, num_test_cases_ran in was_successful_iter:
was_successful = was_successful and was_successful_batch
total_num_test_cases += num_test_cases_ran
logger.info(f"Ran {total_num_test_cases}/{suite.countTestCases()} test cases")
finish = time.perf_counter()
logger.info(f"It took {finish-start: .2f} second(s) to run tests")
return was_successful