Source code for mattertune.callbacks.early_stopping

from __future__ import annotations

from typing import Literal

import nshconfig as C


[docs] class EarlyStoppingConfig(C.Config): monitor: str = "val/total_loss" """Quantity to be monitored.""" min_delta: float = 0.0 """Minimum change in monitored quantity to qualify as an improvement. Changes of less than or equal to `min_delta` will count as no improvement. Default: ``0.0``.""" patience: int = 3 """Number of validation checks with no improvement after which training will be stopped. Default: ``3``.""" verbose: bool = False """Whether to print messages when improvement is found or early stopping is triggered. Default: ``False``.""" mode: Literal["min", "max"] = "min" """One of 'min' or 'max'. In 'min' mode, training stops when monitored quantity stops decreasing; in 'max' mode it stops when the quantity stops increasing. Default: ``'min'``.""" strict: bool = True """Whether to raise an error if monitored metric is not found in validation metrics. Default: ``True``.""" check_finite: bool = True """Whether to stop training when the monitor becomes NaN or infinite. Default: ``True``.""" stopping_threshold: float | None = None """Stop training immediately once the monitored quantity reaches this threshold. Default: ``None``.""" divergence_threshold: float | None = None """Stop training as soon as the monitored quantity becomes worse than this threshold. Default: ``None``.""" check_on_train_epoch_end: bool | None = None """Whether to run early stopping at the end of training epoch. If False, check runs at validation end. Default: ``None``.""" log_rank_zero_only: bool = False """Whether to log the status of early stopping only for rank 0 process. Default: ``False``."""
[docs] def create_callback(self): from lightning.pytorch.callbacks.early_stopping import EarlyStopping """Creates an EarlyStopping callback instance from this config.""" return EarlyStopping( monitor=self.monitor, min_delta=self.min_delta, patience=self.patience, verbose=self.verbose, mode=self.mode, strict=self.strict, check_finite=self.check_finite, stopping_threshold=self.stopping_threshold, divergence_threshold=self.divergence_threshold, check_on_train_epoch_end=self.check_on_train_epoch_end, log_rank_zero_only=self.log_rank_zero_only, )