Source code for dicee.weight_averaging

import os
import copy
import json
import torch
import torch.nn as nn
from torch._dynamo.eval_frame import OptimizedModule
from pytorch_lightning.utilities import rank_zero_only

from .abstracts import AbstractCallback
from dicee.models.ensemble import EnsembleKGE
from .evaluation.ensemble import evaluate_ensemble_link_prediction_performance

[docs] class ASWA(AbstractCallback): """ Adaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly. """ def __init__(self, num_epochs, path): super().__init__() self.path=path self.num_epochs=num_epochs self.initial_eval_setting = None self.epoch_count=0 self.alphas = [] self.val_aswa = -1
[docs] @rank_zero_only def on_fit_end(self, trainer, model): # super().on_fit_end(trainer, model) if self.initial_eval_setting: # ADD this info back trainer.evaluator.args.eval_model = self.initial_eval_setting param_ensemble = torch.load(f"{self.path}/aswa.pt", torch.device("cpu")) model.load_state_dict(param_ensemble)
[docs] @staticmethod def compute_mrr(trainer, model) -> float: # (2) Enable eval mode. eval_model = copy.deepcopy(model) eval_model.eval() # (3) MRR performance on the validation data of running model. eval_model.to("cpu") last_val_mrr_running_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=eval_model, form_of_labelling=trainer.form_of_labelling, during_training=True)["Val"]["MRR"] del eval_model return last_val_mrr_running_model
[docs] def get_aswa_state_dict(self, model): # (2) Question: Soft update or Rejection?! if isinstance(model,EnsembleKGE): ensemble_state_dict = torch.load(f"{self.path}/aswa.pt") else: ensemble_state_dict = torch.load(f"{self.path}/aswa.pt", torch.device(next(model.parameters()).device)) # Perform provision parameter update. with torch.no_grad(): for k, parameters in model.state_dict().items(): if parameters.dtype == torch.float: ensemble_state_dict[k] = (ensemble_state_dict[k] * sum(self.alphas) + parameters) / (1 + sum(self.alphas)) return ensemble_state_dict
[docs] def decide(self, running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model): """ Perform Hard Update, software or rejection Parameters ---------- running_model_state_dict ensemble_state_dict val_running_model mrr_updated_ensemble_model Returns ------- """ # (1) HARD UPDATE: # If the validation performance of the running model is greater than # the validation performance of updated ASWA and # the validation performance of ASWA if val_running_model > mrr_updated_ensemble_model and val_running_model > self.val_aswa: """Hard Update """ # (1.1) Save the running model as ASWA torch.save(running_model_state_dict, f=f"{self.path}/aswa.pt") # (2.1) Resect alphas/ensemble weights self.alphas.clear() # (2.2) Store the validation performance of ASWA self.val_aswa = val_running_model return True # (2) SOFT UPDATE: # If the validation performance of the running model is less than # the validation performance of updated ASWA if mrr_updated_ensemble_model > self.val_aswa: """Soft update""" self.val_aswa = mrr_updated_ensemble_model torch.save(ensemble_state_dict, f=f"{self.path}/aswa.pt") self.alphas.append(1.0) return True # (3) Rejection: if self.val_aswa > mrr_updated_ensemble_model: """ Ignore """ self.alphas.append(0) return True
[docs] @rank_zero_only def on_train_epoch_end(self, trainer, model): # if (trainer.global_rank == trainer.local_rank == 0) is False: # return None if isinstance(model, OptimizedModule): model = model._orig_mod # (1) Increment epoch counter self.epoch_count += 1 # (2) Save the given eval setting if it is not saved. if self.initial_eval_setting is None: self.initial_eval_setting = trainer.evaluator.args.eval_model trainer.evaluator.args.eval_model = "val" # (3) Compute MRR of the running model. val_running_model = self.compute_mrr(trainer, model) # (4) Initialize ASWA if it is not initialized. if self.val_aswa == -1: torch.save(model.state_dict(), f=f"{self.path}/aswa.pt") self.alphas.append(1.0) self.val_aswa = val_running_model return True else: # (5) Load ASWA ensemble parameters. ensemble_state_dict = self.get_aswa_state_dict(model) # (6) Initialize ASWA ensemble with (5). ensemble = copy.deepcopy(model) ensemble.load_state_dict(ensemble_state_dict) ensemble.to("cpu") # (7) Evaluate (6) on the validation data, i.e., perform the lookahead operation. mrr_updated_ensemble_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=ensemble, form_of_labelling=trainer.form_of_labelling, during_training=True)["Val"]["MRR"] # print(f"| MRR Running {val_running_model:.4f} | MRR ASWA: {self.val_aswa:.4f} |ASWA|:{sum(self.alphas)}") # (8) Decide whether ASWA should be updated via the current running model. del ensemble self.decide(model.state_dict(), ensemble_state_dict, val_running_model, mrr_updated_ensemble_model)
[docs] class SWA(AbstractCallback): """Stochastic Weight Averaging callback. Initialize SWA callback. Parameters ---------- swa_start_epoch: int The epoch at which to start SWA. swa_c_epochs: int The number of epochs to use for SWA. lr_init: float The initial learning rate. swa_lr: float The learning rate to use during SWA. max_epochs: int The maximum number of epochs. args.num_epochs """ def __init__(self, swa_start_epoch, swa_c_epochs:int=1, lr_init:float=0.1, swa_lr:float=0.05, max_epochs :int=None): super().__init__() self.swa_start_epoch = swa_start_epoch self.swa_c_epochs = swa_c_epochs self.swa_lr = swa_lr self.lr_init = lr_init self.max_epochs = max_epochs self.swa_model = None self.swa_n = 0 self.current_epoch = -1
[docs] @staticmethod def moving_average(swa_model, running_model, alpha): """Update SWA model with moving average of current model. Math: # SWA update: # θ_swa ← (1 - alpha) * θ_swa + alpha * θ # alpha = 1 / (n + 1), where n = number of models already averaged # alpha is tracked via self.swa_n in code""" with torch.no_grad(): swa_model.to(next(running_model.parameters()).device) for swa_param, param in zip(swa_model.parameters(), running_model.parameters()): swa_param.data = (1.0 - alpha) * swa_param.data + alpha * param.data
[docs] def on_train_epoch_start(self, trainer, model): """Update learning rate according to SWA schedule.""" # Get current epoch - simplified with fallback if hasattr(trainer, 'current_epoch'): self.current_epoch = trainer.current_epoch else: self.current_epoch += 1 if self.current_epoch < self.swa_start_epoch: return # Calculate learning rate using the schedule t = self.current_epoch / self.max_epochs lr_ratio = self.swa_lr / self.lr_init if t <= 0.5: factor = 1.0 elif t <= 0.9: factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 else: factor = lr_ratio if isinstance(model, EnsembleKGE): optimizers = getattr(model, "optimizers", []) elif hasattr(trainer, "optimizers") and trainer.optimizers: optimizers = trainer.optimizers if isinstance(trainer.optimizers, list) else [trainer.optimizers] elif hasattr(trainer, "optimizer") and trainer.optimizer is not None: optimizers = [trainer.optimizer] else: optimizers = None if optimizers is not None: for optimizer in optimizers: for param_group in optimizer.param_groups: param_group['lr'] = self.lr_init * factor
[docs] @rank_zero_only def on_train_epoch_end(self, trainer, model): """Apply SWA averaging if conditions are met.""" if self.current_epoch < self.swa_start_epoch: return # Check if we should apply SWA if self.current_epoch >= self.swa_start_epoch and \ (self.current_epoch - self.swa_start_epoch) % self.swa_c_epochs == 0: running_model = model._orig_mod if isinstance(model, OptimizedModule) else model if self.swa_model is None: # Case: EnsembleKGE if isinstance(running_model, EnsembleKGE): self.swa_model = type(running_model)(running_model.models) self.swa_model.load_state_dict(running_model.state_dict()) else: self.swa_model = type(running_model)(running_model.args) self.swa_model.load_state_dict(running_model.state_dict()) if isinstance(running_model, EnsembleKGE): # Update each submodel and its SWA counterpart for submodel, swa_submodel in zip(running_model.models, self.swa_model.models): self.moving_average(swa_submodel, submodel, 1.0 / (self.swa_n + 1)) else: # Single model case self.moving_average(self.swa_model, running_model, 1.0 / (self.swa_n + 1)) self.swa_n += 1 if model.args["eval_every_n_epochs"] > 0 or model.args["eval_at_epochs"] is not None: trainer.wa_model = self.swa_model
[docs] @rank_zero_only def on_fit_end(self, trainer, model): """Replace main model with SWA model at the end of training.""" if self.swa_model is not None and self.swa_n > 0: # Copy SWA weights back to main model directly model.load_state_dict(self.swa_model.state_dict())
[docs] class SWAG(AbstractCallback): """Stochastic Weight Averaging - Gaussian (SWAG). Parameters ---------- swa_start_epoch : int Epoch at which to start collecting weights. swa_c_epochs : int Interval of epochs between updates. lr_init : float Initial LR. swa_lr : float LR in SWA / GSWA phase. max_epochs : int Total number of epochs. max_num_models : int Number of models to keep for low-rank covariance approx. var_clamp : float Clamp low variance for stability. """ def __init__(self, swa_start_epoch, swa_c_epochs:int=1, lr_init:float=0.1, swa_lr:float=0.05, max_epochs:int=None, max_num_models:int=20, var_clamp:float=1e-30): super().__init__() self.swa_start_epoch = swa_start_epoch self.swa_c_epochs = swa_c_epochs self.swa_lr = swa_lr self.lr_init = lr_init self.max_epochs = max_epochs self.max_num_models = max_num_models self.var_clamp = var_clamp # Stats for Gaussian averaging self.mean = None self.sq_mean = None self.deviations = [] self.gswa_n = 0 self.current_epoch = -1 def _collect_stats(self, model): """Collect weights to update mean, sq_mean, and covariance deviations. Math: # Let θ_i be the model parameter vector at collection step i # gswa_n = number of models collected so far (0-based) # Update running mean: # μ_{n+1} = (n * μ_n + θ_{n+1}) / (n + 1) # This is a cumulative moving average of model weights # Update running squared mean: # μ2_{n+1} = (n * μ2_n + θ_{n+1}^2) / (n + 1) # This is used to compute variance: alpha^2 ≈ μ2 - μ^2 # Compute deviation for low-rank covariance approximation: # dev_{n+1} = θ_{n+1} - μ_{n+1} # We store the last max_num_models deviations to approximate covariance """ # collect current model weights as a flat vector in cpu vec = nn.utils.parameters_to_vector(model.parameters()).detach().cpu() if self.mean is None: self.mean = vec.clone() self.sq_mean = vec.clone()**2 else: self.mean = (self.mean * self.gswa_n + vec) / (self.gswa_n + 1) self.sq_mean = (self.sq_mean * self.gswa_n + vec**2) / (self.gswa_n + 1) # low-rank covariance info dev = (vec - self.mean).unsqueeze(1) self.deviations.append(dev) if len(self.deviations) > self.max_num_models: self.deviations.pop(0) self.gswa_n += 1
[docs] def get_mean_and_var(self): """Return mean + variance (diagonal part).""" if self.mean is None: return None, None var = torch.clamp(self.sq_mean - self.mean**2, min=self.var_clamp) return self.mean, var
[docs] def sample(self, base_model, scale=0.5): """Sample new model from SWAG posterior distribution. Math: # From SWAG, posterior is approximated as: # θ ~ N(mean, Σ) # where Σ ≈ diag(var) + (1/(K-1)) * D D^T # - mean = running average of weights # - var = elementwise variance (sq_mean - mean^2) # - D = [dev_1, dev_2, ..., dev_K], deviations from mean (low-rank approx) # - K = number of collected models # Sampling step: # 1. θ_diag = mean + scale * std ⊙ ε, where ε ~ N(0, I) # 2. θ_lowrank = θ_diag + (D z) / sqrt(K-1), where z ~ N(0, I_K) # Final sample = θ_lowrank """ if self.mean is None: raise RuntimeError("No SWAG stats collected yet.") # Mean and variance mean, var = self.get_mean_and_var() std = torch.sqrt(var) # Diagonal Gaussian perturbation sample_vec = mean + scale * std * torch.randn_like(std) # Low-rank covariance perturbation if self.deviations: D = torch.cat(self.deviations, dim=1) # shape: [num_params, K] z = torch.randn(D.shape[1], device=D.device) # random vector in K-dim space denom = max(1, len(self.deviations) - 1) ** 0.5 # normalization sample_vec += (D @ z) / denom # Build new model and load sampled weights nn.utils.vector_to_parameters(sample_vec, base_model.parameters()) return base_model
[docs] def on_train_epoch_start(self, trainer, model): """Update LR schedule (same as SWA).""" if hasattr(trainer, 'current_epoch'): self.current_epoch = trainer.current_epoch else: self.current_epoch += 1 if self.current_epoch < self.swa_start_epoch: return # LR cosine-like schedule t = self.current_epoch / self.max_epochs lr_ratio = self.swa_lr / self.lr_init if t <= 0.5: factor = 1.0 elif t <= 0.9: factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 else: factor = lr_ratio # update trainer optimizers if hasattr(trainer, "optimizers") and trainer.optimizers: optimizers = trainer.optimizers if isinstance(trainer.optimizers, list) else [trainer.optimizers] elif hasattr(trainer, "optimizer") and trainer.optimizer is not None: optimizers = [trainer.optimizer] else: optimizers = [] for optimizer in optimizers: for pg in optimizer.param_groups: pg['lr'] = self.lr_init * factor
[docs] @rank_zero_only def on_train_epoch_end(self, trainer, model): """Collect Gaussian stats at the end of epochs after swa_start.""" if self.current_epoch >= self.swa_start_epoch and \ (self.current_epoch - self.swa_start_epoch) % self.swa_c_epochs == 0: running_model = model._orig_mod if isinstance(model, OptimizedModule) else model self._collect_stats(running_model)
[docs] @rank_zero_only def on_fit_end(self, trainer, model): """Set model weights to the collected SWAG mean at the end of training.""" sample_models = [] for i in range(self.max_num_models): model_copy = copy.deepcopy(model) sample_models.append(self.sample(model_copy)) ensemble_eval_report = evaluate_ensemble_link_prediction_performance( models=sample_models, triples=trainer.dataset.test_set, er_vocab=trainer.dataset.er_vocab.result(), weights=None, batch_size=trainer.num_training_batches, weighted_averaging=False, normalize_scores= False) ensemble_eval_report_path = os.path.join(model.args["full_storage_path"], "swag_eval_report.json") # Write the dictionary to the JSON file with open(ensemble_eval_report_path, 'w', encoding='utf-8') as f: json.dump(ensemble_eval_report, f, indent=4, ensure_ascii=False) if self.mean is not None: nn.utils.vector_to_parameters(self.mean.to(next(model.parameters()).device), model.parameters())
[docs] class EMA(AbstractCallback): """Exponential Moving Average (EMA) callback. Parameters ---------- ema_start_epoch : int Epoch to start EMA. decay : float EMA decay rate (typical: 0.99 - 0.9999) Math: θ_ema <- decay * θ_ema + (1 - decay) * θ max_epochs : int Maximum number of epochs. """ def __init__(self, ema_start_epoch: int, decay: float = 0.999, max_epochs: int = None, ema_c_epochs: int = 1): super().__init__() self.ema_start_epoch = ema_start_epoch self.decay = decay self.max_epochs = max_epochs self.ema_c_epochs = ema_c_epochs self.ema_model = None self.current_epoch = -1
[docs] @staticmethod def ema_update(ema_model, running_model, decay: float): """Update EMA model with exponential moving average of current model. Math: # EMA update: # θ_ema ← (1 - alpha) * θ_ema + alpha * θ # alpha = 1 - decay, where decay is the EMA smoothing factor (typical 0.99 - 0.999) # alpha controls how much of the current model θ contributes to the EMA # decay is fixed in code --> can be extended to scheduled """ with torch.no_grad(): ema_model.to(next(running_model.parameters()).device) for ema_param, param in zip(ema_model.parameters(), running_model.parameters()): ema_param.data.mul_(decay).add_(param.data, alpha=1.0 - decay)
[docs] @rank_zero_only def on_train_epoch_start(self, trainer, model): """Track current epoch.""" if hasattr(trainer, 'current_epoch'): self.current_epoch = trainer.current_epoch else: self.current_epoch += 1
[docs] @rank_zero_only def on_train_epoch_end(self, trainer, model): """Update EMA if past start epoch.""" if self.current_epoch >= self.ema_start_epoch and \ (self.current_epoch - self.ema_start_epoch) % self.ema_c_epochs == 0: running_model = model._orig_mod if isinstance(model, OptimizedModule) else model if self.ema_model is None: # Initialize EMA model as a copy of running model if isinstance(running_model, EnsembleKGE): self.ema_model = type(running_model)(running_model.models) self.ema_model.load_state_dict(running_model.state_dict()) else: self.ema_model = type(running_model)(running_model.args) self.ema_model.load_state_dict(running_model.state_dict()) # Always use fixed decay since we start late decay_t = self.decay if isinstance(running_model, EnsembleKGE): for submodel, ema_submodel in zip(running_model.models, self.ema_model.models): self.ema_update(ema_submodel, submodel, decay_t) else: self.ema_update(self.ema_model, running_model, decay_t) # Make EMA model available for evaluation if model.args.get("eval_every_n_epochs", 0) > 0 or model.args.get("eval_at_epochs") is not None: trainer.swa_model = self.ema_model
[docs] @rank_zero_only def on_fit_end(self, trainer, model): """Replace main model with EMA model at the end of training.""" if self.ema_model is not None: model.load_state_dict(self.ema_model.state_dict())
[docs] class TWA(AbstractCallback): """Train with Weight Averaging (TWA) using subspace projection + averaging. Parameters ---------- twa_start_epoch : int Epoch to start TWA. lr_init : float Learning rate used for β updates. num_samples : int Number of sampled weight snapshots to build projection subspace. reg_lambda : float Regularization coefficient for β updates. max_epochs : int Total number of training epochs. twa_c_epochs : int Interval of epochs between TWA updates. """ def __init__(self, twa_start_epoch: int, lr_init: float, num_samples: int = 5, reg_lambda: float = 0.0, max_epochs: int = None, twa_c_epochs: int = 1): super().__init__() self.twa_start_epoch = twa_start_epoch self.num_samples = num_samples self.reg_lambda = reg_lambda self.max_epochs = max_epochs self.lr_init = lr_init self.twa_c_epochs = twa_c_epochs # always update every epoch after start # State variables self.current_epoch = -1 self.weight_samples = [] self.twa_model = None self.base_weights = None self.P = None # projection matrix self.beta = None # coefficients in subspace
[docs] def sample_weights(self, model): """Collect sampled weights from the current model and maintain rolling buffer.""" w = torch.cat([p.data.view(-1).clone() for p in model.parameters()]) self.weight_samples.append(w) if len(self.weight_samples) > self.num_samples: self.weight_samples.pop(0) return w # return latest sample (sometimes useful)
[docs] def build_projection(self, weight_samples, k=None): """ Build projection subspace from collected weight samples. Args: weight_samples: list of flat weight tensors [(D,), ...] k: number of basis vectors to keep. Defaults to min(N, D). Returns: mean_w: (D,) base weight vector (average) P: (D, k) projection matrix with top-k basis directions """ W = torch.stack(weight_samples) # (N, D) mean_w = W.mean(dim=0) # (D,) centered = W - mean_w # (N, D) # SVD via torch.linalg.svd (safer) U, S, Vh = torch.linalg.svd(centered, full_matrices=False) V = Vh.T # (D, min(N,D)) if k is None: k = min(W.size(0), V.size(1)) # by default use N basis vectors P = V[:, :k] # (D, k) return mean_w, P
[docs] @rank_zero_only def on_train_epoch_start(self, trainer, model): """Track epoch.""" if hasattr(trainer, 'current_epoch'): self.current_epoch = trainer.current_epoch else: self.current_epoch += 1
[docs] @rank_zero_only def on_train_epoch_end(self, trainer, model): """Main TWA logic: build subspace and update in β space. # Math: # TWA weight update: # w_twa = mean_w + P * beta # mean_w = (1/n) * sum_i w_i (SWA baseline) # beta <- (1 - eta * lambda) * beta - eta * P^T * g # g = gradient of training loss w.r.t. full model weights # eta = learning rate, lambda = ridge regularization # P = orthonormal basis spanning sampled checkpoints {w_i} """ # Collect weight samples before TWA starts # Only sample every twa_c_epochs for more diversity: if self.current_epoch < self.twa_start_epoch: self.sample_weights(model) return if self.current_epoch >= self.twa_start_epoch and \ (self.current_epoch - self.twa_start_epoch) % self.twa_c_epochs == 0: running_model = model._orig_mod if isinstance(model, OptimizedModule) else model if self.twa_model is None: # Case: EnsembleKGE if isinstance(running_model, EnsembleKGE): self.twa_model = type(running_model)(running_model.models) self.twa_model.load_state_dict(running_model.state_dict()) else: self.twa_model = type(running_model)(running_model.args) self.twa_model.load_state_dict(running_model.state_dict()) # Build projection subspace using checkpoints {w_1, ..., w_n} mean_w, P = self.build_projection(self.weight_samples) self.base_weights = mean_w self.P = P # Initialize coefficients β in subspace self.beta = torch.zeros(P.size(1), device=mean_w.device) # Gradient projection and β update running_model = model._orig_mod if isinstance(model, OptimizedModule) else model # training gradients of running model grads = torch.cat([ (p.grad if p.grad is not None else torch.zeros_like(p)).view(-1) for p in running_model.parameters() ]) with torch.no_grad(): # Project gradient into subspace and apply ridge regularization self.beta = (1 - self.lr_init * self.reg_lambda) * self.beta \ - self.lr_init * (self.P.t() @ grads) # Reconstruct full TWA weights new_w = self.base_weights + self.P @ self.beta # Load weights into TWA model idx = 0 for p in self.twa_model.parameters(): numel = p.numel() p.data.copy_(new_w[idx: idx+numel].view_as(p)) idx += numel # Make TWA model available for evaluation if model.args.get("eval_every_n_epochs", 0) > 0 or model.args.get("eval_at_epochs") is not None: trainer.wa_model = self.twa_model
[docs] @rank_zero_only def on_fit_end(self, trainer, model): """Replace with TWA model at the end.""" if self.twa_model is not None: model.load_state_dict(self.twa_model.state_dict())