Source code for mattertune.callbacks.model_checkpoint

from __future__ import annotations

from datetime import timedelta
from typing import Literal

import nshconfig as C


[docs] class ModelCheckpointConfig(C.Config): dirpath: str | None = None """Directory to save the model file. Default: ``None``.""" filename: str | None = None """Checkpoint filename. Can contain named formatting options. Default: ``None``.""" monitor: str | None = None """Quantity to monitor. Default: ``None``.""" verbose: bool = False """Verbosity mode. Default: ``False``.""" save_last: Literal[True, False, "link"] | None = None """When True or "link", saves a 'last.ckpt' checkpoint when a checkpoint is saved. Default: ``None``.""" save_top_k: int = 1 """If save_top_k=k, save k models with best monitored quantity. Default: ``1``.""" save_weights_only: bool = False """If True, only save model weights. Default: ``False``.""" mode: Literal["min", "max"] = "min" """One of {'min', 'max'}. For 'min' training stops when monitored quantity stops decreasing. Default: ``'min'``.""" auto_insert_metric_name: bool = True """Whether to automatically insert metric name in checkpoint filename. Default: ``True``.""" every_n_train_steps: int | None = None """Number of training steps between checkpoints. Default: ``None``.""" train_time_interval: timedelta | None = None """Checkpoints are monitored at the specified time interval. Default: ``None``.""" every_n_epochs: int | None = None """Number of epochs between checkpoints. Default: ``None``.""" save_on_train_epoch_end: bool | None = None """Whether to run checkpointing at end of training epoch. Default: ``None``.""" enable_version_counter: bool = True """Whether to append version to existing filenames. Default: ``True``."""
[docs] def create_callback(self): """Creates a ModelCheckpoint callback instance from this config.""" from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint return ModelCheckpoint( dirpath=self.dirpath, filename=self.filename, monitor=self.monitor, verbose=self.verbose, save_last=self.save_last, save_top_k=self.save_top_k, save_weights_only=self.save_weights_only, mode=self.mode, auto_insert_metric_name=self.auto_insert_metric_name, every_n_train_steps=self.every_n_train_steps, train_time_interval=self.train_time_interval, every_n_epochs=self.every_n_epochs, save_on_train_epoch_end=self.save_on_train_epoch_end, enable_version_counter=self.enable_version_counter, )