Source code for mattertune.main

from __future__ import annotations

import logging
from import Sequence
from datetime import timedelta
from typing import Any, Literal, NamedTuple

import nshconfig as C
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
from lightning.pytorch import Trainer
from lightning.pytorch.strategies.strategy import Strategy

from .backbones import ModelConfig
from .callbacks.early_stopping import EarlyStoppingConfig
from .callbacks.model_checkpoint import ModelCheckpointConfig
from .data import DataModuleConfig, MatterTuneDataModule
from .finetune.base import FinetuneModuleBase
from .loggers import CSVLoggerConfig, LoggerConfig
from .registry import backbone_registry, data_registry

log = logging.getLogger(__name__)

[docs] class TuneOutput(NamedTuple): """The output of the MatterTuner.tune method.""" model: FinetuneModuleBase """The trained model.""" trainer: Trainer """The trainer used to train the model."""
[docs] class TrainerConfig(C.Config): accelerator: str = "auto" """Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto") as well as custom accelerator instances.""" strategy: str | Strategy = "auto" """Supports different training strategies with aliases as well custom strategies. Default: ``"auto"``. """ num_nodes: int = 1 """Number of GPU nodes for distributed training. Default: ``1``. """ devices: list[int] | str | int = "auto" """The devices to use. Can be set to a sequence of device indices, "all" to indicate all available devices should be used, or ``"auto"`` for automatic selection based on the chosen accelerator. Default: ``"auto"``. """ precision: _PRECISION_INPUT | None = "32-true" """Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). Can be used on CPU, GPU, TPUs, HPUs or IPUs. Default: ``'32-true'``.""" deterministic: bool | Literal["warn"] | None = None """ If ``True``, sets whether PyTorch operations must use deterministic algorithms. Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``. """ max_epochs: int | None = None """Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs = -1``.""" min_epochs: int | None = None """Force training for at least these many epochs. Disabled by default (None).""" max_steps: int = -1 """Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs`` to ``-1``.""" min_steps: int | None = None """Force training for at least these number of steps. Disabled by default (``None``).""" max_time: str | timedelta | dict[str, int] | None = None """Stop training after this amount of time has passed. Disabled by default (``None``). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a :class:`datetime.timedelta`, or a dictionary with keys that will be passed to :class:`datetime.timedelta`.""" val_check_interval: int | float | None = None """How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or during iteration-based training. Default: ``1.0``. """ check_val_every_n_epoch: int | None = 1 """Perform a validation loop every after every `N` training epochs. If ``None``, validation will be done solely based on the number of training batches, requiring ``val_check_interval`` to be an integer value. Default: ``1``. """ log_every_n_steps: int | None = None """How often to log within steps. Default: ``50``. """ gradient_clip_val: int | float | None = None """The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before. Default: ``None``. """ gradient_clip_algorithm: str | None = None """The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will be set to ``"norm"``. """ checkpoint: ModelCheckpointConfig | None = None """The configuration for the model checkpoint.""" early_stopping: EarlyStoppingConfig | None = None """The configuration for early stopping.""" loggers: Sequence[LoggerConfig] | Literal["default"] = "default" """The loggers to use for logging training metrics. If ``"default"``, will use the CSV logger + the W&B logger if available. Default: ``"default"``. """ additional_trainer_kwargs: dict[str, Any] = {} """ Additional keyword arguments for the Lightning Trainer. This is for advanced users who want to customize the Lightning Trainer, and is not recommended for beginners. """ def _to_lightning_kwargs(self): callbacks = [] if self.checkpoint is not None: callbacks.append(self.checkpoint.create_callback()) if self.early_stopping is not None: callbacks.append(self.early_stopping.create_callback()) loggers = [] if self.loggers == "default": loggers.append(CSVLoggerConfig(save_dir="./logs").create_logger()) else: for logger_config in self.loggers: loggers.append(logger_config.create_logger()) kwargs = { "callbacks": callbacks, "accelerator": self.accelerator, "strategy": self.strategy, "devices": self.devices, "num_nodes": self.num_nodes, "precision": self.precision, "deterministic": self.deterministic, "max_epochs": self.max_epochs, "min_epochs": self.min_epochs, "max_steps": self.max_steps, "min_steps": self.min_steps, "max_time": self.max_time, "val_check_interval": self.val_check_interval, "check_val_every_n_epoch": self.check_val_every_n_epoch, "log_every_n_steps": self.log_every_n_steps, "gradient_clip_val": self.gradient_clip_val, "gradient_clip_algorithm": self.gradient_clip_algorithm, "logger": loggers, } # Add the additional trainer kwargs kwargs.update(self.additional_trainer_kwargs) return kwargs
[docs] @backbone_registry.rebuild_on_registers @data_registry.rebuild_on_registers class MatterTunerConfig(C.Config): data: DataModuleConfig """The configuration for the data.""" model: ModelConfig """The configuration for the model.""" trainer: TrainerConfig = TrainerConfig() """The configuration for the trainer."""
[docs] class MatterTuner:
[docs] def __init__(self, config: MatterTunerConfig): self.config = config
[docs] def tune(self, trainer_kwargs: dict[str, Any] | None = None) -> TuneOutput: # Make sure all the necessary dependencies are installed self.config.model.ensure_dependencies() # Create the model lightning_module = self.config.model.create_model() assert isinstance( lightning_module, FinetuneModuleBase ), f'The backbone model must be a FinetuneModuleBase subclass. Got "{type(lightning_module)}".' # Create the datamodule datamodule = MatterTuneDataModule( # Resolve the full trainer kwargs trainer_kwargs_: dict[str, Any] = self.config.trainer._to_lightning_kwargs() # Update with the user-specified kwargs in the method call if trainer_kwargs is not None: trainer_kwargs_.update(trainer_kwargs) if lightning_module.requires_disabled_inference_mode(): if ( user_inference_mode := trainer_kwargs_.get("inference_mode") ) is not None and user_inference_mode: raise ValueError( "The model requires inference_mode to be disabled. " "But the provided trainer kwargs have inference_mode=True. " "Please set inference_mode=False.\n" "If you think this is a mistake, please report a bug." ) "The model requires inference_mode to be disabled. " "Setting inference_mode=False." ) trainer_kwargs_["inference_mode"] = False # Create the trainer trainer = Trainer(**trainer_kwargs_), datamodule) # Return the trained model return TuneOutput(model=lightning_module, trainer=trainer)