dicee.weight_averaging ====================== .. py:module:: dicee.weight_averaging Classes ------- .. autoapisummary:: dicee.weight_averaging.ASWA dicee.weight_averaging.SWA dicee.weight_averaging.SWAG dicee.weight_averaging.EMA dicee.weight_averaging.TWA Module Contents --------------- .. py:class:: ASWA(num_epochs, path) Bases: :py:obj:`dicee.abstracts.AbstractCallback` Adaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly. .. py:attribute:: path .. py:attribute:: num_epochs .. py:attribute:: initial_eval_setting :value: None .. py:attribute:: epoch_count :value: 0 .. py:attribute:: alphas :value: [] .. py:attribute:: val_aswa :value: -1 .. py:method:: on_fit_end(trainer, model) Call at the end of the training. Parameter --------- trainer: model: :rtype: None .. py:method:: compute_mrr(trainer, model) -> float :staticmethod: .. py:method:: get_aswa_state_dict(model) .. py:method:: decide(running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model) Perform Hard Update, software or rejection :param running_model_state_dict: :param ensemble_state_dict: :param val_running_model: :param mrr_updated_ensemble_model: .. py:method:: on_train_epoch_end(trainer, model) Call at the end of each epoch during training. Parameter --------- trainer: model: :rtype: None .. py:class:: SWA(swa_start_epoch, swa_c_epochs: int = 1, lr_init: float = 0.1, swa_lr: float = 0.05, max_epochs: int = None) Bases: :py:obj:`dicee.abstracts.AbstractCallback` Stochastic Weight Averaging callback. Initialize SWA callback. Parameters ---------- swa_start_epoch: int The epoch at which to start SWA. swa_c_epochs: int The number of epochs to use for SWA. lr_init: float The initial learning rate. swa_lr: float The learning rate to use during SWA. max_epochs: int The maximum number of epochs. args.num_epochs .. py:attribute:: swa_start_epoch .. py:attribute:: swa_c_epochs :value: 1 .. py:attribute:: swa_lr :value: 0.05 .. py:attribute:: lr_init :value: 0.1 .. py:attribute:: max_epochs :value: None .. py:attribute:: swa_model :value: None .. py:attribute:: swa_n :value: 0 .. py:attribute:: current_epoch :value: -1 .. py:method:: moving_average(swa_model, running_model, alpha) :staticmethod: Update SWA model with moving average of current model. Math: # SWA update: # θ_swa ← (1 - alpha) * θ_swa + alpha * θ # alpha = 1 / (n + 1), where n = number of models already averaged # alpha is tracked via self.swa_n in code .. py:method:: on_train_epoch_start(trainer, model) Update learning rate according to SWA schedule. .. py:method:: on_train_epoch_end(trainer, model) Apply SWA averaging if conditions are met. .. py:method:: on_fit_end(trainer, model) Replace main model with SWA model at the end of training. .. py:class:: SWAG(swa_start_epoch, swa_c_epochs: int = 1, lr_init: float = 0.1, swa_lr: float = 0.05, max_epochs: int = None, max_num_models: int = 20, var_clamp: float = 1e-30) Bases: :py:obj:`dicee.abstracts.AbstractCallback` Stochastic Weight Averaging - Gaussian (SWAG). Parameters ---------- swa_start_epoch : int Epoch at which to start collecting weights. swa_c_epochs : int Interval of epochs between updates. lr_init : float Initial LR. swa_lr : float LR in SWA / GSWA phase. max_epochs : int Total number of epochs. max_num_models : int Number of models to keep for low-rank covariance approx. var_clamp : float Clamp low variance for stability. .. py:attribute:: swa_start_epoch .. py:attribute:: swa_c_epochs :value: 1 .. py:attribute:: swa_lr :value: 0.05 .. py:attribute:: lr_init :value: 0.1 .. py:attribute:: max_epochs :value: None .. py:attribute:: max_num_models :value: 20 .. py:attribute:: var_clamp :value: 1e-30 .. py:attribute:: mean :value: None .. py:attribute:: sq_mean :value: None .. py:attribute:: deviations :value: [] .. py:attribute:: gswa_n :value: 0 .. py:attribute:: current_epoch :value: -1 .. py:method:: get_mean_and_var() Return mean + variance (diagonal part). .. py:method:: sample(base_model, scale=0.5) Sample new model from SWAG posterior distribution. Math: # From SWAG, posterior is approximated as: # θ ~ N(mean, Σ) # where Σ ≈ diag(var) + (1/(K-1)) * D D^T # - mean = running average of weights # - var = elementwise variance (sq_mean - mean^2) # - D = [dev_1, dev_2, ..., dev_K], deviations from mean (low-rank approx) # - K = number of collected models # Sampling step: # 1. θ_diag = mean + scale * std ⊙ ε, where ε ~ N(0, I) # 2. θ_lowrank = θ_diag + (D z) / sqrt(K-1), where z ~ N(0, I_K) # Final sample = θ_lowrank .. py:method:: on_train_epoch_start(trainer, model) Update LR schedule (same as SWA). .. py:method:: on_train_epoch_end(trainer, model) Collect Gaussian stats at the end of epochs after swa_start. .. py:method:: on_fit_end(trainer, model) Set model weights to the collected SWAG mean at the end of training. .. py:class:: EMA(ema_start_epoch: int, decay: float = 0.999, max_epochs: int = None, ema_c_epochs: int = 1) Bases: :py:obj:`dicee.abstracts.AbstractCallback` Exponential Moving Average (EMA) callback. :param ema_start_epoch: Epoch to start EMA. :type ema_start_epoch: int :param decay: EMA decay rate (typical: 0.99 - 0.9999) Math: θ_ema <- decay * θ_ema + (1 - decay) * θ :type decay: float :param max_epochs: Maximum number of epochs. :type max_epochs: int .. py:attribute:: ema_start_epoch .. py:attribute:: decay :value: 0.999 .. py:attribute:: max_epochs :value: None .. py:attribute:: ema_c_epochs :value: 1 .. py:attribute:: ema_model :value: None .. py:attribute:: current_epoch :value: -1 .. py:method:: ema_update(ema_model, running_model, decay: float) :staticmethod: Update EMA model with exponential moving average of current model. Math: # EMA update: # θ_ema ← (1 - alpha) * θ_ema + alpha * θ # alpha = 1 - decay, where decay is the EMA smoothing factor (typical 0.99 - 0.999) # alpha controls how much of the current model θ contributes to the EMA # decay is fixed in code --> can be extended to scheduled .. py:method:: on_train_epoch_start(trainer, model) Track current epoch. .. py:method:: on_train_epoch_end(trainer, model) Update EMA if past start epoch. .. py:method:: on_fit_end(trainer, model) Replace main model with EMA model at the end of training. .. py:class:: TWA(twa_start_epoch: int, lr_init: float, num_samples: int = 5, reg_lambda: float = 0.0, max_epochs: int = None, twa_c_epochs: int = 1) Bases: :py:obj:`dicee.abstracts.AbstractCallback` Train with Weight Averaging (TWA) using subspace projection + averaging. Parameters ---------- twa_start_epoch : int Epoch to start TWA. lr_init : float Learning rate used for β updates. num_samples : int Number of sampled weight snapshots to build projection subspace. reg_lambda : float Regularization coefficient for β updates. max_epochs : int Total number of training epochs. twa_c_epochs : int Interval of epochs between TWA updates. .. py:attribute:: twa_start_epoch .. py:attribute:: num_samples :value: 5 .. py:attribute:: reg_lambda :value: 0.0 .. py:attribute:: max_epochs :value: None .. py:attribute:: lr_init .. py:attribute:: twa_c_epochs :value: 1 .. py:attribute:: current_epoch :value: -1 .. py:attribute:: weight_samples :value: [] .. py:attribute:: twa_model :value: None .. py:attribute:: base_weights :value: None .. py:attribute:: P :value: None .. py:attribute:: beta :value: None .. py:method:: sample_weights(model) Collect sampled weights from the current model and maintain rolling buffer. .. py:method:: build_projection(weight_samples, k=None) Build projection subspace from collected weight samples. :param weight_samples: list of flat weight tensors [(D,), ...] :param k: number of basis vectors to keep. Defaults to min(N, D). :returns: (D,) base weight vector (average) P: (D, k) projection matrix with top-k basis directions :rtype: mean_w .. py:method:: on_train_epoch_start(trainer, model) Track epoch. .. py:method:: on_train_epoch_end(trainer, model) Main TWA logic: build subspace and update in β space. # Math: # TWA weight update: # w_twa = mean_w + P * beta # mean_w = (1/n) * sum_i w_i (SWA baseline) # beta <- (1 - eta * lambda) * beta - eta * P^T * g # g = gradient of training loss w.r.t. full model weights # eta = learning rate, lambda = ridge regularization # P = orthonormal basis spanning sampled checkpoints {w_i} .. py:method:: on_fit_end(trainer, model) Replace with TWA model at the end.