mattertune.configs.recipes.ema

class mattertune.configs.recipes.ema.EMARecipeConfig(*, name='ema', decay, validate_original_weights=False, every_n_steps=1, cpu_offload=False)[source]
Parameters:
  • name (Literal['ema'])

  • decay (Annotated[float, Gt(gt=0)])

  • validate_original_weights (bool)

  • every_n_steps (int)

  • cpu_offload (bool)

name: Literal['ema']
decay: C.PositiveFloat

The exponential decay used when calculating the moving average. Has to be between 0-1.

validate_original_weights: bool

Validate the original weights, as apposed to the EMA weights.

every_n_steps: int

Apply EMA every N steps.

cpu_offload: bool

Offload weights to CPU.

create_lightning_callback()[source]

Creates the PyTorch Lightning callback for this recipe, or returns None if no callback is needed.

class mattertune.configs.recipes.ema.RecipeConfigBase[source]

Base configuration for recipes.

abstract create_lightning_callback()[source]

Creates the PyTorch Lightning callback for this recipe, or returns None if no callback is needed.

Return type:

Callback | None

classmethod ensure_dependencies()[source]

Ensure that all dependencies are installed.

This method should raise an exception if any dependencies are missing, with a message indicating which dependencies are missing and how to install them.