Source code for dicee.callbacks

"""Callbacks for training lifecycle events.

Provides callback classes for various training events including
epoch end, model saving, weight averaging, and evaluation.
"""
import copy
import datetime
import json
import math
import os
import time
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
from pytorch_lightning.utilities import rank_zero_only
from torch._dynamo.eval_frame import OptimizedModule
from torch.optim.lr_scheduler import LambdaLR

import dicee.models.base_model
from .abstracts import AbstractCallback
from .evaluation.ensemble import evaluate_ensemble_link_prediction_performance
from .static_funcs import save_checkpoint_model, save_pickle


[docs] class AccumulateEpochLossCallback(AbstractCallback): """Callback to store epoch losses to a CSV file. Args: path: Directory path where the loss file will be saved. """ def __init__(self, path: str): super().__init__() self.path = path
[docs] def on_fit_end(self, trainer, model) -> None: """Store epoch loss history to CSV file. Args: trainer: The trainer instance. model: The model being trained. """ loss_df = pd.DataFrame(model.loss_history, columns=['EpochLoss']) loss_df.to_csv(os.path.join(self.path, 'epoch_losses.csv'))
[docs] class PrintCallback(AbstractCallback): def __init__(self): super().__init__() self.start_time = time.time()
[docs] def on_fit_start(self, trainer, pl_module): # print(pl_module) # print(pl_module.summarize()) # print(pl_module.selected_optimizer) print(f"\nTraining is starting {datetime.datetime.now()}...")
[docs] def on_fit_end(self, trainer, pl_module): training_time = time.time() - self.start_time if 60 > training_time: message = f'{training_time:.3f} seconds.' elif 60 * 60 > training_time > 60: message = f'{training_time / 60:.3f} minutes.' elif training_time > 60 * 60: message = f'{training_time / (60 * 60):.3f} hours.' else: message = f'{training_time:.3f} seconds.' print(f"Training Runtime: {message}\n")
[docs] def on_train_batch_end(self, *args, **kwargs): return
[docs] def on_train_epoch_end(self, *args, **kwargs): return
[docs] class KGESaveCallback(AbstractCallback): def __init__(self, every_x_epoch: int, max_epochs: int, path: str): super().__init__() self.every_x_epoch = every_x_epoch self.max_epochs = max_epochs self.epoch_counter = 0 self.path = path if self.every_x_epoch is None: self.every_x_epoch = max(self.max_epochs // 2, 1)
[docs] def on_train_batch_end(self, *args, **kwargs): return
[docs] def on_fit_start(self, trainer, pl_module): pass
[docs] def on_train_epoch_end(self, *args, **kwargs): pass
[docs] def on_fit_end(self, *args, **kwargs): pass
[docs] def on_epoch_end(self, model, trainer, **kwargs): if self.epoch_counter % self.every_x_epoch == 0 and self.epoch_counter > 1: print(f'\nStoring model {self.epoch_counter}...') save_checkpoint_model(model, path=self.path + f'/model_at_{str(self.epoch_counter)}_' f'epoch_{str(str(datetime.datetime.now()))}.pt') self.epoch_counter += 1
[docs] class PseudoLabellingCallback(AbstractCallback): def __init__(self, data_module, kg, batch_size): super().__init__() self.data_module = data_module self.kg = kg self.num_of_epochs = 0 self.unlabelled_size = len(self.kg.unlabelled_set) self.batch_size = batch_size
[docs] def create_random_data(self): entities = torch.randint(low=0, high=self.kg.num_entities, size=(self.batch_size, 2)) relations = torch.randint(low=0, high=self.kg.num_relations, size=(self.batch_size,)) # unlabelled triples return torch.stack((entities[:, 0], relations, entities[:, 1]), dim=1)
[docs] def on_epoch_end(self, trainer, model): # Create random triples # if trainer.current_epoch < 10: # return None # Increase it size, Now we increase it. model.eval() with torch.no_grad(): # (1) Create random triples # unlabelled_input_batch = self.create_random_data() # (2) or use unlabelled batch unlabelled_input_batch = self.kg.unlabelled_set[ torch.randint(low=0, high=self.unlabelled_size, size=(self.batch_size,))] # (2) Predict unlabelled batch, and use prediction as pseudo-labels pseudo_label = torch.sigmoid(model(unlabelled_input_batch)) selected_triples = unlabelled_input_batch[pseudo_label >= .90] if len(selected_triples) > 0: # Update dataset self.data_module.train_set_idx = np.concatenate( (self.data_module.train_set_idx, selected_triples.detach().numpy()), axis=0) trainer.train_dataloader = self.data_module.train_dataloader() print(f'\tEpoch:{trainer.current_epoch}: Pseudo-labelling\t |D|= {len(self.data_module.train_set_idx)}') model.train()
[docs] def estimate_q(eps): """ estimate rate of convergence q from sequence esp""" x = np.arange(len(eps) - 1) y = np.log(np.abs(np.diff(np.log(eps)))) line = np.polyfit(x, y, 1) # fit degree 1 polynomial q = np.exp(line[0]) # find q return q
[docs] def compute_convergence(seq, i): assert len(seq) >= i > 0 return estimate_q(seq[-i:] / (np.arange(i) + 1))
[docs] class Eval(AbstractCallback): def __init__(self, path, epoch_ratio: int = None): super().__init__() self.path = path self.reports = [] self.epoch_ratio = epoch_ratio if epoch_ratio is not None else 1 self.epoch_counter = 0
[docs] def on_fit_start(self, trainer, model): pass
[docs] def on_fit_end(self, trainer, model): save_pickle(data=self.reports, file_path=trainer.attributes.full_storage_path + '/evals_per_epoch') """ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 7)) for (p,q), mrr in pairs_to_train_mrr.items(): ax1.plot(mrr, label=f'{p},{q}') ax1.set_ylabel('Train MRR') for (p,q), mrr in pairs_to_val_mrr.items(): ax2.plot(mrr, label=f'{p},{q}') ax2.set_ylabel('Val MRR') plt.legend() plt.xlabel('Epochs') plt.savefig('{full_storage_path}train_val_mrr.pdf') plt.show() """
[docs] def on_train_epoch_end(self, trainer, model): self.epoch_counter += 1 if self.epoch_counter % self.epoch_ratio == 0: model.eval() report = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=model, form_of_labelling=trainer.form_of_labelling, during_training=True) model.train() self.reports.append(report)
[docs] def on_train_batch_end(self, *args, **kwargs): return
[docs] class KronE(AbstractCallback): def __init__(self): super().__init__() self.f = None
[docs] @staticmethod def batch_kronecker_product(a, b): """ Kronecker product of matrices a and b with leading batch dimensions. Batch dimensions are broadcast. The number of them mush :type a: torch.Tensor :type b: torch.Tensor :rtype: torch.Tensor """ a, b = a.unsqueeze(1), b.unsqueeze(1) siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:])) res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4) siz0 = res.shape[:-4] res = res.reshape(siz0 + siz1) return res.flatten(1)
[docs] def get_kronecker_triple_representation(self, indexed_triple: torch.LongTensor): """ Get kronecker embeddings """ n, d = indexed_triple.shape assert d == 3 # Get the embeddings head_ent_emb, rel_ent_emb, tail_ent_emb = self.f(indexed_triple) head_ent_kron_emb = self.batch_kronecker_product(*torch.hsplit(head_ent_emb, 2)) rel_ent_kron_emb = self.batch_kronecker_product(*torch.hsplit(rel_ent_emb, 2)) tail_ent_kron_emb = self.batch_kronecker_product(*torch.hsplit(tail_ent_emb, 2)) return torch.cat((head_ent_emb, head_ent_kron_emb), dim=1), \ torch.cat((rel_ent_emb, rel_ent_kron_emb), dim=1), \ torch.cat((tail_ent_emb, tail_ent_kron_emb), dim=1)
[docs] def on_fit_start(self, trainer, model): if isinstance(model.normalize_head_entity_embeddings, dicee.models.base_model.IdentityClass): self.f = model.get_triple_representation model.get_triple_representation = self.get_kronecker_triple_representation else: raise NotImplementedError('Normalizer should be reinitialized')
[docs] class Perturb(AbstractCallback): """ A callback for a three-Level Perturbation Input Perturbation: During training an input x is perturbed by randomly replacing its element. In the context of knowledge graph embedding models, x can denote a triple, a tuple of an entity and a relation, or a tuple of two entities. A perturbation means that a component of x is randomly replaced by an entity or a relation. Parameter Perturbation: Output Perturbation: """ def __init__(self, level: str = "input", ratio: float = 0.0, method: str = None, scaler: float = None, frequency=None): """ level in {input, param, output} ratio:float btw [0,1] a percentage of mini-batch data point to be perturbed. method = ? """ super().__init__() assert level in {"input", "param", "out"} assert ratio >= 0.0 self.level = level self.ratio = ratio self.method = method self.scaler = scaler self.frequency = frequency # per epoch, per mini-batch ?
[docs] def on_train_batch_start(self, trainer, model, batch, batch_idx): # Modifications should be in-place # (1) Extract the input and output data points in a given batch. x, y = batch n, _ = x.shape assert n > 0 # (2) Compute the number of perturbed data points. num_of_perturbed_data = int(n * self.ratio) if num_of_perturbed_data == 0: return None # (3) Detect the device on which data points reside device = x.get_device() if device == -1: device = "cpu" # (4) Sample random integers from 0 to n without replacement and take num_of_perturbed_data of tem random_indices = torch.randperm(n, device=device)[:num_of_perturbed_data] # (5) Apply perturbation depending on the level. # (5.1) Apply Input level perturbation. if self.level == "input": if torch.rand(1) > 0.5: # (5.1.1) Perturb input via heads: Sample random indices for heads. perturbation = torch.randint(low=0, high=model.num_entities, size=(num_of_perturbed_data,), device=device) # Replace the head entities with (5.1.1) on given randomly selected data points in a mini-batch. x[random_indices] = torch.column_stack((perturbation, x[:, 1][random_indices])) else: # (5.1.2) Perturb input via relations : Sample random indices for relations. perturbation = torch.randint(low=0, high=model.num_relations, size=(num_of_perturbed_data,), device=device) # Replace the relations with (5.1.2) on given randomly selected data points in a mini-batch. x[random_indices] = torch.column_stack( (x[:, 0][random_indices], perturbation)) # (5.2) Apply Parameter level perturbation. elif self.level == "param": h, r = torch.hsplit(x, 2) # (5.2.1) Apply Gaussian Perturbation if self.method == "GN": if torch.rand(1) > 0.0: # (5.2.1.1) Apply Gaussian Perturbation on heads. h_selected = h[random_indices] with torch.no_grad(): model.entity_embeddings.weight[h_selected] += torch.normal(mean=0, std=self.scaler, size=model.entity_embeddings.weight[ h_selected].shape, device=model.device) else: # (5.2.1.2) Apply Gaussian Perturbation on relations. r_selected = r[random_indices] with (torch.no_grad()): model.relation_embeddings.weight[r_selected] += torch.normal(mean=0, std=self.scaler, size= model.entity_embeddings.weight[ r_selected].shape, device=model.device) # (5.2.2) Apply Random Perturbation elif self.method == "RN": if torch.rand(1) > 0.0: # (5.2.2.1) Apply Random Perturbation on heads. h_selected = h[random_indices] with torch.no_grad(): model.entity_embeddings.weight[h_selected] += torch.rand( size=model.entity_embeddings.weight[h_selected].shape, device=model.device) * self.scaler else: # (5.2.2.2) Apply Random Perturbation on relations. r_selected = r[random_indices] with torch.no_grad(): model.relation_embeddings.weight[r_selected] += torch.rand( size=model.entity_embeddings.weight[r_selected].shape, device=model.device) * self.scaler else: raise RuntimeError(f"--method is given as {self.method}!") elif self.level == "out": # (5.3) Apply output level perturbation. if self.method == "Soft": # (5.3) Output level soft perturbation resembles label smoothing. # (5.3.1) Compute the perturbation rate. perturb = torch.rand(1, device=model.device) * self.scaler # https://pytorch.org/docs/stable/generated/torch.where.html # 1.0 => 1.0 - perturb # 0.0 => perturb # (5.3.2) Reduces 1s and increases 0s via (5.2.1) batch[1][random_indices] = torch.where(batch[1][random_indices] == 1.0, 1.0 - perturb, perturb) elif self.method == "Hard": # (5.3) Output level hard perturbation flips 1s to 0 and 0s to 1s. batch[1][random_indices] = torch.where(batch[1][random_indices] == 1.0, 0.0, 1.0) else: raise NotImplementedError(f"{self.level}") else: raise RuntimeError(f"--level is given as {self.level}!")
[docs] class PeriodicEvalCallback(AbstractCallback): """ Callback to periodically evaluate the model and optionally save checkpoints during training. Evaluates at regular intervals (every N epochs) or at explicitly specified epochs. Stores evaluation reports and model checkpoints. """ def __init__(self, experiment_path: str, max_epochs: int, eval_every_n_epoch: int = 0, eval_at_epochs: list = None, save_model_every_n_epoch: bool = True, n_epochs_eval_model: str = "val_test"): """ Initialize the PeriodicEvalCallback. Parameters ---------- experiment_path : str Directory where evaluation reports and model checkpoints will be saved. max_epochs : int Maximum number of training epochs. eval_every_n_epoch : int, optional Evaluate every N epochs. Ignored if eval_at_epochs is provided. eval_at_epochs : list, optional List of specific epochs at which to evaluate. If provided and non-empty, overrides eval_every_n_epoch. save_model_every_n_epoch : bool, optional Whether to save model checkpoints at each evaluation epoch. n_epochs_eval_model : str, optional Evaluation mode for N epochs. Default is "val_test". """ super().__init__() self.experiment_dir = experiment_path self.max_epochs = max_epochs self.epoch_counter = 0 self.save_model_every_n_epoch = save_model_every_n_epoch self.reports = defaultdict(dict) self.n_epochs_eval_model = n_epochs_eval_model self.default_eval_model = None # Determine evaluation epochs: combine explicit list and interval if provided eval_epochs_set = set() if eval_at_epochs and len(eval_at_epochs) > 0: eval_epochs_set.update(eval_at_epochs) if eval_every_n_epoch > 0: eval_epochs_set.update(range(eval_every_n_epoch, max_epochs + 1, eval_every_n_epoch)) self.eval_epochs = eval_epochs_set # Prepare directory for model checkpoints if needed if self.save_model_every_n_epoch: self.n_epochs_storage_path = os.path.join(self.experiment_dir, 'models_n_epochs') os.makedirs(self.n_epochs_storage_path, exist_ok=True)
[docs] @rank_zero_only def on_fit_end(self, trainer, model): """ Called at the end of training. Saves final evaluation report.""" report_path = os.path.join(self.experiment_dir, 'eval_report_n_epochs.json') with open(report_path, 'w') as f: json.dump(self.reports, f, indent=4)
[docs] @rank_zero_only def on_train_epoch_end(self, trainer, model): """ Called at the end of each training epoch. Performs evaluation and checkpointing if scheduled. """ self.epoch_counter += 1 # Check if current epoch is scheduled for evaluation if self.epoch_counter not in self.eval_epochs: return # Store default evaluation mode once if self.default_eval_model is None: self.default_eval_model = trainer.evaluator.args.eval_model # Skip evaluation if default model already covers all eval splits and it's the final epoch if self.epoch_counter == self.max_epochs: default_splits = set(self.default_eval_model.split('_')) required_splits = set(self.n_epochs_eval_model.split('_')) if required_splits.issubset(default_splits): return # Set evaluation mode for this scheduled epoch trainer.evaluator.args.eval_model = self.n_epochs_eval_model # Prepare evaluation model eval_model = None if any(model.args.get(k) for k in ("swa", "ema", "twa")): try: eval_model = copy.deepcopy(trainer.wa_model) except Exception: # If SWA epoch is not reached, trainer has no swa model # fallback to the original model eval_model = copy.deepcopy(model) elif model.args.get("adaptive_swa"): # Load ASWA weights and apply to a deepcopy of the model aswa_path = os.path.join(self.experiment_dir, "aswa.pt") aswa_ensemble_params = torch.load(aswa_path, map_location="cpu") # Clone model and apply ASWA weights if isinstance(model, OptimizedModule): eval_model = type(model._orig_mod)(model.args) else: eval_model = type(model)(model.args) eval_model.load_state_dict(aswa_ensemble_params) else: eval_model = copy.deepcopy(model) eval_model.to('cpu') eval_model.eval() report = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=eval_model, form_of_labelling=trainer.form_of_labelling, during_training=True) # Restore evaluation mode trainer.evaluator.args.eval_model = self.default_eval_model # Store evaluation report self.reports[f'epoch_{self.epoch_counter}_eval'] = report # Save model checkpoint if needed if self.save_model_every_n_epoch: save_path = os.path.join(self.n_epochs_storage_path, f'model_at_epoch_{self.epoch_counter}.pt') torch.save(eval_model.state_dict(), save_path) # Free memory only if eval_model is a separate instance (ASWA case) # if model.args.get("adaptive_swa") and eval_model is not model: # del eval_model del eval_model torch.cuda.empty_cache()
[docs] class LRScheduler(AbstractCallback): """ Callback for managing learning rate scheduling and model snapshots. Supports cosine annealing ("cca"), MMCCLR ("mmcclr"), and their deferred (warmup) variants: - "deferred_cca" - "deferred_mmcclr" At the end of each learning rate cycle, the model can optionally be saved as a snapshot. """ def __init__( self, adaptive_lr_config: dict, total_epochs: int, experiment_dir: str, eta_max: float = 0.1, snapshot_dir: str = "snapshots", ): """ Initialize the LR scheduler callback. Args: adaptive_lr_config (dict): Configuration dictionary containing LR scheduling parameters. Can include: scheduler_name, lr_min, num_cycles, weighted_ensemble, n_snapshots total_epochs (int): Total number of training epochs (args.num_epochs) experiment_dir (str): Directory for the experiment, used as base for snapshots. eta_max (float, optional): Maximum learning rate at the start of each cycle. passed from `args.lr`. Default is 0.1. snapshot_dir (str, optional): Subdirectory inside experiment_dir where snapshots will be saved. Default is "snapshots". """ # Validate and set defaults for configuration self._validate_and_set_config(adaptive_lr_config, eta_max) self.total_epochs = total_epochs self.experiment_dir = experiment_dir self.snapshot_dir = os.path.join(experiment_dir, snapshot_dir) os.makedirs(self.snapshot_dir, exist_ok=True) assert self.eta_max > self.eta_min, \ f"Max Learning Rate ({self.eta_max}) must be greater than Min Learning Rate ({self.eta_min})" # Calculate warmup epochs only for deferred schedulers if self.scheduler_name.startswith("deferred"): # Use formula: defer for (n_cycles - n_snapshots) cycles deferred_cycles = self.n_cycles - self.n_snapshots self.warmup_epochs = int(deferred_cycles / self.n_cycles * self.total_epochs) else: # Non-deferred schedulers don't use warmup self.warmup_epochs = 0 # Placeholders to be set during training self.batches_per_epoch = None self.total_steps = None self.cycle_length = None self.warmup_steps = None self.lr_lambda = None self.scheduler = None self.step_count = 0 self.snapshot_loss = defaultdict(float) def _validate_and_set_config(self, config: dict, eta_max: float): """ Validate the adaptive_lr_config and set default values for missing parameters. """ # Default configuration defaults = { "scheduler_name": "cca", "lr_min": 0.01, "num_cycles": 10, "weighted_ensemble": True, "n_snapshots": 5 } # Validate config is a dictionary if not isinstance(config, dict): raise ValueError("adaptive_lr_config must be a dictionary") # Validate scheduler_name if "scheduler_name" in config: valid_schedulers = ["cca", "mmcclr", "deferred_cca", "deferred_mmcclr"] if config["scheduler_name"] not in valid_schedulers: raise ValueError(f"Invalid scheduler_name '{config['scheduler_name']}'. " f"Must be one of: {valid_schedulers}") # Validate lr_min if "lr_min" in config: lr_min = config["lr_min"] if not isinstance(lr_min, (int, float)) or lr_min <= 0: raise ValueError(f"lr_min must be a positive number, got: {lr_min}") if lr_min >= eta_max: raise ValueError(f"lr_min ({lr_min}) must be less than eta_max ({eta_max})") # Validate num_cycles if "num_cycles" in config: num_cycles = config["num_cycles"] if not isinstance(num_cycles, (int, float)) or num_cycles <= 0: raise ValueError(f"num_cycles must be a positive number, got: {num_cycles}") # Validate n_snapshots if "n_snapshots" in config: n_snapshots = config["n_snapshots"] if not isinstance(n_snapshots, int) or n_snapshots <= 0: raise ValueError(f"n_snapshots must be a positive integer, got: {n_snapshots}") # Validate weighted_ensemble if "weighted_ensemble" in config: weighted_ensemble = config["weighted_ensemble"] if not isinstance(weighted_ensemble, bool): raise ValueError(f"weighted_ensemble must be a boolean, got: {weighted_ensemble}") # Set attributes with defaults for missing values self.scheduler_name = config.get("scheduler_name", defaults["scheduler_name"]).lower() self.eta_min = config.get("lr_min", defaults["lr_min"]) self.n_cycles = config.get("num_cycles", defaults["num_cycles"]) self.weighted_ensemble = config.get("weighted_ensemble", defaults["weighted_ensemble"]) self.n_snapshots = config.get("n_snapshots", defaults["n_snapshots"]) self.eta_max = eta_max assert self.n_snapshots <= self.n_cycles, \ f"n_snapshots ({self.n_snapshots}) must be less than or equal to num_cycles ({self.n_cycles})" print(f"LRScheduler initialized with config: {config}") print(f"Using: scheduler_name={self.scheduler_name}, eta_min={self.eta_min}, " f"n_cycles={self.n_cycles}, weighted_ensemble={self.weighted_ensemble}, " f"n_snapshots={self.n_snapshots}") def _initialize_training_params(self, num_training_batches): """Set batches per epoch, total steps, cycle length, and warmup steps.""" self.batches_per_epoch = num_training_batches self.total_steps = self.total_epochs * self.batches_per_epoch self.cycle_length = self.total_steps // self.n_cycles # Ensure cycle length is at least 1 to avoid division by zero if self.cycle_length < 1: raise ValueError(f"Cycle length ({self.cycle_length}) must be at least 1. " f"Total steps: {self.total_steps}, n_cycles: {self.n_cycles}") assert self.total_steps > self.n_cycles, \ f"Total steps ({self.total_steps}) must be greater than Total Cycles ({self.n_cycles})." # Calculate warmup steps based on warmup epochs if self.warmup_epochs > 0: self.warmup_steps = int(self.warmup_epochs * self.batches_per_epoch) if self.warmup_steps >= self.total_steps: raise ValueError(f"Warmup steps ({self.warmup_steps}) must be less than total steps ({self.total_steps}).") def _get_lr_schedule(self): def cosine_annealing(step): cycle_length = math.ceil(self.total_steps / self.n_cycles) cycle_step = step % cycle_length # Return multiplier: cosine annealing between eta_min/base_lr and eta_max/base_lr # Assuming base_lr is eta_max, we scale between eta_min/eta_max and 1.0 cosine_factor = 0.5 * (1 + np.cos(np.pi * cycle_step / cycle_length)) min_multiplier = self.eta_min / self.eta_max return min_multiplier + (1.0 - min_multiplier) * cosine_factor def mmcclr(step): # Convert step to epoch-based calculation current_epoch = step // self.batches_per_epoch cycle_length_epochs = self.total_epochs // self.n_cycles cycle_step = current_epoch % cycle_length_epochs # Return multiplier: cosine annealing between eta_min/base_lr and eta_max/base_lr # Assuming base_lr is eta_max, we scale between eta_min/eta_max and 1.0 cosine_factor = 0.5 * (1 + np.cos(np.pi * cycle_step / cycle_length_epochs)) min_multiplier = self.eta_min / self.eta_max return min_multiplier + (1.0 - min_multiplier) * cosine_factor def deferred(base_schedule): # Warmup returns 1.0; afterwards use base schedule shifted by warmup steps return lambda step: 1.0 if step < self.warmup_steps else base_schedule(step - self.warmup_steps) sched_map = { "cca": cosine_annealing, "mmcclr": mmcclr, "deferred_cca": deferred(cosine_annealing), "deferred_mmcclr": deferred(mmcclr), } if self.scheduler_name not in sched_map: raise ValueError(f"Unknown scheduler name: {self.scheduler_name}") return sched_map[self.scheduler_name] def _calculate_snap_weights(self): """ Calculate weights for model snapshots based on their loss values. The weight for each snapshot is inversely proportional to its loss. """ # Get losses in the order of the model names you intend to use in your ensemble: model_names = list(self.snapshot_loss.keys()) losses = np.array([self.snapshot_loss[name] for name in model_names]) min_loss = losses.min() max_loss = losses.max() # SnapE weights: (max+min) - model_loss raw_weights = (max_loss + min_loss) - losses # Clip to avoid negative weights raw_weights = np.clip(raw_weights, a_min=0, a_max=None) # Normalize weights to sum to 1 if raw_weights.sum() > 0: weights = raw_weights / raw_weights.sum() else: weights = np.ones_like(raw_weights) / len(raw_weights) self.snapshot_weights = dict(zip(model_names, weights))
[docs] def on_train_start(self, trainer, model): """Initialize training parameters and LR scheduler at start of training.""" self._initialize_training_params(trainer.num_training_batches) self.lr_lambda = self._get_lr_schedule() self.scheduler = LambdaLR(trainer.optimizers[0], lr_lambda=self.lr_lambda) self.step_count = 0 print(f"Using learning rate scheduler: {self.scheduler_name}")
[docs] def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): """Step the LR scheduler and save model snapshot if needed after each batch.""" self.scheduler.step() self.step_count += 1 # Log the learning rate for this step # current_lr = self.scheduler.get_last_lr()[0] if hasattr(self.scheduler, "get_last_lr") else None if self._is_snapshot_step(self.step_count): snapshot_path = os.path.join( self.snapshot_dir, f"snapshot_epoch_{trainer.current_epoch}.pt" ) torch.save(model.state_dict(), snapshot_path) self.snapshot_loss[os.path.basename(snapshot_path)] = outputs['loss'].item() # Store loss at snapshot step
def _is_snapshot_step(self, step): """ Determine if the current step is a snapshot step. For deferred schedulers: Take n_snapshots evenly distributed in the active scheduling phase. For regular schedulers: Take snapshots at the end of each cycle. """ if self.scheduler_name.startswith("deferred"): # Skip snapshots during warmup if step < self.warmup_steps: return False # Take n_snapshots evenly distributed in the remaining steps after warmup remaining_steps = self.total_steps - self.warmup_steps snapshot_interval = remaining_steps // self.n_snapshots steps_after_warmup = step - self.warmup_steps # Check if we're at a snapshot interval boundary return (steps_after_warmup + 1) % snapshot_interval == 0 else: # For non-deferred schedulers, use cycle-based snapshots return (step + 1) % self.cycle_length == 0
[docs] def on_fit_end(self, trainer, model): # Load all model snapshots from the snapshot directory self.ensemble_weights = None snapshot_files = sorted( [f for f in os.listdir(self.snapshot_dir) if f.endswith('.pt')] ) self.model_snapshots = [] for checkpoint in snapshot_files: checkpoint_path = os.path.join(self.snapshot_dir, checkpoint) state_dict = torch.load(checkpoint_path, map_location="cpu") model_copy = type(model)(model.args) model_copy.load_state_dict(state_dict) self.model_snapshots.append(model_copy) if self.snapshot_loss and self.weighted_ensemble: self._calculate_snap_weights() # 2. Build the weight list aligned to snapshot_files order: self.ensemble_weights = [self.snapshot_weights[fname] for fname in snapshot_files] ensemble_eval_report = evaluate_ensemble_link_prediction_performance( models=self.model_snapshots, triples=trainer.dataset.test_set, er_vocab=trainer.dataset.er_vocab.result(), weights=self.ensemble_weights, batch_size=trainer.num_training_batches, weighted_averaging=self.weighted_ensemble ) # Prepare a single dictionary with LR scheduling info and nested ensemble eval report self.ensemble_eval_report = { "scheduler_name": self.scheduler_name, "total_epochs": self.total_epochs, "num_cycles": self.n_cycles, "warmup_epochs": self.warmup_epochs, "lr_max": self.eta_max, "lr_min": self.eta_min, "batches_per_epoch": self.batches_per_epoch, "total_steps": self.total_steps, "cycle_length": self.cycle_length, "warmup_steps": self.warmup_steps, "weighted_ensemble": self.weighted_ensemble, "n_snapshots": self.n_snapshots, "ensemble_eval_report": ensemble_eval_report, "snapshot_loss": self.snapshot_loss } ensemble_eval_report_path = os.path.join(self.experiment_dir, "ensemble_eval_report.json") # Write the dictionary to the JSON file with open(ensemble_eval_report_path, 'w', encoding='utf-8') as f: json.dump(self.ensemble_eval_report, f, indent=4, ensure_ascii=False) print(f"Ensemble Evaluations: Evaluate {model.name} on Test Set with an ensemble of {len(self.model_snapshots)} models: \n{ensemble_eval_report}")