Source code for mattertune.finetune.optimizer

from __future__ import annotations

from collections.abc import Iterable
from typing import Annotated, Literal

import nshconfig as C
import torch
import torch.nn as nn
from typing_extensions import TypeAliasType, assert_never


[docs] class AdamConfig(C.Config): name: Literal["Adam"] = "Adam" """name of the optimizer.""" lr: C.PositiveFloat """Learning rate.""" eps: C.NonNegativeFloat = 1e-8 """Epsilon.""" betas: tuple[C.PositiveFloat, C.PositiveFloat] = (0.9, 0.999) """Betas.""" weight_decay: C.NonNegativeFloat = 0.0 """Weight decay.""" amsgrad: bool = False """Whether to use AMSGrad variant of Adam."""
[docs] class AdamWConfig(C.Config): name: Literal["AdamW"] = "AdamW" """name of the optimizer.""" lr: C.PositiveFloat """Learning rate.""" eps: C.NonNegativeFloat = 1e-8 """Epsilon.""" betas: tuple[C.PositiveFloat, C.PositiveFloat] = (0.9, 0.999) """Betas.""" weight_decay: C.NonNegativeFloat = 0.01 """Weight decay.""" amsgrad: bool = False """Whether to use AMSGrad variant of Adam."""
[docs] class SGDConfig(C.Config): name: Literal["SGD"] = "SGD" """name of the optimizer.""" lr: C.PositiveFloat """Learning rate.""" momentum: C.NonNegativeFloat = 0.0 """Momentum.""" weight_decay: C.NonNegativeFloat = 0.0 """Weight decay.""" nestrov: bool = False """Whether to use nestrov."""
OptimizerConfig = TypeAliasType( "OptimizerConfig", Annotated[ AdamConfig | AdamWConfig | SGDConfig, C.Field(discriminator="name"), ], )
[docs] def create_optimizer( config: OptimizerConfig, parameters: Iterable[nn.Parameter], ) -> torch.optim.Optimizer: match config: case AdamConfig(): optimizer = torch.optim.Adam( parameters, lr=config.lr, eps=config.eps, betas=config.betas, weight_decay=config.weight_decay, amsgrad=config.amsgrad, ) case AdamWConfig(): optimizer = torch.optim.AdamW( parameters, lr=config.lr, eps=config.eps, betas=config.betas, weight_decay=config.weight_decay, amsgrad=config.amsgrad, ) case SGDConfig(): optimizer = torch.optim.SGD( parameters, lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay, nesterov=config.nestrov, ) case _: assert_never(config) return optimizer