Source code for gigl.src.common.modeling_task_specs.utils.profiler_wrapper
import tempfile
from distutils.util import strtobool
from torch.profiler import (
ProfilerActivity,
profile,
schedule,
tensorboard_trace_handler,
)
from gigl.common import LocalUri
from gigl.common.logger import Logger
[docs]
TMP_PROFILER_LOG_DIR_NAME = LocalUri(tempfile.TemporaryDirectory().name)
[docs]
class TorchProfiler:
def __init__(self, **kwargs) -> None:
[docs]
self.trace_handler = tensorboard_trace_handler(
dir_name=TMP_PROFILER_LOG_DIR_NAME, use_gzip=True # type: ignore
)
[docs]
self.wait = int(kwargs.get("wait", 5))
[docs]
self.warmup = int(kwargs.get("warmup", 2))
[docs]
self.active = int(kwargs.get("active", 2))
[docs]
self.repeat = int(kwargs.get("repeat", 1))
[docs]
self.tracing_schedule = schedule(
wait=self.wait,
warmup=self.warmup,
active=self.active,
repeat=self.repeat,
)
[docs]
self.profile_memory = bool(strtobool(kwargs.get("profile_memory", "True")))
[docs]
self.record_shapes = bool(strtobool(kwargs.get("record_shapes", "False")))
[docs]
self.with_stack = bool(strtobool(kwargs.get("with_stack", "False")))
logger.info(f"Profiler will be instantiated with {self.__dict__}")
[docs]
def profiler_context(self) -> profile:
return profile(
activities=[
ProfilerActivity.CPU,
ProfilerActivity.CUDA,
],
schedule=self.tracing_schedule,
on_trace_ready=self.trace_handler,
profile_memory=self.profile_memory,
record_shapes=self.record_shapes,
with_stack=self.with_stack,
)