mattertune.callbacks.ema
Functions
|
|
|
Classes
|
Implements Exponential Moving Averaging (EMA). |
|
|
|
EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average of parameters registered in the optimizer. |
- class mattertune.callbacks.ema.EMA(decay, validate_original_weights=False, every_n_steps=1, cpu_offload=False)[source]
Implements Exponential Moving Averaging (EMA).
When training a model, this callback will maintain moving averages of the trained parameters. When evaluating, we use the moving averages copy of the trained parameters. When saving, we save an additional set of parameters with the prefix ema.
- Parameters:
decay (float) – The exponential decay used when calculating the moving average. Has to be between 0-1.
validate_original_weights (bool) – Validate the original weights, as apposed to the EMA weights.
every_n_steps (int) – Apply EMA every N steps.
cpu_offload (bool) – Offload weights to CPU.
- __init__(decay, validate_original_weights=False, every_n_steps=1, cpu_offload=False)[source]
- Parameters:
decay (float)
validate_original_weights (bool)
every_n_steps (int)
cpu_offload (bool)
- on_fit_start(trainer, pl_module)[source]
Called when fit begins.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- Return type:
None
- on_validation_start(trainer, pl_module)[source]
Called when the validation loop begins.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- Return type:
None
- on_validation_end(trainer, pl_module)[source]
Called when the validation loop ends.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- Return type:
None
- on_test_start(trainer, pl_module)[source]
Called when the test begins.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- Return type:
None
- on_test_end(trainer, pl_module)[source]
Called when the test ends.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- Return type:
None
- swap_model_weights(trainer, saving_ema_model=False)[source]
- Parameters:
trainer (Trainer)
saving_ema_model (bool)
- mattertune.callbacks.ema.run_ema_update_cpu(ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None)[source]
- class mattertune.callbacks.ema.EMAOptimizer(optimizer, device, decay=0.9999, every_n_steps=1, current_step=0)[source]
EMAOptimizer is a wrapper for torch.optim.Optimizer that computes Exponential Moving Average of parameters registered in the optimizer.
EMA parameters are automatically updated after every step of the optimizer with the following formula:
ema_weight = decay * ema_weight + (1 - decay) * training_weight
To access EMA parameters, use
swap_ema_weights()
context manager to perform a temporary in-place swap of regular parameters with EMA parameters.Notes
EMAOptimizer is not compatible with APEX AMP O2.
- Parameters:
optimizer (torch.optim.Optimizer) – optimizer to wrap
device (torch.device) – device for EMA parameters
decay (float) – decay factor
every_n_steps (int)
current_step (int)
- Returns:
returns an instance of torch.optim.Optimizer that computes EMA of parameters
Example
model = Model().to(device) opt = torch.optim.Adam(model.parameters())
opt = EMAOptimizer(opt, device, 0.9999)
- for epoch in range(epochs):
training_loop(model, opt)
regular_eval_accuracy = evaluate(model)
- with opt.swap_ema_weights():
ema_eval_accuracy = evaluate(model)
- __init__(optimizer, device, decay=0.9999, every_n_steps=1, current_step=0)[source]
- Parameters:
optimizer (Optimizer)
device (device)
decay (float)
every_n_steps (int)
current_step (int)
- step(closure=None, **kwargs)[source]
Perform a single optimization step to update parameter.
- Parameters:
closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.
- swap_ema_weights(enabled=True)[source]
A context manager to in-place swap regular parameters with EMA parameters. It swaps back to the original regular parameters on context manager exit.
- Parameters:
enabled (bool) – whether the swap should be performed
- state_dict()[source]
Return the state of the optimizer as a
dict
.It contains two entries:
state
: a Dict holding current optimization state. Its contentdiffers between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved.
state
is a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.
param_groups
: a List containing all parameter groups where eachparameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group. If a param group was initialized with
named_parameters()
the names content will also be saved in the state dict.
NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group
params
(int IDs) and the optimizerparam_groups
(actualnn.Parameter
s) in order to match state WITHOUT additional verification.A returned state dict might look something like:
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] 'param_names' ['param0'] (optional) }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] 'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional) } ] }
- load_state_dict(state_dict)[source]
Load the optimizer state.
- Parameters:
state_dict (dict) – optimizer state. Should be an object returned from a call to
state_dict()
.
Note
The names of the parameters (if they exist under the “param_names” key of each param group in
state_dict()
) will not affect the loading process. To use the parameters’ names for custom cases (such as when the parameters in the loaded state dict differ from those initialized in the optimizer), a customregister_load_state_dict_pre_hook
should be implemented to adapt the loaded dict accordingly. Ifparam_names
exist in loaded state dictparam_groups
they will be saved and override the current names, if present, in the optimizer state. If they do not exist in loaded state dict, the optimizerparam_names
will remain unchanged.
- add_param_group(param_group)[source]
Add a param group to the
Optimizer
s param_groups.This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the
Optimizer
as training progresses.- Parameters:
param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.
- class mattertune.callbacks.ema.EMAConfig(*, decay, validate_original_weights=False, every_n_steps=1, cpu_offload=False)[source]
- Parameters:
decay (float)
validate_original_weights (bool)
every_n_steps (int)
cpu_offload (bool)
- decay: float
- validate_original_weights: bool
- every_n_steps: int
- cpu_offload: bool