dicee.weight_averaging
Classes
Adaptive stochastic weight averaging |
|
Stochastic Weight Averaging callback. |
|
Stochastic Weight Averaging - Gaussian (SWAG). |
|
Exponential Moving Average (EMA) callback. |
|
Train with Weight Averaging (TWA) using subspace projection + averaging. |
Module Contents
- class dicee.weight_averaging.ASWA(num_epochs, path)
Bases:
dicee.abstracts.AbstractCallbackAdaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly.
- path
- num_epochs
- initial_eval_setting = None
- epoch_count = 0
- alphas = []
- val_aswa = -1
- static compute_mrr(trainer, model) float
- get_aswa_state_dict(model)
- decide(running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model)
Perform Hard Update, software or rejection
- Parameters:
running_model_state_dict
ensemble_state_dict
val_running_model
mrr_updated_ensemble_model
- class dicee.weight_averaging.SWA(swa_start_epoch, swa_c_epochs: int = 1, lr_init: float = 0.1, swa_lr: float = 0.05, max_epochs: int = None)
Bases:
dicee.abstracts.AbstractCallbackStochastic Weight Averaging callback.
- Initialize SWA callback.
- swa_start_epoch: int
The epoch at which to start SWA.
- swa_c_epochs: int
The number of epochs to use for SWA.
- lr_init: float
The initial learning rate.
- swa_lr: float
The learning rate to use during SWA.
- max_epochs: int
The maximum number of epochs. args.num_epochs
- swa_start_epoch
- swa_c_epochs = 1
- swa_lr = 0.05
- lr_init = 0.1
- max_epochs = None
- swa_model = None
- swa_n = 0
- current_epoch = -1
- static moving_average(swa_model, running_model, alpha)
Update SWA model with moving average of current model. Math: # SWA update: # θ_swa ← (1 - alpha) * θ_swa + alpha * θ # alpha = 1 / (n + 1), where n = number of models already averaged # alpha is tracked via self.swa_n in code
- on_train_epoch_start(trainer, model)
Update learning rate according to SWA schedule.
- on_train_epoch_end(trainer, model)
Apply SWA averaging if conditions are met.
- on_fit_end(trainer, model)
Replace main model with SWA model at the end of training.
- class dicee.weight_averaging.SWAG(swa_start_epoch, swa_c_epochs: int = 1, lr_init: float = 0.1, swa_lr: float = 0.05, max_epochs: int = None, max_num_models: int = 20, var_clamp: float = 1e-30)
Bases:
dicee.abstracts.AbstractCallbackStochastic Weight Averaging - Gaussian (SWAG). Parameters
- swa_start_epochint
Epoch at which to start collecting weights.
- swa_c_epochsint
Interval of epochs between updates.
- lr_initfloat
Initial LR.
- swa_lrfloat
LR in SWA / GSWA phase.
- max_epochsint
Total number of epochs.
- max_num_modelsint
Number of models to keep for low-rank covariance approx.
- var_clampfloat
Clamp low variance for stability.
- swa_start_epoch
- swa_c_epochs = 1
- swa_lr = 0.05
- lr_init = 0.1
- max_epochs = None
- max_num_models = 20
- var_clamp = 1e-30
- mean = None
- sq_mean = None
- deviations = []
- gswa_n = 0
- current_epoch = -1
- get_mean_and_var()
Return mean + variance (diagonal part).
- sample(base_model, scale=0.5)
Sample new model from SWAG posterior distribution.
Math: # From SWAG, posterior is approximated as: # θ ~ N(mean, Σ) # where Σ ≈ diag(var) + (1/(K-1)) * D D^T # - mean = running average of weights # - var = elementwise variance (sq_mean - mean^2) # - D = [dev_1, dev_2, …, dev_K], deviations from mean (low-rank approx) # - K = number of collected models
# Sampling step: # 1. θ_diag = mean + scale * std ⊙ ε, where ε ~ N(0, I) # 2. θ_lowrank = θ_diag + (D z) / sqrt(K-1), where z ~ N(0, I_K) # Final sample = θ_lowrank
- on_train_epoch_start(trainer, model)
Update LR schedule (same as SWA).
- on_train_epoch_end(trainer, model)
Collect Gaussian stats at the end of epochs after swa_start.
- on_fit_end(trainer, model)
Set model weights to the collected SWAG mean at the end of training.
- class dicee.weight_averaging.EMA(ema_start_epoch: int, decay: float = 0.999, max_epochs: int = None, ema_c_epochs: int = 1)
Bases:
dicee.abstracts.AbstractCallbackExponential Moving Average (EMA) callback.
- Parameters:
ema_start_epoch (int) – Epoch to start EMA.
decay (float) – EMA decay rate (typical: 0.99 - 0.9999) Math: θ_ema <- decay * θ_ema + (1 - decay) * θ
max_epochs (int) – Maximum number of epochs.
- ema_start_epoch
- decay = 0.999
- max_epochs = None
- ema_c_epochs = 1
- ema_model = None
- current_epoch = -1
- static ema_update(ema_model, running_model, decay: float)
Update EMA model with exponential moving average of current model. Math: # EMA update: # θ_ema ← (1 - alpha) * θ_ema + alpha * θ # alpha = 1 - decay, where decay is the EMA smoothing factor (typical 0.99 - 0.999) # alpha controls how much of the current model θ contributes to the EMA # decay is fixed in code –> can be extended to scheduled
- on_train_epoch_start(trainer, model)
Track current epoch.
- on_train_epoch_end(trainer, model)
Update EMA if past start epoch.
- on_fit_end(trainer, model)
Replace main model with EMA model at the end of training.
- class dicee.weight_averaging.TWA(twa_start_epoch: int, lr_init: float, num_samples: int = 5, reg_lambda: float = 0.0, max_epochs: int = None, twa_c_epochs: int = 1)
Bases:
dicee.abstracts.AbstractCallbackTrain with Weight Averaging (TWA) using subspace projection + averaging.
- Parameters
- twa_start_epochint
Epoch to start TWA.
- lr_initfloat
Learning rate used for β updates.
- num_samplesint
Number of sampled weight snapshots to build projection subspace.
- reg_lambdafloat
Regularization coefficient for β updates.
- max_epochsint
Total number of training epochs.
- twa_c_epochsint
Interval of epochs between TWA updates.
- twa_start_epoch
- num_samples = 5
- reg_lambda = 0.0
- max_epochs = None
- lr_init
- twa_c_epochs = 1
- current_epoch = -1
- weight_samples = []
- twa_model = None
- base_weights = None
- P = None
- beta = None
- sample_weights(model)
Collect sampled weights from the current model and maintain rolling buffer.
- build_projection(weight_samples, k=None)
Build projection subspace from collected weight samples. :param weight_samples: list of flat weight tensors [(D,), …] :param k: number of basis vectors to keep. Defaults to min(N, D).
- Returns:
(D,) base weight vector (average) P: (D, k) projection matrix with top-k basis directions
- Return type:
mean_w
- on_train_epoch_start(trainer, model)
Track epoch.
- on_train_epoch_end(trainer, model)
Main TWA logic: build subspace and update in β space.
# Math: # TWA weight update: # w_twa = mean_w + P * beta # mean_w = (1/n) * sum_i w_i (SWA baseline) # beta <- (1 - eta * lambda) * beta - eta * P^T * g # g = gradient of training loss w.r.t. full model weights # eta = learning rate, lambda = ridge regularization # P = orthonormal basis spanning sampled checkpoints {w_i}
- on_fit_end(trainer, model)
Replace with TWA model at the end.