Source code for dicee.callbacks

import datetime
import time
import numpy as np
import torch
import os
import json

import dicee.models.base_model
from .static_funcs import save_checkpoint_model, save_pickle
from .abstracts import AbstractCallback
import pandas as pd
from collections import defaultdict
import math
from torch.optim.lr_scheduler import LambdaLR
from .eval_static_funcs import evaluate_ensemble_link_prediction_performance


[docs] class AccumulateEpochLossCallback(AbstractCallback): def __init__(self, path: str): super().__init__() self.path = path
[docs] def on_fit_end(self, trainer, model) -> None: """ Store epoch loss Parameter --------- trainer: model: Returns --------- None """ pd.DataFrame(model.loss_history, columns=['EpochLoss']).to_csv(f'{self.path}/epoch_losses.csv')
[docs] class PrintCallback(AbstractCallback): def __init__(self): super().__init__() self.start_time = time.time()
[docs] def on_fit_start(self, trainer, pl_module): # print(pl_module) # print(pl_module.summarize()) # print(pl_module.selected_optimizer) print(f"\nTraining is starting {datetime.datetime.now()}...")
[docs] def on_fit_end(self, trainer, pl_module): training_time = time.time() - self.start_time if 60 > training_time: message = f'{training_time:.3f} seconds.' elif 60 * 60 > training_time > 60: message = f'{training_time / 60:.3f} minutes.' elif training_time > 60 * 60: message = f'{training_time / (60 * 60):.3f} hours.' else: message = f'{training_time:.3f} seconds.' print(f"Training Runtime: {message}\n")
[docs] def on_train_batch_end(self, *args, **kwargs): return
[docs] def on_train_epoch_end(self, *args, **kwargs): return
[docs] class KGESaveCallback(AbstractCallback): def __init__(self, every_x_epoch: int, max_epochs: int, path: str): super().__init__() self.every_x_epoch = every_x_epoch self.max_epochs = max_epochs self.epoch_counter = 0 self.path = path if self.every_x_epoch is None: self.every_x_epoch = max(self.max_epochs // 2, 1)
[docs] def on_train_batch_end(self, *args, **kwargs): return
[docs] def on_fit_start(self, trainer, pl_module): pass
[docs] def on_train_epoch_end(self, *args, **kwargs): pass
[docs] def on_fit_end(self, *args, **kwargs): pass
[docs] def on_epoch_end(self, model, trainer, **kwargs): if self.epoch_counter % self.every_x_epoch == 0 and self.epoch_counter > 1: print(f'\nStoring model {self.epoch_counter}...') save_checkpoint_model(model, path=self.path + f'/model_at_{str(self.epoch_counter)}_' f'epoch_{str(str(datetime.datetime.now()))}.pt') self.epoch_counter += 1
[docs] class PseudoLabellingCallback(AbstractCallback): def __init__(self, data_module, kg, batch_size): super().__init__() self.data_module = data_module self.kg = kg self.num_of_epochs = 0 self.unlabelled_size = len(self.kg.unlabelled_set) self.batch_size = batch_size
[docs] def create_random_data(self): entities = torch.randint(low=0, high=self.kg.num_entities, size=(self.batch_size, 2)) relations = torch.randint(low=0, high=self.kg.num_relations, size=(self.batch_size,)) # unlabelled triples return torch.stack((entities[:, 0], relations, entities[:, 1]), dim=1)
[docs] def on_epoch_end(self, trainer, model): # Create random triples # if trainer.current_epoch < 10: # return None # Increase it size, Now we increase it. model.eval() with torch.no_grad(): # (1) Create random triples # unlabelled_input_batch = self.create_random_data() # (2) or use unlabelled batch unlabelled_input_batch = self.kg.unlabelled_set[ torch.randint(low=0, high=self.unlabelled_size, size=(self.batch_size,))] # (2) Predict unlabelled batch, and use prediction as pseudo-labels pseudo_label = torch.sigmoid(model(unlabelled_input_batch)) selected_triples = unlabelled_input_batch[pseudo_label >= .90] if len(selected_triples) > 0: # Update dataset self.data_module.train_set_idx = np.concatenate( (self.data_module.train_set_idx, selected_triples.detach().numpy()), axis=0) trainer.train_dataloader = self.data_module.train_dataloader() print(f'\tEpoch:{trainer.current_epoch}: Pseudo-labelling\t |D|= {len(self.data_module.train_set_idx)}') model.train()
[docs] def estimate_q(eps): """ estimate rate of convergence q from sequence esp""" x = np.arange(len(eps) - 1) y = np.log(np.abs(np.diff(np.log(eps)))) line = np.polyfit(x, y, 1) # fit degree 1 polynomial q = np.exp(line[0]) # find q return q
[docs] def compute_convergence(seq, i): assert len(seq) >= i > 0 return estimate_q(seq[-i:] / (np.arange(i) + 1))
[docs] class ASWA(AbstractCallback): """ Adaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly. """ def __init__(self, num_epochs, path): super().__init__() self.path=path self.num_epochs=num_epochs self.initial_eval_setting = None self.epoch_count=0 self.alphas = [] self.val_aswa = -1
[docs] def on_fit_end(self, trainer, model): # super().on_fit_end(trainer, model) if self.initial_eval_setting: # ADD this info back trainer.evaluator.args.eval_model = self.initial_eval_setting if trainer.global_rank==trainer.local_rank==0: param_ensemble = torch.load(f"{self.path}/aswa.pt", torch.device("cpu")) model.load_state_dict(param_ensemble)
[docs] @staticmethod def compute_mrr(trainer, model) -> float: # (2) Enable eval mode. model.eval() # (3) MRR performance on the validation data of running model. device_name = model.device model.to("cpu") last_val_mrr_running_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=model, form_of_labelling=trainer.form_of_labelling, during_training=True)["Val"]["MRR"] model.to(device_name) # (4) Enable train mode. model.train() return last_val_mrr_running_model
[docs] def get_aswa_state_dict(self, model): # (2) Question: Soft update or Rejection?! ensemble_state_dict = torch.load(f"{self.path}/aswa.pt", torch.device(model.device)) # Perform provision parameter update. with torch.no_grad(): for k, parameters in model.state_dict().items(): if parameters.dtype == torch.float: ensemble_state_dict[k] = (ensemble_state_dict[k] * sum(self.alphas) + parameters) / (1 + sum(self.alphas)) return ensemble_state_dict
[docs] def decide(self, running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model): """ Perform Hard Update, software or rejection Parameters ---------- running_model_state_dict ensemble_state_dict val_running_model mrr_updated_ensemble_model Returns ------- """ # (1) HARD UPDATE: # If the validation performance of the running model is greater than # the validation performance of updated ASWA and # the validation performance of ASWA if val_running_model > mrr_updated_ensemble_model and val_running_model > self.val_aswa: """Hard Update """ # (1.1) Save the running model as ASWA torch.save(running_model_state_dict, f=f"{self.path}/aswa.pt") # (2.1) Resect alphas/ensemble weights self.alphas.clear() # (2.2) Store the validation performance of ASWA self.val_aswa = val_running_model return True # (2) SOFT UPDATE: # If the validation performance of the running model is less than # the validation performance of updated ASWA if mrr_updated_ensemble_model > self.val_aswa: """Soft update""" self.val_aswa = mrr_updated_ensemble_model torch.save(ensemble_state_dict, f=f"{self.path}/aswa.pt") self.alphas.append(1.0) return True # (3) Rejection: if self.val_aswa > mrr_updated_ensemble_model: """ Ignore """ self.alphas.append(0) return True
[docs] def on_train_epoch_end(self, trainer, model): if (trainer.global_rank == trainer.local_rank == 0) is False: return None # (1) Increment epoch counter self.epoch_count += 1 # (2) Save the given eval setting if it is not saved. if self.initial_eval_setting is None: self.initial_eval_setting = trainer.evaluator.args.eval_model trainer.evaluator.args.eval_model = "val" # (3) Compute MRR of the running model. val_running_model = self.compute_mrr(trainer, model) # (4) Initialize ASWA if it is not initialized. if self.val_aswa == -1: torch.save(model.state_dict(), f=f"{self.path}/aswa.pt") self.alphas.append(1.0) self.val_aswa = val_running_model return True else: # (5) Load ASWA ensemble parameters. ensemble_state_dict = self.get_aswa_state_dict(model) # (6) Initialize ASWA ensemble with (5). ensemble = type(model)(model.args) ensemble.load_state_dict(ensemble_state_dict) # (7) Evaluate (6) on the validation data, i.e., perform the lookahead operation. mrr_updated_ensemble_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=ensemble, form_of_labelling=trainer.form_of_labelling, during_training=True)["Val"]["MRR"] # print(f"| MRR Running {val_running_model:.4f} | MRR ASWA: {self.val_aswa:.4f} |ASWA|:{sum(self.alphas)}") # (8) Decide whether ASWA should be updated via the current running model. self.decide(model.state_dict(), ensemble_state_dict, val_running_model, mrr_updated_ensemble_model)
[docs] class Eval(AbstractCallback): def __init__(self, path, epoch_ratio: int = None): super().__init__() self.path = path self.reports = [] self.epoch_ratio = epoch_ratio if epoch_ratio is not None else 1 self.epoch_counter = 0
[docs] def on_fit_start(self, trainer, model): pass
[docs] def on_fit_end(self, trainer, model): save_pickle(data=self.reports, file_path=trainer.attributes.full_storage_path + '/evals_per_epoch') """ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 7)) for (p,q), mrr in pairs_to_train_mrr.items(): ax1.plot(mrr, label=f'{p},{q}') ax1.set_ylabel('Train MRR') for (p,q), mrr in pairs_to_val_mrr.items(): ax2.plot(mrr, label=f'{p},{q}') ax2.set_ylabel('Val MRR') plt.legend() plt.xlabel('Epochs') plt.savefig('{full_storage_path}train_val_mrr.pdf') plt.show() """
[docs] def on_train_epoch_end(self, trainer, model): self.epoch_counter += 1 if self.epoch_counter % self.epoch_ratio == 0: model.eval() report = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=model, form_of_labelling=trainer.form_of_labelling, during_training=True) model.train() self.reports.append(report)
[docs] def on_train_batch_end(self, *args, **kwargs): return
[docs] class KronE(AbstractCallback): def __init__(self): super().__init__() self.f = None
[docs] @staticmethod def batch_kronecker_product(a, b): """ Kronecker product of matrices a and b with leading batch dimensions. Batch dimensions are broadcast. The number of them mush :type a: torch.Tensor :type b: torch.Tensor :rtype: torch.Tensor """ a, b = a.unsqueeze(1), b.unsqueeze(1) siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:])) res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4) siz0 = res.shape[:-4] res = res.reshape(siz0 + siz1) return res.flatten(1)
[docs] def get_kronecker_triple_representation(self, indexed_triple: torch.LongTensor): """ Get kronecker embeddings """ n, d = indexed_triple.shape assert d == 3 # Get the embeddings head_ent_emb, rel_ent_emb, tail_ent_emb = self.f(indexed_triple) head_ent_kron_emb = self.batch_kronecker_product(*torch.hsplit(head_ent_emb, 2)) rel_ent_kron_emb = self.batch_kronecker_product(*torch.hsplit(rel_ent_emb, 2)) tail_ent_kron_emb = self.batch_kronecker_product(*torch.hsplit(tail_ent_emb, 2)) return torch.cat((head_ent_emb, head_ent_kron_emb), dim=1), \ torch.cat((rel_ent_emb, rel_ent_kron_emb), dim=1), \ torch.cat((tail_ent_emb, tail_ent_kron_emb), dim=1)
[docs] def on_fit_start(self, trainer, model): if isinstance(model.normalize_head_entity_embeddings, dicee.models.base_model.IdentityClass): self.f = model.get_triple_representation model.get_triple_representation = self.get_kronecker_triple_representation else: raise NotImplementedError('Normalizer should be reinitialized')
[docs] class Perturb(AbstractCallback): """ A callback for a three-Level Perturbation Input Perturbation: During training an input x is perturbed by randomly replacing its element. In the context of knowledge graph embedding models, x can denote a triple, a tuple of an entity and a relation, or a tuple of two entities. A perturbation means that a component of x is randomly replaced by an entity or a relation. Parameter Perturbation: Output Perturbation: """ def __init__(self, level: str = "input", ratio: float = 0.0, method: str = None, scaler: float = None, frequency=None): """ level in {input, param, output} ratio:float btw [0,1] a percentage of mini-batch data point to be perturbed. method = ? """ super().__init__() assert level in {"input", "param", "out"} assert ratio >= 0.0 self.level = level self.ratio = ratio self.method = method self.scaler = scaler self.frequency = frequency # per epoch, per mini-batch ?
[docs] def on_train_batch_start(self, trainer, model, batch, batch_idx): # Modifications should be in-place # (1) Extract the input and output data points in a given batch. x, y = batch n, _ = x.shape assert n > 0 # (2) Compute the number of perturbed data points. num_of_perturbed_data = int(n * self.ratio) if num_of_perturbed_data == 0: return None # (3) Detect the device on which data points reside device = x.get_device() if device == -1: device = "cpu" # (4) Sample random integers from 0 to n without replacement and take num_of_perturbed_data of tem random_indices = torch.randperm(n, device=device)[:num_of_perturbed_data] # (5) Apply perturbation depending on the level. # (5.1) Apply Input level perturbation. if self.level == "input": if torch.rand(1) > 0.5: # (5.1.1) Perturb input via heads: Sample random indices for heads. perturbation = torch.randint(low=0, high=model.num_entities, size=(num_of_perturbed_data,), device=device) # Replace the head entities with (5.1.1) on given randomly selected data points in a mini-batch. x[random_indices] = torch.column_stack((perturbation, x[:, 1][random_indices])) else: # (5.1.2) Perturb input via relations : Sample random indices for relations. perturbation = torch.randint(low=0, high=model.num_relations, size=(num_of_perturbed_data,), device=device) # Replace the relations with (5.1.2) on given randomly selected data points in a mini-batch. x[random_indices] = torch.column_stack( (x[:, 0][random_indices], perturbation)) # (5.2) Apply Parameter level perturbation. elif self.level == "param": h, r = torch.hsplit(x, 2) # (5.2.1) Apply Gaussian Perturbation if self.method == "GN": if torch.rand(1) > 0.0: # (5.2.1.1) Apply Gaussian Perturbation on heads. h_selected = h[random_indices] with torch.no_grad(): model.entity_embeddings.weight[h_selected] += torch.normal(mean=0, std=self.scaler, size=model.entity_embeddings.weight[ h_selected].shape, device=model.device) else: # (5.2.1.2) Apply Gaussian Perturbation on relations. r_selected = r[random_indices] with (torch.no_grad()): model.relation_embeddings.weight[r_selected] += torch.normal(mean=0, std=self.scaler, size= model.entity_embeddings.weight[ r_selected].shape, device=model.device) # (5.2.2) Apply Random Perturbation elif self.method == "RN": if torch.rand(1) > 0.0: # (5.2.2.1) Apply Random Perturbation on heads. h_selected = h[random_indices] with torch.no_grad(): model.entity_embeddings.weight[h_selected] += torch.rand( size=model.entity_embeddings.weight[h_selected].shape, device=model.device) * self.scaler else: # (5.2.2.2) Apply Random Perturbation on relations. r_selected = r[random_indices] with torch.no_grad(): model.relation_embeddings.weight[r_selected] += torch.rand( size=model.entity_embeddings.weight[r_selected].shape, device=model.device) * self.scaler else: raise RuntimeError(f"--method is given as {self.method}!") elif self.level == "out": # (5.3) Apply output level perturbation. if self.method == "Soft": # (5.3) Output level soft perturbation resembles label smoothing. # (5.3.1) Compute the perturbation rate. perturb = torch.rand(1, device=model.device) * self.scaler # https://pytorch.org/docs/stable/generated/torch.where.html # 1.0 => 1.0 - perturb # 0.0 => perturb # (5.3.2) Reduces 1s and increases 0s via (5.2.1) batch[1][random_indices] = torch.where(batch[1][random_indices] == 1.0, 1.0 - perturb, perturb) elif self.method == "Hard": # (5.3) Output level hard perturbation flips 1s to 0 and 0s to 1s. batch[1][random_indices] = torch.where(batch[1][random_indices] == 1.0, 0.0, 1.0) else: raise NotImplementedError(f"{self.level}") else: raise RuntimeError(f"--level is given as {self.level}!")
[docs] class PeriodicEvalCallback(AbstractCallback): """ Callback to periodically evaluate the model and optionally save checkpoints during training. Evaluates at regular intervals (every N epochs) or at explicitly specified epochs. Stores evaluation reports and model checkpoints. """ def __init__(self, experiment_path: str, max_epochs: int, eval_every_n_epoch: int = 0, eval_at_epochs: list = None, save_model_every_n_epoch: bool = True, n_epochs_eval_model: str = "val_test"): """ Initialize the PeriodicEvalCallback. Parameters ---------- experiment_path : str Directory where evaluation reports and model checkpoints will be saved. max_epochs : int Maximum number of training epochs. eval_every_n_epoch : int, optional Evaluate every N epochs. Ignored if eval_at_epochs is provided. eval_at_epochs : list, optional List of specific epochs at which to evaluate. If provided and non-empty, overrides eval_every_n_epoch. save_model_every_n_epoch : bool, optional Whether to save model checkpoints at each evaluation epoch. n_epochs_eval_model : str, optional Evaluation mode for N epochs. Default is "val_test". """ super().__init__() self.experiment_dir = experiment_path self.max_epochs = max_epochs self.epoch_counter = 0 self.save_model_every_n_epoch = save_model_every_n_epoch self.reports = defaultdict(dict) self.n_epochs_eval_model = n_epochs_eval_model self.default_eval_model = None # Determine evaluation epochs: combine explicit list and interval if provided eval_epochs_set = set() if eval_at_epochs and len(eval_at_epochs) > 0: eval_epochs_set.update(eval_at_epochs) if eval_every_n_epoch > 0: eval_epochs_set.update(range(eval_every_n_epoch, max_epochs + 1, eval_every_n_epoch)) self.eval_epochs = eval_epochs_set # Prepare directory for model checkpoints if needed if self.save_model_every_n_epoch: self.n_epochs_storage_path = os.path.join(self.experiment_dir, 'models_n_epochs') os.makedirs(self.n_epochs_storage_path, exist_ok=True)
[docs] def on_fit_end(self, trainer, model): """ Called at the end of training. Saves final evaluation report.""" report_path = os.path.join(self.experiment_dir, 'eval_report_n_epochs.json') with open(report_path, 'w') as f: json.dump(self.reports, f, indent=4)
[docs] def on_train_epoch_end(self, trainer, model): """ Called at the end of each training epoch. Performs evaluation and checkpointing if scheduled. """ self.epoch_counter += 1 # Check if current epoch is scheduled for evaluation if self.epoch_counter not in self.eval_epochs: return # Store default evaluation mode once if self.default_eval_model is None: self.default_eval_model = trainer.evaluator.args.eval_model # Skip evaluation if default model already covers all eval splits and it's the final epoch if self.epoch_counter == self.max_epochs: default_splits = set(self.default_eval_model.split('_')) required_splits = set(self.n_epochs_eval_model.split('_')) if required_splits.issubset(default_splits): return # Set evaluation mode for this scheduled epoch trainer.evaluator.args.eval_model = self.n_epochs_eval_model # Prepare evaluation model eval_model = None if model.args.get("swa"): eval_model = trainer.swa_model elif model.args.get("adaptive_swa"): # Load ASWA weights and apply to a deepcopy of the model aswa_path = os.path.join(self.experiment_dir, "aswa.pt") aswa_ensemble_params = torch.load(aswa_path, map_location="cpu") # Clone model and apply ASWA weights eval_model = type(model)(model.args) eval_model.load_state_dict(aswa_ensemble_params) else: eval_model = model # save device and move to CPU for evaluation to save memory device = model.device eval_model.to('cpu') eval_model.eval() report = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=eval_model, form_of_labelling=trainer.form_of_labelling, during_training=True) # Restore model to original device and mode eval_model.to(device) eval_model.train() # Restore evaluation mode trainer.evaluator.args.eval_model = self.default_eval_model # Store evaluation report self.reports[f'epoch_{self.epoch_counter}_eval'] = report # Save model checkpoint if needed if self.save_model_every_n_epoch: save_path = os.path.join(self.n_epochs_storage_path, f'model_at_epoch_{self.epoch_counter}.pt') save_checkpoint_model(eval_model, path=save_path) # Free memory only if eval_model is a separate instance (ASWA case) if model.args.get("adaptive_swa") and eval_model is not model: del eval_model
[docs] class LRScheduler(AbstractCallback): """ Callback for managing learning rate scheduling and model snapshots. Supports cosine annealing ("cca"), MMCCLR ("mmcclr"), and their deferred (warmup) variants: - "deferred_cca" - "deferred_mmcclr" At the end of each learning rate cycle, the model can optionally be saved as a snapshot. """ def __init__( self, adaptive_lr_config: dict, total_epochs: int, experiment_dir: str, eta_max: float = 0.1, snapshot_dir: str = "snapshots", ): """ Initialize the LR scheduler callback. Args: adaptive_lr_config (dict): Configuration dictionary containing LR scheduling parameters. Can include: scheduler_name, lr_min, num_cycles, weighted_ensemble, n_snapshots total_epochs (int): Total number of training epochs (args.num_epochs) experiment_dir (str): Directory for the experiment, used as base for snapshots. eta_max (float, optional): Maximum learning rate at the start of each cycle. passed from `args.lr`. Default is 0.1. snapshot_dir (str, optional): Subdirectory inside experiment_dir where snapshots will be saved. Default is "snapshots". """ # Validate and set defaults for configuration self._validate_and_set_config(adaptive_lr_config, eta_max) self.total_epochs = total_epochs self.experiment_dir = experiment_dir self.snapshot_dir = os.path.join(experiment_dir, snapshot_dir) os.makedirs(self.snapshot_dir, exist_ok=True) assert self.eta_max > self.eta_min, \ f"Max Learning Rate ({self.eta_max}) must be greater than Min Learning Rate ({self.eta_min})" # Calculate warmup epochs only for deferred schedulers if self.scheduler_name.startswith("deferred"): # Use formula: defer for (n_cycles - n_snapshots) cycles deferred_cycles = self.n_cycles - self.n_snapshots self.warmup_epochs = int(deferred_cycles / self.n_cycles * self.total_epochs) else: # Non-deferred schedulers don't use warmup self.warmup_epochs = 0 # Placeholders to be set during training self.batches_per_epoch = None self.total_steps = None self.cycle_length = None self.warmup_steps = None self.lr_lambda = None self.scheduler = None self.step_count = 0 self.snapshot_loss = defaultdict(float) def _validate_and_set_config(self, config: dict, eta_max: float): """ Validate the adaptive_lr_config and set default values for missing parameters. """ # Default configuration defaults = { "scheduler_name": "cca", "lr_min": 0.01, "num_cycles": 10, "weighted_ensemble": True, "n_snapshots": 5 } # Validate config is a dictionary if not isinstance(config, dict): raise ValueError("adaptive_lr_config must be a dictionary") # Validate scheduler_name if "scheduler_name" in config: valid_schedulers = ["cca", "mmcclr", "deferred_cca", "deferred_mmcclr"] if config["scheduler_name"] not in valid_schedulers: raise ValueError(f"Invalid scheduler_name '{config['scheduler_name']}'. " f"Must be one of: {valid_schedulers}") # Validate lr_min if "lr_min" in config: lr_min = config["lr_min"] if not isinstance(lr_min, (int, float)) or lr_min <= 0: raise ValueError(f"lr_min must be a positive number, got: {lr_min}") if lr_min >= eta_max: raise ValueError(f"lr_min ({lr_min}) must be less than eta_max ({eta_max})") # Validate num_cycles if "num_cycles" in config: num_cycles = config["num_cycles"] if not isinstance(num_cycles, (int, float)) or num_cycles <= 0: raise ValueError(f"num_cycles must be a positive number, got: {num_cycles}") # Validate n_snapshots if "n_snapshots" in config: n_snapshots = config["n_snapshots"] if not isinstance(n_snapshots, int) or n_snapshots <= 0: raise ValueError(f"n_snapshots must be a positive integer, got: {n_snapshots}") # Validate weighted_ensemble if "weighted_ensemble" in config: weighted_ensemble = config["weighted_ensemble"] if not isinstance(weighted_ensemble, bool): raise ValueError(f"weighted_ensemble must be a boolean, got: {weighted_ensemble}") # Set attributes with defaults for missing values self.scheduler_name = config.get("scheduler_name", defaults["scheduler_name"]).lower() self.eta_min = config.get("lr_min", defaults["lr_min"]) self.n_cycles = config.get("num_cycles", defaults["num_cycles"]) self.weighted_ensemble = config.get("weighted_ensemble", defaults["weighted_ensemble"]) self.n_snapshots = config.get("n_snapshots", defaults["n_snapshots"]) self.eta_max = eta_max assert self.n_snapshots <= self.n_cycles, \ f"n_snapshots ({self.n_snapshots}) must be less than or equal to num_cycles ({self.n_cycles})" print(f"LRScheduler initialized with config: {config}") print(f"Using: scheduler_name={self.scheduler_name}, eta_min={self.eta_min}, " f"n_cycles={self.n_cycles}, weighted_ensemble={self.weighted_ensemble}, " f"n_snapshots={self.n_snapshots}") def _initialize_training_params(self, num_training_batches): """Set batches per epoch, total steps, cycle length, and warmup steps.""" self.batches_per_epoch = num_training_batches self.total_steps = self.total_epochs * self.batches_per_epoch self.cycle_length = self.total_steps // self.n_cycles # Ensure cycle length is at least 1 to avoid division by zero if self.cycle_length < 1: raise ValueError(f"Cycle length ({self.cycle_length}) must be at least 1. " f"Total steps: {self.total_steps}, n_cycles: {self.n_cycles}") assert self.total_steps > self.n_cycles, \ f"Total steps ({self.total_steps}) must be greater than Total Cycles ({self.n_cycles})." # Calculate warmup steps based on warmup epochs if self.warmup_epochs > 0: self.warmup_steps = int(self.warmup_epochs * self.batches_per_epoch) if self.warmup_steps >= self.total_steps: raise ValueError(f"Warmup steps ({self.warmup_steps}) must be less than total steps ({self.total_steps}).") def _get_lr_schedule(self): def cosine_annealing(step): cycle_length = math.ceil(self.total_steps / self.n_cycles) cycle_step = step % cycle_length # Return multiplier: cosine annealing between eta_min/base_lr and eta_max/base_lr # Assuming base_lr is eta_max, we scale between eta_min/eta_max and 1.0 cosine_factor = 0.5 * (1 + np.cos(np.pi * cycle_step / cycle_length)) min_multiplier = self.eta_min / self.eta_max return min_multiplier + (1.0 - min_multiplier) * cosine_factor def mmcclr(step): # Convert step to epoch-based calculation current_epoch = step // self.batches_per_epoch cycle_length_epochs = self.total_epochs // self.n_cycles cycle_step = current_epoch % cycle_length_epochs # Return multiplier: cosine annealing between eta_min/base_lr and eta_max/base_lr # Assuming base_lr is eta_max, we scale between eta_min/eta_max and 1.0 cosine_factor = 0.5 * (1 + np.cos(np.pi * cycle_step / cycle_length_epochs)) min_multiplier = self.eta_min / self.eta_max return min_multiplier + (1.0 - min_multiplier) * cosine_factor def deferred(base_schedule): # Warmup returns 1.0; afterwards use base schedule shifted by warmup steps return lambda step: 1.0 if step < self.warmup_steps else base_schedule(step - self.warmup_steps) sched_map = { "cca": cosine_annealing, "mmcclr": mmcclr, "deferred_cca": deferred(cosine_annealing), "deferred_mmcclr": deferred(mmcclr), } if self.scheduler_name not in sched_map: raise ValueError(f"Unknown scheduler name: {self.scheduler_name}") return sched_map[self.scheduler_name] def _calculate_snap_weights(self): """ Calculate weights for model snapshots based on their loss values. The weight for each snapshot is inversely proportional to its loss. """ # Get losses in the order of the model names you intend to use in your ensemble: model_names = list(self.snapshot_loss.keys()) losses = np.array([self.snapshot_loss[name] for name in model_names]) min_loss = losses.min() max_loss = losses.max() # SnapE weights: (max+min) - model_loss raw_weights = (max_loss + min_loss) - losses # Clip to avoid negative weights raw_weights = np.clip(raw_weights, a_min=0, a_max=None) # Normalize weights to sum to 1 if raw_weights.sum() > 0: weights = raw_weights / raw_weights.sum() else: weights = np.ones_like(raw_weights) / len(raw_weights) self.snapshot_weights = dict(zip(model_names, weights))
[docs] def on_train_start(self, trainer, model): """Initialize training parameters and LR scheduler at start of training.""" self._initialize_training_params(trainer.num_training_batches) self.lr_lambda = self._get_lr_schedule() self.scheduler = LambdaLR(trainer.optimizers[0], lr_lambda=self.lr_lambda) self.step_count = 0 print(f"Using learning rate scheduler: {self.scheduler_name}")
[docs] def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): """Step the LR scheduler and save model snapshot if needed after each batch.""" self.scheduler.step() self.step_count += 1 # Log the learning rate for this step # current_lr = self.scheduler.get_last_lr()[0] if hasattr(self.scheduler, "get_last_lr") else None if self._is_snapshot_step(self.step_count): snapshot_path = os.path.join( self.snapshot_dir, f"snapshot_epoch_{trainer.current_epoch}.pt" ) torch.save(model.state_dict(), snapshot_path) self.snapshot_loss[os.path.basename(snapshot_path)] = outputs['loss'].item() # Store loss at snapshot step
def _is_snapshot_step(self, step): """ Determine if the current step is a snapshot step. For deferred schedulers: Take n_snapshots evenly distributed in the active scheduling phase. For regular schedulers: Take snapshots at the end of each cycle. """ if self.scheduler_name.startswith("deferred"): # Skip snapshots during warmup if step < self.warmup_steps: return False # Take n_snapshots evenly distributed in the remaining steps after warmup remaining_steps = self.total_steps - self.warmup_steps snapshot_interval = remaining_steps // self.n_snapshots steps_after_warmup = step - self.warmup_steps # Check if we're at a snapshot interval boundary return (steps_after_warmup + 1) % snapshot_interval == 0 else: # For non-deferred schedulers, use cycle-based snapshots return (step + 1) % self.cycle_length == 0
[docs] def on_fit_end(self, trainer, model): # Load all model snapshots from the snapshot directory self.ensemble_weights = None snapshot_files = sorted( [f for f in os.listdir(self.snapshot_dir) if f.endswith('.pt')] ) self.model_snapshots = [] for checkpoint in snapshot_files: checkpoint_path = os.path.join(self.snapshot_dir, checkpoint) state_dict = torch.load(checkpoint_path, map_location="cpu") model_copy = type(model)(model.args) model_copy.load_state_dict(state_dict) self.model_snapshots.append(model_copy) if self.snapshot_loss and self.weighted_ensemble: self._calculate_snap_weights() # 2. Build the weight list aligned to snapshot_files order: self.ensemble_weights = [self.snapshot_weights[fname] for fname in snapshot_files] ensemble_eval_report = evaluate_ensemble_link_prediction_performance( models=self.model_snapshots, triples=trainer.dataset.test_set, er_vocab=trainer.dataset.er_vocab.result(), weights=self.ensemble_weights, batch_size=trainer.num_training_batches, weighted_averaging=self.weighted_ensemble ) # Prepare a single dictionary with LR scheduling info and nested ensemble eval report self.ensemble_eval_report = { "scheduler_name": self.scheduler_name, "total_epochs": self.total_epochs, "num_cycles": self.n_cycles, "warmup_epochs": self.warmup_epochs, "lr_max": self.eta_max, "lr_min": self.eta_min, "batches_per_epoch": self.batches_per_epoch, "total_steps": self.total_steps, "cycle_length": self.cycle_length, "warmup_steps": self.warmup_steps, "weighted_ensemble": self.weighted_ensemble, "n_snapshots": self.n_snapshots, "ensemble_eval_report": ensemble_eval_report, "snapshot_loss": self.snapshot_loss } ensemble_eval_report_path = os.path.join(self.experiment_dir, "ensemble_eval_report.json") # Write the dictionary to the JSON file with open(ensemble_eval_report_path, 'w', encoding='utf-8') as f: json.dump(self.ensemble_eval_report, f, indent=4, ensure_ascii=False) print(f"Ensemble Evaluations: Evaluate {model.name} on Test Set with an ensemble of {len(self.model_snapshots)} models: \n{ensemble_eval_report}")
[docs] class SWA(AbstractCallback): """Stochastic Weight Averaging callbacks.""" def __init__(self, swa_start_epoch, swa_c_epochs:int=1, lr_init:float=0.1, swa_lr:float=0.05, max_epochs :int=None): super().__init__() """ Initialize SWA callback. Parameters ---------- swa_start_epoch: int The epoch at which to start SWA. swa_c_epochs: int The number of epochs to use for SWA. lr_init: float The initial learning rate. swa_lr: float The learning rate to use during SWA. max_epochs: int The maximum number of epochs. args.num_epochs """ self.swa_start_epoch = swa_start_epoch self.swa_c_epochs = swa_c_epochs self.swa_lr = swa_lr self.lr_init = lr_init self.max_epochs = max_epochs self.swa_model = None self.swa_n = 0 self.current_epoch = -1
[docs] def moving_average(self, swa_model, model, alpha): """Update SWA model with moving average of current model.""" with torch.no_grad(): for swa_param, param in zip(swa_model.parameters(), model.parameters()): swa_param.data = (1.0 - alpha) * swa_param.data + alpha * param.data
[docs] def on_fit_start(self, trainer, model): """Initialize SWA model with same architecture as main model.""" self.swa_model = type(model)(model.args) self.swa_model.load_state_dict(model.state_dict()) self.swa_model = self.swa_model.to(model.device) # Check if trainer has optimizer attribute, if not, try to get from optimizers list optimizer = getattr(trainer, 'optimizer', None) if optimizer is None: if not (hasattr(trainer, 'optimizers') and trainer.optimizers): raise AttributeError("Trainer does not have a valid optimizer or optimizers list.")
[docs] def on_train_epoch_start(self, trainer, model): """Update learning rate according to SWA schedule.""" # Get current epoch - simplified with fallback if hasattr(trainer, 'current_epoch'): self.current_epoch = trainer.current_epoch else: self.current_epoch += 1 # Calculate learning rate using the schedule t = self.current_epoch / self.max_epochs lr_ratio = self.swa_lr / self.lr_init if t <= 0.5: factor = 1.0 elif t <= 0.9: factor = 1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4 else: factor = lr_ratio new_lr = self.lr_init * factor # Get the optimizer from the trainer optimizer = getattr(trainer, 'optimizer', None) if optimizer is None and hasattr(trainer, 'optimizers') and trainer.optimizers: optimizer = trainer.optimizers[0] if isinstance(trainer.optimizers, list) else trainer.optimizers if optimizer is not None: for param_group in optimizer.param_groups: param_group['lr'] = new_lr
[docs] def on_train_epoch_end(self, trainer, model): """Apply SWA averaging if conditions are met.""" #set swa_model in trainer if eval_every_n_epochs or eval_at_epochs is set if model.args["eval_every_n_epochs"] > 0 or model.args["eval_at_epochs"] is not None: trainer.swa_model = self.swa_model # Check if we should apply SWA if self.current_epoch >= self.swa_start_epoch and \ (self.current_epoch - self.swa_start_epoch) % self.swa_c_epochs == 0: # Perform moving average update with the model directly self.moving_average(self.swa_model, model, 1.0 / (self.swa_n + 1)) self.swa_n += 1
[docs] def on_fit_end(self, trainer, model): """Replace main model with SWA model at the end of training.""" if self.swa_model is not None and self.swa_n > 0: # Copy SWA weights back to main model directly model.load_state_dict(self.swa_model.state_dict())