Source code for mattertune.finetune.optimizer

from __future__ import annotations

import fnmatch
from collections.abc import Iterable, Sequence
from typing import Annotated, Any, Literal

import nshconfig as C
import torch
import torch.nn as nn
from typing_extensions import NotRequired, TypeAliasType, TypedDict, assert_never


[docs] class PerParamHparamsDict(TypedDict): patterns: Sequence[str] """Patterns to match parameter names.""" hparams: dict[str, Any] """Hyperparameters for the matched parameters.""" optimize: NotRequired[bool] """Whether to optimize this parameter. Default is True."""
[docs] class OptimizerConfigBase(C.Config): per_parameter_hparams: Sequence[PerParamHparamsDict] | None = None """Per parameter hyperparameters. This should be a list of dictionaries, each of which has the following keys: - `patterns`: a list of patterns to match parameter names. - `hparams`: a dictionary of hyperparameters for the matched parameters. - `optimize`: whether to optimize this parameter. Default is True. This allows you to, for example, set different learning rates for different parameters."""
[docs] class AdamConfig(OptimizerConfigBase): name: Literal["Adam"] = "Adam" """name of the optimizer.""" lr: C.PositiveFloat """Learning rate.""" eps: C.NonNegativeFloat = 1e-8 """Epsilon.""" betas: tuple[C.PositiveFloat, C.PositiveFloat] = (0.9, 0.999) """Betas.""" weight_decay: C.NonNegativeFloat = 0.0 """Weight decay.""" amsgrad: bool = False """Whether to use AMSGrad variant of Adam."""
[docs] class AdamWConfig(OptimizerConfigBase): name: Literal["AdamW"] = "AdamW" """name of the optimizer.""" lr: C.PositiveFloat """Learning rate.""" eps: C.NonNegativeFloat = 1e-8 """Epsilon.""" betas: tuple[C.PositiveFloat, C.PositiveFloat] = (0.9, 0.999) """Betas.""" weight_decay: C.NonNegativeFloat = 0.01 """Weight decay.""" amsgrad: bool = False """Whether to use AMSGrad variant of Adam."""
[docs] class SGDConfig(OptimizerConfigBase): name: Literal["SGD"] = "SGD" """name of the optimizer.""" lr: C.PositiveFloat """Learning rate.""" momentum: C.NonNegativeFloat = 0.0 """Momentum.""" weight_decay: C.NonNegativeFloat = 0.0 """Weight decay.""" nestrov: bool = False """Whether to use nestrov."""
OptimizerConfig = TypeAliasType( "OptimizerConfig", Annotated[ AdamConfig | AdamWConfig | SGDConfig, C.Field(discriminator="name"), ], ) def _named_parameters_matching_patterns( named_parameters: Iterable[tuple[str, nn.Parameter]], patterns: Iterable[str], ): for name, param in named_parameters: if ( matching_pattern := next( (pattern for pattern in patterns if fnmatch.fnmatch(name, pattern)), None, ) ) is None: continue yield name, param, matching_pattern def _split_parameters( named_parameters: Iterable[tuple[str, nn.Parameter]], pattern_lists: Iterable[Iterable[str]], ): named_parameters_list = list(named_parameters) all_parameters = [p for _, p in named_parameters_list] parameters: list[list[torch.nn.Parameter]] = [] for patterns in pattern_lists: matching = [ p for _, p, _ in _named_parameters_matching_patterns( named_parameters_list, patterns ) ] parameters.append(matching) # Remove matching parameters from all_parameters all_parameters = [ p for p in all_parameters if all(p is not m for m in matching) ] return parameters, all_parameters
[docs] def create_optimizer( config: OptimizerConfig, named_parameters: Iterable[tuple[str, nn.Parameter]], ) -> torch.optim.Optimizer: default_kwargs: dict[str, Any] match config: case AdamConfig(): default_kwargs = dict( lr=config.lr, eps=config.eps, betas=config.betas, weight_decay=config.weight_decay, amsgrad=config.amsgrad, ) cls = torch.optim.Adam case AdamWConfig(): default_kwargs = dict( lr=config.lr, eps=config.eps, betas=config.betas, weight_decay=config.weight_decay, amsgrad=config.amsgrad, ) cls = torch.optim.AdamW case SGDConfig(): default_kwargs = dict( lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay, nesterov=config.nestrov, ) cls = torch.optim.SGD case _: assert_never(config) # If per_parameter_hparams is not specified, return the optimizer if config.per_parameter_hparams is None: return cls((p for _, p in named_parameters), **default_kwargs) # Otherwise, split parameters parameters, all_parameters = _split_parameters( named_parameters, [d["patterns"] for d in config.per_parameter_hparams] ) params_list: list[dict[str, Any]] = [] for p, d in zip(parameters, config.per_parameter_hparams): if not d.get("optimize", True): continue param_dict = {} param_dict.update(default_kwargs) param_dict.update(d["hparams"]) param_dict["params"] = p params_list.append(param_dict) if all_parameters: params_list.append({"params": all_parameters, **default_kwargs}) return cls(params_list, **default_kwargs)