dicee.weight_averaging
Classes
Adaptive stochastic weight averaging |
|
Stochastic Weight Averaging callback. |
|
Stochastic Weight Averaging - Gaussian (SWAG). |
|
Exponential Moving Average (EMA) callback. |
|
Train with Weight Averaging (TWA) using subspace projection + averaging. |
Module Contents
- class dicee.weight_averaging.ASWA(num_epochs, path)[source]
Bases:
dicee.abstracts.AbstractCallbackAdaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly.
- path
- num_epochs
- initial_eval_setting = None
- epoch_count = 0
- alphas = []
- val_aswa = -1
- on_fit_end(trainer, model)[source]
Call at the end of the training.
Parameter
trainer:
model:
- rtype:
None
- class dicee.weight_averaging.SWA(swa_start_epoch, swa_c_epochs: int = 1, lr_init: float = 0.1, swa_lr: float = 0.05, max_epochs: int = None)[source]
Bases:
dicee.abstracts.AbstractCallbackStochastic Weight Averaging callback.
- Initialize SWA callback.
- swa_start_epoch: int
The epoch at which to start SWA.
- swa_c_epochs: int
The number of epochs to use for SWA.
- lr_init: float
The initial learning rate.
- swa_lr: float
The learning rate to use during SWA.
- max_epochs: int
The maximum number of epochs. args.num_epochs
- swa_start_epoch
- swa_c_epochs = 1
- swa_lr = 0.05
- lr_init = 0.1
- max_epochs = None
- swa_model = None
- swa_n = 0
- current_epoch = -1
- class dicee.weight_averaging.SWAG(swa_start_epoch, swa_c_epochs: int = 1, lr_init: float = 0.1, swa_lr: float = 0.05, max_epochs: int = None, max_num_models: int = 20, var_clamp: float = 1e-30)[source]
Bases:
dicee.abstracts.AbstractCallbackStochastic Weight Averaging - Gaussian (SWAG). Parameters
- swa_start_epochint
Epoch at which to start collecting weights.
- swa_c_epochsint
Interval of epochs between updates.
- lr_initfloat
Initial LR.
- swa_lrfloat
LR in SWA / GSWA phase.
- max_epochsint
Total number of epochs.
- max_num_modelsint
Number of models to keep for low-rank covariance approx.
- var_clampfloat
Clamp low variance for stability.
- swa_start_epoch
- swa_c_epochs = 1
- swa_lr = 0.05
- lr_init = 0.1
- max_epochs = None
- max_num_models = 20
- var_clamp = 1e-30
- mean = None
- sq_mean = None
- deviations = []
- gswa_n = 0
- current_epoch = -1
- sample(base_model, scale=0.5)[source]
Sample new model from SWAG posterior distribution.
Math: # From SWAG, posterior is approximated as: # θ ~ N(mean, Σ) # where Σ ≈ diag(var) + (1/(K-1)) * D D^T # - mean = running average of weights # - var = elementwise variance (sq_mean - mean^2) # - D = [dev_1, dev_2, …, dev_K], deviations from mean (low-rank approx) # - K = number of collected models
# Sampling step: # 1. θ_diag = mean + scale * std ⊙ ε, where ε ~ N(0, I) # 2. θ_lowrank = θ_diag + (D z) / sqrt(K-1), where z ~ N(0, I_K) # Final sample = θ_lowrank
- class dicee.weight_averaging.EMA(ema_start_epoch: int, decay: float = 0.999, max_epochs: int = None, ema_c_epochs: int = 1)[source]
Bases:
dicee.abstracts.AbstractCallbackExponential Moving Average (EMA) callback.
- Parameters:
ema_start_epoch (int) – Epoch to start EMA.
decay (float) – EMA decay rate (typical: 0.99 - 0.9999) Math: θ_ema <- decay * θ_ema + (1 - decay) * θ
max_epochs (int) – Maximum number of epochs.
- ema_start_epoch
- decay = 0.999
- max_epochs = None
- ema_c_epochs = 1
- ema_model = None
- current_epoch = -1
- static ema_update(ema_model, running_model, decay: float)[source]
Update EMA model with exponential moving average of current model. Math: # EMA update: # θ_ema ← (1 - alpha) * θ_ema + alpha * θ # alpha = 1 - decay, where decay is the EMA smoothing factor (typical 0.99 - 0.999) # alpha controls how much of the current model θ contributes to the EMA # decay is fixed in code –> can be extended to scheduled
- class dicee.weight_averaging.TWA(twa_start_epoch: int, lr_init: float, num_samples: int = 5, reg_lambda: float = 0.0, max_epochs: int = None, twa_c_epochs: int = 1)[source]
Bases:
dicee.abstracts.AbstractCallbackTrain with Weight Averaging (TWA) using subspace projection + averaging.
- Parameters
- twa_start_epochint
Epoch to start TWA.
- lr_initfloat
Learning rate used for β updates.
- num_samplesint
Number of sampled weight snapshots to build projection subspace.
- reg_lambdafloat
Regularization coefficient for β updates.
- max_epochsint
Total number of training epochs.
- twa_c_epochsint
Interval of epochs between TWA updates.
- twa_start_epoch
- num_samples = 5
- reg_lambda = 0.0
- max_epochs = None
- lr_init
- twa_c_epochs = 1
- current_epoch = -1
- weight_samples = []
- twa_model = None
- base_weights = None
- P = None
- beta = None
- sample_weights(model)[source]
Collect sampled weights from the current model and maintain rolling buffer.
- build_projection(weight_samples, k=None)[source]
Build projection subspace from collected weight samples. :param weight_samples: list of flat weight tensors [(D,), …] :param k: number of basis vectors to keep. Defaults to min(N, D).
- Returns:
(D,) base weight vector (average) P: (D, k) projection matrix with top-k basis directions
- Return type:
mean_w
- on_train_epoch_end(trainer, model)[source]
Main TWA logic: build subspace and update in β space.
# Math: # TWA weight update: # w_twa = mean_w + P * beta # mean_w = (1/n) * sum_i w_i (SWA baseline) # beta <- (1 - eta * lambda) * beta - eta * P^T * g # g = gradient of training loss w.r.t. full model weights # eta = learning rate, lambda = ridge regularization # P = orthonormal basis spanning sampled checkpoints {w_i}