import os
import copy
import json
import torch
import torch.nn as nn
from torch._dynamo.eval_frame import OptimizedModule
from pytorch_lightning.utilities import rank_zero_only
from .abstracts import AbstractCallback
from dicee.models.ensemble import EnsembleKGE
from .evaluation.ensemble import evaluate_ensemble_link_prediction_performance
[docs]
class ASWA(AbstractCallback):
""" Adaptive stochastic weight averaging
ASWE keeps track of the validation performance and update s the ensemble model accordingly.
"""
def __init__(self, num_epochs, path):
super().__init__()
self.path=path
self.num_epochs=num_epochs
self.initial_eval_setting = None
self.epoch_count=0
self.alphas = []
self.val_aswa = -1
[docs]
@rank_zero_only
def on_fit_end(self, trainer, model):
# super().on_fit_end(trainer, model)
if self.initial_eval_setting:
# ADD this info back
trainer.evaluator.args.eval_model = self.initial_eval_setting
param_ensemble = torch.load(f"{self.path}/aswa.pt", torch.device("cpu"))
model.load_state_dict(param_ensemble)
[docs]
@staticmethod
def compute_mrr(trainer, model) -> float:
# (2) Enable eval mode.
eval_model = copy.deepcopy(model)
eval_model.eval()
# (3) MRR performance on the validation data of running model.
eval_model.to("cpu")
last_val_mrr_running_model = trainer.evaluator.eval(dataset=trainer.dataset,
trained_model=eval_model,
form_of_labelling=trainer.form_of_labelling,
during_training=True)["Val"]["MRR"]
del eval_model
return last_val_mrr_running_model
[docs]
def get_aswa_state_dict(self, model):
# (2) Question: Soft update or Rejection?!
if isinstance(model,EnsembleKGE):
ensemble_state_dict = torch.load(f"{self.path}/aswa.pt")
else:
ensemble_state_dict = torch.load(f"{self.path}/aswa.pt", torch.device(next(model.parameters()).device))
# Perform provision parameter update.
with torch.no_grad():
for k, parameters in model.state_dict().items():
if parameters.dtype == torch.float:
ensemble_state_dict[k] = (ensemble_state_dict[k] * sum(self.alphas) + parameters) / (1 + sum(self.alphas))
return ensemble_state_dict
[docs]
def decide(self, running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model):
"""
Perform Hard Update, software or rejection
Parameters
----------
running_model_state_dict
ensemble_state_dict
val_running_model
mrr_updated_ensemble_model
Returns
-------
"""
# (1) HARD UPDATE:
# If the validation performance of the running model is greater than
# the validation performance of updated ASWA and
# the validation performance of ASWA
if val_running_model > mrr_updated_ensemble_model and val_running_model > self.val_aswa:
"""Hard Update """
# (1.1) Save the running model as ASWA
torch.save(running_model_state_dict, f=f"{self.path}/aswa.pt")
# (2.1) Resect alphas/ensemble weights
self.alphas.clear()
# (2.2) Store the validation performance of ASWA
self.val_aswa = val_running_model
return True
# (2) SOFT UPDATE:
# If the validation performance of the running model is less than
# the validation performance of updated ASWA
if mrr_updated_ensemble_model > self.val_aswa:
"""Soft update"""
self.val_aswa = mrr_updated_ensemble_model
torch.save(ensemble_state_dict, f=f"{self.path}/aswa.pt")
self.alphas.append(1.0)
return True
# (3) Rejection:
if self.val_aswa > mrr_updated_ensemble_model:
""" Ignore """
self.alphas.append(0)
return True
[docs]
@rank_zero_only
def on_train_epoch_end(self, trainer, model):
# if (trainer.global_rank == trainer.local_rank == 0) is False:
# return None
if isinstance(model, OptimizedModule):
model = model._orig_mod
# (1) Increment epoch counter
self.epoch_count += 1
# (2) Save the given eval setting if it is not saved.
if self.initial_eval_setting is None:
self.initial_eval_setting = trainer.evaluator.args.eval_model
trainer.evaluator.args.eval_model = "val"
# (3) Compute MRR of the running model.
val_running_model = self.compute_mrr(trainer, model)
# (4) Initialize ASWA if it is not initialized.
if self.val_aswa == -1:
torch.save(model.state_dict(), f=f"{self.path}/aswa.pt")
self.alphas.append(1.0)
self.val_aswa = val_running_model
return True
else:
# (5) Load ASWA ensemble parameters.
ensemble_state_dict = self.get_aswa_state_dict(model)
# (6) Initialize ASWA ensemble with (5).
ensemble = copy.deepcopy(model)
ensemble.load_state_dict(ensemble_state_dict)
ensemble.to("cpu")
# (7) Evaluate (6) on the validation data, i.e., perform the lookahead operation.
mrr_updated_ensemble_model = trainer.evaluator.eval(dataset=trainer.dataset,
trained_model=ensemble,
form_of_labelling=trainer.form_of_labelling,
during_training=True)["Val"]["MRR"]
# print(f"| MRR Running {val_running_model:.4f} | MRR ASWA: {self.val_aswa:.4f} |ASWA|:{sum(self.alphas)}")
# (8) Decide whether ASWA should be updated via the current running model.
del ensemble
self.decide(model.state_dict(), ensemble_state_dict, val_running_model, mrr_updated_ensemble_model)
[docs]
class SWA(AbstractCallback):
"""Stochastic Weight Averaging callback.
Initialize SWA callback.
Parameters
----------
swa_start_epoch: int
The epoch at which to start SWA.
swa_c_epochs: int
The number of epochs to use for SWA.
lr_init: float
The initial learning rate.
swa_lr: float
The learning rate to use during SWA.
max_epochs: int
The maximum number of epochs. args.num_epochs
"""
def __init__(self, swa_start_epoch, swa_c_epochs:int=1, lr_init:float=0.1,
swa_lr:float=0.05, max_epochs :int=None):
super().__init__()
self.swa_start_epoch = swa_start_epoch
self.swa_c_epochs = swa_c_epochs
self.swa_lr = swa_lr
self.lr_init = lr_init
self.max_epochs = max_epochs
self.swa_model = None
self.swa_n = 0
self.current_epoch = -1
[docs]
@staticmethod
def moving_average(swa_model, running_model, alpha):
"""Update SWA model with moving average of current model.
Math:
# SWA update:
# θ_swa ← (1 - alpha) * θ_swa + alpha * θ
# alpha = 1 / (n + 1), where n = number of models already averaged
# alpha is tracked via self.swa_n in code"""
with torch.no_grad():
swa_model.to(next(running_model.parameters()).device)
for swa_param, param in zip(swa_model.parameters(), running_model.parameters()):
swa_param.data = (1.0 - alpha) * swa_param.data + alpha * param.data
[docs]
def on_train_epoch_start(self, trainer, model):
"""Update learning rate according to SWA schedule."""
# Get current epoch - simplified with fallback
if hasattr(trainer, 'current_epoch'):
self.current_epoch = trainer.current_epoch
else:
self.current_epoch += 1
if self.current_epoch < self.swa_start_epoch:
return
# Calculate learning rate using the schedule
t = self.current_epoch / self.max_epochs
lr_ratio = self.swa_lr / self.lr_init
if t <= 0.5:
factor = 1.0
elif t <= 0.9:
factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
else:
factor = lr_ratio
if isinstance(model, EnsembleKGE):
optimizers = getattr(model, "optimizers", [])
elif hasattr(trainer, "optimizers") and trainer.optimizers:
optimizers = trainer.optimizers if isinstance(trainer.optimizers, list) else [trainer.optimizers]
elif hasattr(trainer, "optimizer") and trainer.optimizer is not None:
optimizers = [trainer.optimizer]
else:
optimizers = None
if optimizers is not None:
for optimizer in optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] = self.lr_init * factor
[docs]
@rank_zero_only
def on_train_epoch_end(self, trainer, model):
"""Apply SWA averaging if conditions are met."""
if self.current_epoch < self.swa_start_epoch:
return
# Check if we should apply SWA
if self.current_epoch >= self.swa_start_epoch and \
(self.current_epoch - self.swa_start_epoch) % self.swa_c_epochs == 0:
running_model = model._orig_mod if isinstance(model, OptimizedModule) else model
if self.swa_model is None:
# Case: EnsembleKGE
if isinstance(running_model, EnsembleKGE):
self.swa_model = type(running_model)(running_model.models)
self.swa_model.load_state_dict(running_model.state_dict())
else:
self.swa_model = type(running_model)(running_model.args)
self.swa_model.load_state_dict(running_model.state_dict())
if isinstance(running_model, EnsembleKGE):
# Update each submodel and its SWA counterpart
for submodel, swa_submodel in zip(running_model.models, self.swa_model.models):
self.moving_average(swa_submodel, submodel, 1.0 / (self.swa_n + 1))
else:
# Single model case
self.moving_average(self.swa_model, running_model, 1.0 / (self.swa_n + 1))
self.swa_n += 1
if model.args["eval_every_n_epochs"] > 0 or model.args["eval_at_epochs"] is not None:
trainer.wa_model = self.swa_model
[docs]
@rank_zero_only
def on_fit_end(self, trainer, model):
"""Replace main model with SWA model at the end of training."""
if self.swa_model is not None and self.swa_n > 0:
# Copy SWA weights back to main model directly
model.load_state_dict(self.swa_model.state_dict())
[docs]
class SWAG(AbstractCallback):
"""Stochastic Weight Averaging - Gaussian (SWAG).
Parameters
----------
swa_start_epoch : int
Epoch at which to start collecting weights.
swa_c_epochs : int
Interval of epochs between updates.
lr_init : float
Initial LR.
swa_lr : float
LR in SWA / GSWA phase.
max_epochs : int
Total number of epochs.
max_num_models : int
Number of models to keep for low-rank covariance approx.
var_clamp : float
Clamp low variance for stability.
"""
def __init__(self, swa_start_epoch, swa_c_epochs:int=1,
lr_init:float=0.1, swa_lr:float=0.05,
max_epochs:int=None, max_num_models:int=20, var_clamp:float=1e-30):
super().__init__()
self.swa_start_epoch = swa_start_epoch
self.swa_c_epochs = swa_c_epochs
self.swa_lr = swa_lr
self.lr_init = lr_init
self.max_epochs = max_epochs
self.max_num_models = max_num_models
self.var_clamp = var_clamp
# Stats for Gaussian averaging
self.mean = None
self.sq_mean = None
self.deviations = []
self.gswa_n = 0
self.current_epoch = -1
def _collect_stats(self, model):
"""Collect weights to update mean, sq_mean, and covariance deviations.
Math:
# Let θ_i be the model parameter vector at collection step i
# gswa_n = number of models collected so far (0-based)
# Update running mean:
# μ_{n+1} = (n * μ_n + θ_{n+1}) / (n + 1)
# This is a cumulative moving average of model weights
# Update running squared mean:
# μ2_{n+1} = (n * μ2_n + θ_{n+1}^2) / (n + 1)
# This is used to compute variance: alpha^2 ≈ μ2 - μ^2
# Compute deviation for low-rank covariance approximation:
# dev_{n+1} = θ_{n+1} - μ_{n+1}
# We store the last max_num_models deviations to approximate covariance
"""
# collect current model weights as a flat vector in cpu
vec = nn.utils.parameters_to_vector(model.parameters()).detach().cpu()
if self.mean is None:
self.mean = vec.clone()
self.sq_mean = vec.clone()**2
else:
self.mean = (self.mean * self.gswa_n + vec) / (self.gswa_n + 1)
self.sq_mean = (self.sq_mean * self.gswa_n + vec**2) / (self.gswa_n + 1)
# low-rank covariance info
dev = (vec - self.mean).unsqueeze(1)
self.deviations.append(dev)
if len(self.deviations) > self.max_num_models:
self.deviations.pop(0)
self.gswa_n += 1
[docs]
def get_mean_and_var(self):
"""Return mean + variance (diagonal part)."""
if self.mean is None:
return None, None
var = torch.clamp(self.sq_mean - self.mean**2, min=self.var_clamp)
return self.mean, var
[docs]
def sample(self, base_model, scale=0.5):
"""Sample new model from SWAG posterior distribution.
Math:
# From SWAG, posterior is approximated as:
# θ ~ N(mean, Σ)
# where Σ ≈ diag(var) + (1/(K-1)) * D D^T
# - mean = running average of weights
# - var = elementwise variance (sq_mean - mean^2)
# - D = [dev_1, dev_2, ..., dev_K], deviations from mean (low-rank approx)
# - K = number of collected models
# Sampling step:
# 1. θ_diag = mean + scale * std ⊙ ε, where ε ~ N(0, I)
# 2. θ_lowrank = θ_diag + (D z) / sqrt(K-1), where z ~ N(0, I_K)
# Final sample = θ_lowrank
"""
if self.mean is None:
raise RuntimeError("No SWAG stats collected yet.")
# Mean and variance
mean, var = self.get_mean_and_var()
std = torch.sqrt(var)
# Diagonal Gaussian perturbation
sample_vec = mean + scale * std * torch.randn_like(std)
# Low-rank covariance perturbation
if self.deviations:
D = torch.cat(self.deviations, dim=1) # shape: [num_params, K]
z = torch.randn(D.shape[1], device=D.device) # random vector in K-dim space
denom = max(1, len(self.deviations) - 1) ** 0.5 # normalization
sample_vec += (D @ z) / denom
# Build new model and load sampled weights
nn.utils.vector_to_parameters(sample_vec, base_model.parameters())
return base_model
[docs]
def on_train_epoch_start(self, trainer, model):
"""Update LR schedule (same as SWA)."""
if hasattr(trainer, 'current_epoch'):
self.current_epoch = trainer.current_epoch
else:
self.current_epoch += 1
if self.current_epoch < self.swa_start_epoch:
return
# LR cosine-like schedule
t = self.current_epoch / self.max_epochs
lr_ratio = self.swa_lr / self.lr_init
if t <= 0.5:
factor = 1.0
elif t <= 0.9:
factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4
else:
factor = lr_ratio
# update trainer optimizers
if hasattr(trainer, "optimizers") and trainer.optimizers:
optimizers = trainer.optimizers if isinstance(trainer.optimizers, list) else [trainer.optimizers]
elif hasattr(trainer, "optimizer") and trainer.optimizer is not None:
optimizers = [trainer.optimizer]
else:
optimizers = []
for optimizer in optimizers:
for pg in optimizer.param_groups:
pg['lr'] = self.lr_init * factor
[docs]
@rank_zero_only
def on_train_epoch_end(self, trainer, model):
"""Collect Gaussian stats at the end of epochs after swa_start."""
if self.current_epoch >= self.swa_start_epoch and \
(self.current_epoch - self.swa_start_epoch) % self.swa_c_epochs == 0:
running_model = model._orig_mod if isinstance(model, OptimizedModule) else model
self._collect_stats(running_model)
[docs]
@rank_zero_only
def on_fit_end(self, trainer, model):
"""Set model weights to the collected SWAG mean at the end of training."""
sample_models = []
for i in range(self.max_num_models):
model_copy = copy.deepcopy(model)
sample_models.append(self.sample(model_copy))
ensemble_eval_report = evaluate_ensemble_link_prediction_performance(
models=sample_models,
triples=trainer.dataset.test_set,
er_vocab=trainer.dataset.er_vocab.result(),
weights=None,
batch_size=trainer.num_training_batches,
weighted_averaging=False, normalize_scores= False)
ensemble_eval_report_path = os.path.join(model.args["full_storage_path"], "swag_eval_report.json")
# Write the dictionary to the JSON file
with open(ensemble_eval_report_path, 'w', encoding='utf-8') as f:
json.dump(ensemble_eval_report, f, indent=4, ensure_ascii=False)
if self.mean is not None:
nn.utils.vector_to_parameters(self.mean.to(next(model.parameters()).device),
model.parameters())
[docs]
class EMA(AbstractCallback):
"""Exponential Moving Average (EMA) callback.
Parameters
----------
ema_start_epoch : int
Epoch to start EMA.
decay : float
EMA decay rate (typical: 0.99 - 0.9999)
Math: θ_ema <- decay * θ_ema + (1 - decay) * θ
max_epochs : int
Maximum number of epochs.
"""
def __init__(self, ema_start_epoch: int, decay: float = 0.999,
max_epochs: int = None, ema_c_epochs: int = 1):
super().__init__()
self.ema_start_epoch = ema_start_epoch
self.decay = decay
self.max_epochs = max_epochs
self.ema_c_epochs = ema_c_epochs
self.ema_model = None
self.current_epoch = -1
[docs]
@staticmethod
def ema_update(ema_model, running_model, decay: float):
"""Update EMA model with exponential moving average of current model.
Math:
# EMA update:
# θ_ema ← (1 - alpha) * θ_ema + alpha * θ
# alpha = 1 - decay, where decay is the EMA smoothing factor (typical 0.99 - 0.999)
# alpha controls how much of the current model θ contributes to the EMA
# decay is fixed in code --> can be extended to scheduled
"""
with torch.no_grad():
ema_model.to(next(running_model.parameters()).device)
for ema_param, param in zip(ema_model.parameters(), running_model.parameters()):
ema_param.data.mul_(decay).add_(param.data, alpha=1.0 - decay)
[docs]
@rank_zero_only
def on_train_epoch_start(self, trainer, model):
"""Track current epoch."""
if hasattr(trainer, 'current_epoch'):
self.current_epoch = trainer.current_epoch
else:
self.current_epoch += 1
[docs]
@rank_zero_only
def on_train_epoch_end(self, trainer, model):
"""Update EMA if past start epoch."""
if self.current_epoch >= self.ema_start_epoch and \
(self.current_epoch - self.ema_start_epoch) % self.ema_c_epochs == 0:
running_model = model._orig_mod if isinstance(model, OptimizedModule) else model
if self.ema_model is None:
# Initialize EMA model as a copy of running model
if isinstance(running_model, EnsembleKGE):
self.ema_model = type(running_model)(running_model.models)
self.ema_model.load_state_dict(running_model.state_dict())
else:
self.ema_model = type(running_model)(running_model.args)
self.ema_model.load_state_dict(running_model.state_dict())
# Always use fixed decay since we start late
decay_t = self.decay
if isinstance(running_model, EnsembleKGE):
for submodel, ema_submodel in zip(running_model.models, self.ema_model.models):
self.ema_update(ema_submodel, submodel, decay_t)
else:
self.ema_update(self.ema_model, running_model, decay_t)
# Make EMA model available for evaluation
if model.args.get("eval_every_n_epochs", 0) > 0 or model.args.get("eval_at_epochs") is not None:
trainer.swa_model = self.ema_model
[docs]
@rank_zero_only
def on_fit_end(self, trainer, model):
"""Replace main model with EMA model at the end of training."""
if self.ema_model is not None:
model.load_state_dict(self.ema_model.state_dict())
[docs]
class TWA(AbstractCallback):
"""Train with Weight Averaging (TWA) using subspace projection + averaging.
Parameters
----------
twa_start_epoch : int
Epoch to start TWA.
lr_init : float
Learning rate used for β updates.
num_samples : int
Number of sampled weight snapshots to build projection subspace.
reg_lambda : float
Regularization coefficient for β updates.
max_epochs : int
Total number of training epochs.
twa_c_epochs : int
Interval of epochs between TWA updates.
"""
def __init__(self, twa_start_epoch: int, lr_init: float,
num_samples: int = 5, reg_lambda: float = 0.0,
max_epochs: int = None, twa_c_epochs: int = 1):
super().__init__()
self.twa_start_epoch = twa_start_epoch
self.num_samples = num_samples
self.reg_lambda = reg_lambda
self.max_epochs = max_epochs
self.lr_init = lr_init
self.twa_c_epochs = twa_c_epochs # always update every epoch after start
# State variables
self.current_epoch = -1
self.weight_samples = []
self.twa_model = None
self.base_weights = None
self.P = None # projection matrix
self.beta = None # coefficients in subspace
[docs]
def sample_weights(self, model):
"""Collect sampled weights from the current model and maintain rolling buffer."""
w = torch.cat([p.data.view(-1).clone() for p in model.parameters()])
self.weight_samples.append(w)
if len(self.weight_samples) > self.num_samples:
self.weight_samples.pop(0)
return w # return latest sample (sometimes useful)
[docs]
def build_projection(self, weight_samples, k=None):
"""
Build projection subspace from collected weight samples.
Args:
weight_samples: list of flat weight tensors [(D,), ...]
k: number of basis vectors to keep. Defaults to min(N, D).
Returns:
mean_w: (D,) base weight vector (average)
P: (D, k) projection matrix with top-k basis directions
"""
W = torch.stack(weight_samples) # (N, D)
mean_w = W.mean(dim=0) # (D,)
centered = W - mean_w # (N, D)
# SVD via torch.linalg.svd (safer)
U, S, Vh = torch.linalg.svd(centered, full_matrices=False)
V = Vh.T # (D, min(N,D))
if k is None:
k = min(W.size(0), V.size(1)) # by default use N basis vectors
P = V[:, :k] # (D, k)
return mean_w, P
[docs]
@rank_zero_only
def on_train_epoch_start(self, trainer, model):
"""Track epoch."""
if hasattr(trainer, 'current_epoch'):
self.current_epoch = trainer.current_epoch
else:
self.current_epoch += 1
[docs]
@rank_zero_only
def on_train_epoch_end(self, trainer, model):
"""Main TWA logic: build subspace and update in β space.
# Math:
# TWA weight update:
# w_twa = mean_w + P * beta
# mean_w = (1/n) * sum_i w_i (SWA baseline)
# beta <- (1 - eta * lambda) * beta - eta * P^T * g
# g = gradient of training loss w.r.t. full model weights
# eta = learning rate, lambda = ridge regularization
# P = orthonormal basis spanning sampled checkpoints {w_i} """
# Collect weight samples before TWA starts
# Only sample every twa_c_epochs for more diversity:
if self.current_epoch < self.twa_start_epoch:
self.sample_weights(model)
return
if self.current_epoch >= self.twa_start_epoch and \
(self.current_epoch - self.twa_start_epoch) % self.twa_c_epochs == 0:
running_model = model._orig_mod if isinstance(model, OptimizedModule) else model
if self.twa_model is None:
# Case: EnsembleKGE
if isinstance(running_model, EnsembleKGE):
self.twa_model = type(running_model)(running_model.models)
self.twa_model.load_state_dict(running_model.state_dict())
else:
self.twa_model = type(running_model)(running_model.args)
self.twa_model.load_state_dict(running_model.state_dict())
# Build projection subspace using checkpoints {w_1, ..., w_n}
mean_w, P = self.build_projection(self.weight_samples)
self.base_weights = mean_w
self.P = P
# Initialize coefficients β in subspace
self.beta = torch.zeros(P.size(1), device=mean_w.device)
# Gradient projection and β update
running_model = model._orig_mod if isinstance(model, OptimizedModule) else model
# training gradients of running model
grads = torch.cat([
(p.grad if p.grad is not None else torch.zeros_like(p)).view(-1)
for p in running_model.parameters()
])
with torch.no_grad():
# Project gradient into subspace and apply ridge regularization
self.beta = (1 - self.lr_init * self.reg_lambda) * self.beta \
- self.lr_init * (self.P.t() @ grads)
# Reconstruct full TWA weights
new_w = self.base_weights + self.P @ self.beta
# Load weights into TWA model
idx = 0
for p in self.twa_model.parameters():
numel = p.numel()
p.data.copy_(new_w[idx: idx+numel].view_as(p))
idx += numel
# Make TWA model available for evaluation
if model.args.get("eval_every_n_epochs", 0) > 0 or model.args.get("eval_at_epochs") is not None:
trainer.wa_model = self.twa_model
[docs]
@rank_zero_only
def on_fit_end(self, trainer, model):
"""Replace with TWA model at the end."""
if self.twa_model is not None:
model.load_state_dict(self.twa_model.state_dict())