mattertune.callbacks.ema

Functions

ema_update(ema_model_tuple, ...)

run_ema_update_cpu(ema_model_tuple, ...[, ...])

Classes

EMA(decay[, validate_original_weights, ...])

Implements Exponential Moving Averaging (EMA).

EMAConfig(*, decay[, ...])

EMAOptimizer(optimizer, device[, decay, ...])

EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average of parameters registered in the optimizer.

class mattertune.callbacks.ema.EMA(decay, validate_original_weights=False, every_n_steps=1, cpu_offload=False)[source]

Implements Exponential Moving Averaging (EMA).

When training a model, this callback will maintain moving averages of the trained parameters. When evaluating, we use the moving averages copy of the trained parameters. When saving, we save an additional set of parameters with the prefix ema.

Parameters:
  • decay (float) – The exponential decay used when calculating the moving average. Has to be between 0-1.

  • validate_original_weights (bool) – Validate the original weights, as apposed to the EMA weights.

  • every_n_steps (int) – Apply EMA every N steps.

  • cpu_offload (bool) – Offload weights to CPU.

__init__(decay, validate_original_weights=False, every_n_steps=1, cpu_offload=False)[source]
Parameters:
  • decay (float)

  • validate_original_weights (bool)

  • every_n_steps (int)

  • cpu_offload (bool)

on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

Return type:

None

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

Return type:

None

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

Return type:

None

on_test_start(trainer, pl_module)[source]

Called when the test begins.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

Return type:

None

on_test_end(trainer, pl_module)[source]

Called when the test ends.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

Return type:

None

swap_model_weights(trainer, saving_ema_model=False)[source]
Parameters:
  • trainer (Trainer)

  • saving_ema_model (bool)

save_ema_model(trainer)[source]

Saves an EMA copy of the model + EMA optimizer states for resume.

Parameters:

trainer (Trainer)

save_original_optimizer_state(trainer)[source]
Parameters:

trainer (Trainer)

mattertune.callbacks.ema.ema_update(ema_model_tuple, current_model_tuple, decay)[source]
mattertune.callbacks.ema.run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None)[source]
class mattertune.callbacks.ema.EMAOptimizer(optimizer, device, decay=0.9999, every_n_steps=1, current_step=0)[source]

EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average of parameters registered in the optimizer.

EMA parameters are automatically updated after every step of the optimizer with the following formula:

ema_weight = decay * ema_weight + (1 - decay) * training_weight

To access EMA parameters, use swap_ema_weights() context manager to perform a temporary in-place swap of regular parameters with EMA parameters.

Notes

  • EMAOptimizer is not compatible with APEX AMP O2.

Parameters:
  • optimizer (torch.optim.Optimizer) – optimizer to wrap

  • device (torch.device) – device for EMA parameters

  • decay (float) – decay factor

  • every_n_steps (int)

  • current_step (int)

Returns:

returns an instance of torch.optim.Optimizer that computes EMA of parameters

Example

model = Model().to(device) opt = torch.optim.Adam(model.parameters())

opt = EMAOptimizer(opt, device, 0.9999)

for epoch in range(epochs):

training_loop(model, opt)

regular_eval_accuracy = evaluate(model)

with opt.swap_ema_weights():

ema_eval_accuracy = evaluate(model)

__init__(optimizer, device, decay=0.9999, every_n_steps=1, current_step=0)[source]
Parameters:
  • optimizer (Optimizer)

  • device (device)

  • decay (float)

  • every_n_steps (int)

  • current_step (int)

all_parameters()[source]
Return type:

Iterable[Tensor]

step(closure=None, **kwargs)[source]

Perform a single optimization step to update parameter.

Parameters:

closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

update()[source]
swap_tensors(tensor1, tensor2)[source]
switch_main_parameter_weights(saving_ema_model=False)[source]
Parameters:

saving_ema_model (bool)

swap_ema_weights(enabled=True)[source]

A context manager to in-place swap regular parameters with EMA parameters. It swaps back to the original regular parameters on context manager exit.

Parameters:

enabled (bool) – whether the swap should be performed

join()[source]
state_dict()[source]

Return the state of the optimizer as a dict.

It contains two entries:

  • state: a Dict holding current optimization state. Its content

    differs between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved. state is a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.

  • param_groups: a List containing all parameter groups where each

    parameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group. If a param group was initialized with named_parameters() the names content will also be saved in the state dict.

NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group params (int IDs) and the optimizer param_groups (actual nn.Parameter s) in order to match state WITHOUT additional verification.

A returned state dict might look something like:

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0,
            ...
            'params': [0]
            'param_names' ['param0']  (optional)
        },
        {
            'lr': 0.001,
            'weight_decay': 0.5,
            ...
            'params': [1, 2, 3]
            'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional)
        }
    ]
}
load_state_dict(state_dict)[source]

Load the optimizer state.

Parameters:

state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict().

Note

The names of the parameters (if they exist under the “param_names” key of each param group in state_dict()) will not affect the loading process. To use the parameters’ names for custom cases (such as when the parameters in the loaded state dict differ from those initialized in the optimizer), a custom register_load_state_dict_pre_hook should be implemented to adapt the loaded dict accordingly. If param_names exist in loaded state dict param_groups they will be saved and override the current names, if present, in the optimizer state. If they do not exist in loaded state dict, the optimizer param_names will remain unchanged.

add_param_group(param_group)[source]

Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the Optimizer as training progresses.

Parameters:

param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.

class mattertune.callbacks.ema.EMAConfig(*, decay, validate_original_weights=False, every_n_steps=1, cpu_offload=False)[source]
Parameters:
  • decay (float)

  • validate_original_weights (bool)

  • every_n_steps (int)

  • cpu_offload (bool)

decay: float
validate_original_weights: bool
every_n_steps: int
cpu_offload: bool
construct_callback()[source]