dicee.callbacks
Callbacks for training lifecycle events.
Provides callback classes for various training events including epoch end, model saving, weight averaging, and evaluation.
Classes
Callback to store epoch losses to a CSV file. |
|
Callback that prints training start/end times and total runtime. |
|
Callback that periodically saves model checkpoints during training. |
|
Callback that augments the training set with pseudo-labelled triples. |
|
Abstract base class for KGE training lifecycle callbacks. |
|
Abstract base class for KGE training lifecycle callbacks. |
|
A callback for a three-Level Perturbation |
|
Callback to periodically evaluate the model and optionally save checkpoints during training. |
|
Callback for managing learning rate scheduling and model snapshots. |
Functions
|
estimate rate of convergence q from sequence esp |
|
Module Contents
- class dicee.callbacks.AccumulateEpochLossCallback(path: str)[source]
Bases:
dicee.abstracts.AbstractCallbackCallback to store epoch losses to a CSV file.
- Parameters:
path – Directory path where the loss file will be saved.
- path
- class dicee.callbacks.PrintCallback[source]
Bases:
dicee.abstracts.AbstractCallbackCallback that prints training start/end times and total runtime.
- start_time
- on_fit_start(trainer, pl_module)[source]
Called once before the first training epoch.
- Parameters:
trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.
model (BaseKGE) – The model about to be trained.
- on_fit_end(trainer, pl_module)[source]
Called once after the final training epoch completes.
Override to perform post-training actions such as saving the final model state, computing evaluation metrics, or cleaning up resources.
- on_train_batch_end(*args, **kwargs)[source]
Called after each mini-batch gradient update.
Override to inspect or modify the model at a finer granularity than epoch-level hooks.
- on_train_epoch_end(*args, **kwargs)[source]
Called at the end of each training epoch.
- Parameters:
trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.
model (BaseKGE) – The model being trained.
model.loss_historycontains the per-epoch average losses accumulated so far.
- class dicee.callbacks.KGESaveCallback(every_x_epoch: int, max_epochs: int, path: str)[source]
Bases:
dicee.abstracts.AbstractCallbackCallback that periodically saves model checkpoints during training.
- Parameters:
every_x_epoch (int or None) – Save a checkpoint every every_x_epoch epochs. When
None, the interval defaults tomax(max_epochs // 2, 1).max_epochs (int) – Total number of training epochs (used to compute the default interval when every_x_epoch is
None).path (str) – Directory where checkpoint files will be written.
- every_x_epoch
- max_epochs
- epoch_counter = 0
- path
- on_train_batch_end(*args, **kwargs)[source]
Called after each mini-batch gradient update.
Override to inspect or modify the model at a finer granularity than epoch-level hooks.
- on_fit_start(trainer, pl_module)[source]
Called once before the first training epoch.
- Parameters:
trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.
model (BaseKGE) – The model about to be trained.
- on_train_epoch_end(*args, **kwargs)[source]
Called at the end of each training epoch.
- Parameters:
trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.
model (BaseKGE) – The model being trained.
model.loss_historycontains the per-epoch average losses accumulated so far.
- class dicee.callbacks.PseudoLabellingCallback(data_module, kg, batch_size)[source]
Bases:
dicee.abstracts.AbstractCallbackCallback that augments the training set with pseudo-labelled triples.
At the end of each epoch the current model scores a batch of unlabelled triples and those with a predicted probability >= 0.90 are appended to the training dataset (semi-supervised self-training).
- Parameters:
data_module (object) – Dataset module exposing a
train_set_idxattribute and atrain_dataloader()method.kg (KG) – The knowledge graph, providing
num_entities,num_relations, andunlabelled_set.batch_size (int) – Number of unlabelled triples to sample and score each epoch.
- data_module
- kg
- num_of_epochs = 0
- unlabelled_size
- batch_size
- class dicee.callbacks.Eval(path, epoch_ratio: int = None)[source]
Bases:
dicee.abstracts.AbstractCallbackAbstract base class for KGE training lifecycle callbacks.
Concrete sub-classes override one or more hook methods to perform custom actions at specific points during training (e.g. weight averaging, periodic evaluation, model checkpointing). All hooks have empty default implementations so sub-classes only need to override the hooks they care about.
Callbacks are registered by passing them to the trainer’s callbacks list. They are also compatible with PyTorch Lightning trainers because this class extends
lightning.pytorch.callbacks.Callback.- path
- reports = []
- epoch_ratio
- epoch_counter = 0
- on_fit_start(trainer, model)[source]
Called once before the first training epoch.
- Parameters:
trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.
model (BaseKGE) – The model about to be trained.
- on_fit_end(trainer, model)[source]
Called once after the final training epoch completes.
Override to perform post-training actions such as saving the final model state, computing evaluation metrics, or cleaning up resources.
- on_train_epoch_end(trainer, model)[source]
Called at the end of each training epoch.
- Parameters:
trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.
model (BaseKGE) – The model being trained.
model.loss_historycontains the per-epoch average losses accumulated so far.
- class dicee.callbacks.KronE[source]
Bases:
dicee.abstracts.AbstractCallbackAbstract base class for KGE training lifecycle callbacks.
Concrete sub-classes override one or more hook methods to perform custom actions at specific points during training (e.g. weight averaging, periodic evaluation, model checkpointing). All hooks have empty default implementations so sub-classes only need to override the hooks they care about.
Callbacks are registered by passing them to the trainer’s callbacks list. They are also compatible with PyTorch Lightning trainers because this class extends
lightning.pytorch.callbacks.Callback.- f = None
- static batch_kronecker_product(a, b)[source]
Kronecker product of matrices a and b with leading batch dimensions. Batch dimensions are broadcast. The number of them mush :type a: torch.Tensor :type b: torch.Tensor :rtype: torch.Tensor
- get_kronecker_triple_representation(indexed_triple: torch.LongTensor)[source]
Get kronecker embeddings
- on_fit_start(trainer, model)[source]
Called once before the first training epoch.
- Parameters:
trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.
model (BaseKGE) – The model about to be trained.
- class dicee.callbacks.Perturb(level: str = 'input', ratio: float = 0.0, method: str = None, scaler: float = None, frequency=None)[source]
Bases:
dicee.abstracts.AbstractCallbackA callback for a three-Level Perturbation
Input Perturbation: During training an input x is perturbed by randomly replacing its element. In the context of knowledge graph embedding models, x can denote a triple, a tuple of an entity and a relation, or a tuple of two entities. A perturbation means that a component of x is randomly replaced by an entity or a relation.
Parameter Perturbation:
Output Perturbation:
- level = 'input'
- ratio = 0.0
- method = None
- scaler = None
- frequency = None
- class dicee.callbacks.PeriodicEvalCallback(experiment_path: str, max_epochs: int, eval_every_n_epoch: int = 0, eval_at_epochs: list = None, save_model_every_n_epoch: bool = True, n_epochs_eval_model: str = 'val_test')[source]
Bases:
dicee.abstracts.AbstractCallbackCallback to periodically evaluate the model and optionally save checkpoints during training.
Evaluates at regular intervals (every N epochs) or at explicitly specified epochs. Stores evaluation reports and model checkpoints.
- experiment_dir
- max_epochs
- epoch_counter = 0
- save_model_every_n_epoch = True
- reports
- n_epochs_eval_model = 'val_test'
- default_eval_model = None
- eval_epochs
- class dicee.callbacks.LRScheduler(adaptive_lr_config: dict, total_epochs: int, experiment_dir: str, eta_max: float = 0.1, snapshot_dir: str = 'snapshots')[source]
Bases:
dicee.abstracts.AbstractCallbackCallback for managing learning rate scheduling and model snapshots.
Supports cosine annealing (“cca”), MMCCLR (“mmcclr”), and their deferred (warmup) variants: - “deferred_cca” - “deferred_mmcclr”
At the end of each learning rate cycle, the model can optionally be saved as a snapshot.
- total_epochs
- experiment_dir
- snapshot_dir
- batches_per_epoch = None
- total_steps = None
- cycle_length = None
- warmup_steps = None
- lr_lambda = None
- scheduler = None
- step_count = 0
- snapshot_loss
- on_train_start(trainer, model)[source]
Initialize training parameters and LR scheduler at start of training.