dicee.weight_averaging

Classes

ASWA

Adaptive stochastic weight averaging

SWA

Stochastic Weight Averaging callback.

SWAG

Stochastic Weight Averaging - Gaussian (SWAG).

EMA

Exponential Moving Average (EMA) callback.

TWA

Train with Weight Averaging (TWA) using subspace projection + averaging.

Module Contents

class dicee.weight_averaging.ASWA(num_epochs, path)

Bases: dicee.abstracts.AbstractCallback

Adaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly.

path
num_epochs
initial_eval_setting = None
epoch_count = 0
alphas = []
val_aswa = -1
on_fit_end(trainer, model)

Call at the end of the training.

Parameter

trainer:

model:

rtype:

None

static compute_mrr(trainer, model) float
get_aswa_state_dict(model)
decide(running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model)

Perform Hard Update, software or rejection

Parameters:
  • running_model_state_dict

  • ensemble_state_dict

  • val_running_model

  • mrr_updated_ensemble_model

on_train_epoch_end(trainer, model)

Call at the end of each epoch during training.

Parameter

trainer:

model:

rtype:

None

class dicee.weight_averaging.SWA(swa_start_epoch, swa_c_epochs: int = 1, lr_init: float = 0.1, swa_lr: float = 0.05, max_epochs: int = None)

Bases: dicee.abstracts.AbstractCallback

Stochastic Weight Averaging callback.

Initialize SWA callback.
swa_start_epoch: int

The epoch at which to start SWA.

swa_c_epochs: int

The number of epochs to use for SWA.

lr_init: float

The initial learning rate.

swa_lr: float

The learning rate to use during SWA.

max_epochs: int

The maximum number of epochs. args.num_epochs

swa_start_epoch
swa_c_epochs = 1
swa_lr = 0.05
lr_init = 0.1
max_epochs = None
swa_model = None
swa_n = 0
current_epoch = -1
static moving_average(swa_model, running_model, alpha)

Update SWA model with moving average of current model. Math: # SWA update: # θ_swa ← (1 - alpha) * θ_swa + alpha * θ # alpha = 1 / (n + 1), where n = number of models already averaged # alpha is tracked via self.swa_n in code

on_train_epoch_start(trainer, model)

Update learning rate according to SWA schedule.

on_train_epoch_end(trainer, model)

Apply SWA averaging if conditions are met.

on_fit_end(trainer, model)

Replace main model with SWA model at the end of training.

class dicee.weight_averaging.SWAG(swa_start_epoch, swa_c_epochs: int = 1, lr_init: float = 0.1, swa_lr: float = 0.05, max_epochs: int = None, max_num_models: int = 20, var_clamp: float = 1e-30)

Bases: dicee.abstracts.AbstractCallback

Stochastic Weight Averaging - Gaussian (SWAG). Parameters

swa_start_epochint

Epoch at which to start collecting weights.

swa_c_epochsint

Interval of epochs between updates.

lr_initfloat

Initial LR.

swa_lrfloat

LR in SWA / GSWA phase.

max_epochsint

Total number of epochs.

max_num_modelsint

Number of models to keep for low-rank covariance approx.

var_clampfloat

Clamp low variance for stability.

swa_start_epoch
swa_c_epochs = 1
swa_lr = 0.05
lr_init = 0.1
max_epochs = None
max_num_models = 20
var_clamp = 1e-30
mean = None
sq_mean = None
deviations = []
gswa_n = 0
current_epoch = -1
get_mean_and_var()

Return mean + variance (diagonal part).

sample(base_model, scale=0.5)

Sample new model from SWAG posterior distribution.

Math: # From SWAG, posterior is approximated as: # θ ~ N(mean, Σ) # where Σ ≈ diag(var) + (1/(K-1)) * D D^T # - mean = running average of weights # - var = elementwise variance (sq_mean - mean^2) # - D = [dev_1, dev_2, …, dev_K], deviations from mean (low-rank approx) # - K = number of collected models

# Sampling step: # 1. θ_diag = mean + scale * std ⊙ ε, where ε ~ N(0, I) # 2. θ_lowrank = θ_diag + (D z) / sqrt(K-1), where z ~ N(0, I_K) # Final sample = θ_lowrank

on_train_epoch_start(trainer, model)

Update LR schedule (same as SWA).

on_train_epoch_end(trainer, model)

Collect Gaussian stats at the end of epochs after swa_start.

on_fit_end(trainer, model)

Set model weights to the collected SWAG mean at the end of training.

class dicee.weight_averaging.EMA(ema_start_epoch: int, decay: float = 0.999, max_epochs: int = None, ema_c_epochs: int = 1)

Bases: dicee.abstracts.AbstractCallback

Exponential Moving Average (EMA) callback.

Parameters:
  • ema_start_epoch (int) – Epoch to start EMA.

  • decay (float) – EMA decay rate (typical: 0.99 - 0.9999) Math: θ_ema <- decay * θ_ema + (1 - decay) * θ

  • max_epochs (int) – Maximum number of epochs.

ema_start_epoch
decay = 0.999
max_epochs = None
ema_c_epochs = 1
ema_model = None
current_epoch = -1
static ema_update(ema_model, running_model, decay: float)

Update EMA model with exponential moving average of current model. Math: # EMA update: # θ_ema ← (1 - alpha) * θ_ema + alpha * θ # alpha = 1 - decay, where decay is the EMA smoothing factor (typical 0.99 - 0.999) # alpha controls how much of the current model θ contributes to the EMA # decay is fixed in code –> can be extended to scheduled

on_train_epoch_start(trainer, model)

Track current epoch.

on_train_epoch_end(trainer, model)

Update EMA if past start epoch.

on_fit_end(trainer, model)

Replace main model with EMA model at the end of training.

class dicee.weight_averaging.TWA(twa_start_epoch: int, lr_init: float, num_samples: int = 5, reg_lambda: float = 0.0, max_epochs: int = None, twa_c_epochs: int = 1)

Bases: dicee.abstracts.AbstractCallback

Train with Weight Averaging (TWA) using subspace projection + averaging.

Parameters
twa_start_epochint

Epoch to start TWA.

lr_initfloat

Learning rate used for β updates.

num_samplesint

Number of sampled weight snapshots to build projection subspace.

reg_lambdafloat

Regularization coefficient for β updates.

max_epochsint

Total number of training epochs.

twa_c_epochsint

Interval of epochs between TWA updates.

twa_start_epoch
num_samples = 5
reg_lambda = 0.0
max_epochs = None
lr_init
twa_c_epochs = 1
current_epoch = -1
weight_samples = []
twa_model = None
base_weights = None
P = None
beta = None
sample_weights(model)

Collect sampled weights from the current model and maintain rolling buffer.

build_projection(weight_samples, k=None)

Build projection subspace from collected weight samples. :param weight_samples: list of flat weight tensors [(D,), …] :param k: number of basis vectors to keep. Defaults to min(N, D).

Returns:

(D,) base weight vector (average) P: (D, k) projection matrix with top-k basis directions

Return type:

mean_w

on_train_epoch_start(trainer, model)

Track epoch.

on_train_epoch_end(trainer, model)

Main TWA logic: build subspace and update in β space.

# Math: # TWA weight update: # w_twa = mean_w + P * beta # mean_w = (1/n) * sum_i w_i (SWA baseline) # beta <- (1 - eta * lambda) * beta - eta * P^T * g # g = gradient of training loss w.r.t. full model weights # eta = learning rate, lambda = ridge regularization # P = orthonormal basis spanning sampled checkpoints {w_i}

on_fit_end(trainer, model)

Replace with TWA model at the end.