Source code for gigl.src.common.modeling_task_specs.utils.early_stop
import io
from typing import Optional, Tuple
import torch
import torch.nn as nn
from gigl.common.logger import Logger
[docs]
class EarlyStopper:
    """
    Handles early stopping logic, keeping track of the best performing model provided some criterion
    """
    def __init__(
        self,
        early_stop_patience: int,
        should_maximize: bool,
        model: Optional[nn.Module] = None,
    ):
        """
        Args:
            early_stop_patience (int): Maximum allowed number of steps for consecutive decreases in performance
            should_maximize (bool): Whether we minimize or maximize the provided criterion
            model (Optional[nn.Module]): Optional model to provide to early stopper class. If provided, will
                keep track of the state dict of the best model.
        """
        self._should_maximize = should_maximize
        self._early_stop_counter = 0
        self._early_stop_patience = early_stop_patience
        self._prev_best = float("-inf") if self._should_maximize else float("inf")
        self._model = model
        self._best_model_buffer: Optional[io.BytesIO] = None
    def _has_metric_improved(self, value: float) -> bool:
        if self._should_maximize:
            return value > self._prev_best
        else:
            return value < self._prev_best
[docs]
    def step(self, value: float) -> Tuple[bool, bool]:
        """
        Steps through the early stopper provided some criterion. Returns whether the provided criterion improved over the previous best criterion and
        whether we should early stop.
        Args:
            value (float): Criterion used for stepping through early stopper
        Returns:
            bool: Whether there was improvement over previous best criterion
            bool: Whether early stop patience has been reached, indicating early stopping
        """
        has_metric_improved: bool
        should_early_stop: bool
        if self._has_metric_improved(value=value):
            self._early_stop_counter = 0
            logger.info(
                f"Validation criteria improved to {value:.4f} over previous best {self._prev_best}. Resetting early stop counter."
            )
            self._prev_best = value
            if self._model is not None:
                self._best_model_buffer = io.BytesIO()
                self._best_model_buffer.seek(0)
                torch.save(self._model.state_dict(), self._best_model_buffer)
            has_metric_improved = True
        else:
            self._early_stop_counter += 1
            logger.info(
                f"Got validation {value}, which is worse than previous best {self._prev_best}. No improvement in validation criteria for {self._early_stop_counter} consecutive checks. Early Stop Counter: {self._early_stop_counter}"
            )
            has_metric_improved = False
        if self._early_stop_counter >= self._early_stop_patience:
            logger.info(
                f"Early stopping triggered after {self._early_stop_counter} checks without improvement"
            )
            should_early_stop = True
        else:
            should_early_stop = False
        return has_metric_improved, should_early_stop 
    @property
[docs]
    def best_model_state_dict(self) -> Optional[dict[str, torch.Tensor]]:
        if self._best_model_buffer is None:
            return None
        else:
            self._best_model_buffer.seek(0)
            return torch.load(self._best_model_buffer) 
    @property
[docs]
    def best_criterion(self) -> float:
        return self._prev_best