dicee.callbacks

Callbacks for training lifecycle events.

Provides callback classes for various training events including epoch end, model saving, weight averaging, and evaluation.

Classes

AccumulateEpochLossCallback

Callback to store epoch losses to a CSV file.

PrintCallback

Callback that prints training start/end times and total runtime.

KGESaveCallback

Callback that periodically saves model checkpoints during training.

PseudoLabellingCallback

Callback that augments the training set with pseudo-labelled triples.

Eval

Abstract base class for KGE training lifecycle callbacks.

KronE

Abstract base class for KGE training lifecycle callbacks.

Perturb

A callback for a three-Level Perturbation

PeriodicEvalCallback

Callback to periodically evaluate the model and optionally save checkpoints during training.

LRScheduler

Callback for managing learning rate scheduling and model snapshots.

Functions

estimate_q(eps)

estimate rate of convergence q from sequence esp

compute_convergence(seq, i)

Module Contents

class dicee.callbacks.AccumulateEpochLossCallback(path: str)[source]

Bases: dicee.abstracts.AbstractCallback

Callback to store epoch losses to a CSV file.

Parameters:

path – Directory path where the loss file will be saved.

path
on_fit_end(trainer, model) None[source]

Store epoch loss history to CSV file.

Parameters:
  • trainer – The trainer instance.

  • model – The model being trained.

class dicee.callbacks.PrintCallback[source]

Bases: dicee.abstracts.AbstractCallback

Callback that prints training start/end times and total runtime.

start_time
on_fit_start(trainer, pl_module)[source]

Called once before the first training epoch.

Parameters:
  • trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.

  • model (BaseKGE) – The model about to be trained.

on_fit_end(trainer, pl_module)[source]

Called once after the final training epoch completes.

Override to perform post-training actions such as saving the final model state, computing evaluation metrics, or cleaning up resources.

on_train_batch_end(*args, **kwargs)[source]

Called after each mini-batch gradient update.

Override to inspect or modify the model at a finer granularity than epoch-level hooks.

on_train_epoch_end(*args, **kwargs)[source]

Called at the end of each training epoch.

Parameters:
  • trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.

  • model (BaseKGE) – The model being trained. model.loss_history contains the per-epoch average losses accumulated so far.

class dicee.callbacks.KGESaveCallback(every_x_epoch: int, max_epochs: int, path: str)[source]

Bases: dicee.abstracts.AbstractCallback

Callback that periodically saves model checkpoints during training.

Parameters:
  • every_x_epoch (int or None) – Save a checkpoint every every_x_epoch epochs. When None, the interval defaults to max(max_epochs // 2, 1).

  • max_epochs (int) – Total number of training epochs (used to compute the default interval when every_x_epoch is None).

  • path (str) – Directory where checkpoint files will be written.

every_x_epoch
max_epochs
epoch_counter = 0
path
on_train_batch_end(*args, **kwargs)[source]

Called after each mini-batch gradient update.

Override to inspect or modify the model at a finer granularity than epoch-level hooks.

on_fit_start(trainer, pl_module)[source]

Called once before the first training epoch.

Parameters:
  • trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.

  • model (BaseKGE) – The model about to be trained.

on_train_epoch_end(*args, **kwargs)[source]

Called at the end of each training epoch.

Parameters:
  • trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.

  • model (BaseKGE) – The model being trained. model.loss_history contains the per-epoch average losses accumulated so far.

on_fit_end(*args, **kwargs)[source]

Called once after the final training epoch completes.

Override to perform post-training actions such as saving the final model state, computing evaluation metrics, or cleaning up resources.

on_epoch_end(model, trainer, **kwargs)[source]
class dicee.callbacks.PseudoLabellingCallback(data_module, kg, batch_size)[source]

Bases: dicee.abstracts.AbstractCallback

Callback that augments the training set with pseudo-labelled triples.

At the end of each epoch the current model scores a batch of unlabelled triples and those with a predicted probability >= 0.90 are appended to the training dataset (semi-supervised self-training).

Parameters:
  • data_module (object) – Dataset module exposing a train_set_idx attribute and a train_dataloader() method.

  • kg (KG) – The knowledge graph, providing num_entities, num_relations, and unlabelled_set.

  • batch_size (int) – Number of unlabelled triples to sample and score each epoch.

data_module
kg
num_of_epochs = 0
unlabelled_size
batch_size
create_random_data()[source]
on_epoch_end(trainer, model)[source]
dicee.callbacks.estimate_q(eps)[source]

estimate rate of convergence q from sequence esp

dicee.callbacks.compute_convergence(seq, i)[source]
class dicee.callbacks.Eval(path, epoch_ratio: int = None)[source]

Bases: dicee.abstracts.AbstractCallback

Abstract base class for KGE training lifecycle callbacks.

Concrete sub-classes override one or more hook methods to perform custom actions at specific points during training (e.g. weight averaging, periodic evaluation, model checkpointing). All hooks have empty default implementations so sub-classes only need to override the hooks they care about.

Callbacks are registered by passing them to the trainer’s callbacks list. They are also compatible with PyTorch Lightning trainers because this class extends lightning.pytorch.callbacks.Callback.

path
reports = []
epoch_ratio
epoch_counter = 0
on_fit_start(trainer, model)[source]

Called once before the first training epoch.

Parameters:
  • trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.

  • model (BaseKGE) – The model about to be trained.

on_fit_end(trainer, model)[source]

Called once after the final training epoch completes.

Override to perform post-training actions such as saving the final model state, computing evaluation metrics, or cleaning up resources.

on_train_epoch_end(trainer, model)[source]

Called at the end of each training epoch.

Parameters:
  • trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.

  • model (BaseKGE) – The model being trained. model.loss_history contains the per-epoch average losses accumulated so far.

on_train_batch_end(*args, **kwargs)[source]

Called after each mini-batch gradient update.

Override to inspect or modify the model at a finer granularity than epoch-level hooks.

class dicee.callbacks.KronE[source]

Bases: dicee.abstracts.AbstractCallback

Abstract base class for KGE training lifecycle callbacks.

Concrete sub-classes override one or more hook methods to perform custom actions at specific points during training (e.g. weight averaging, periodic evaluation, model checkpointing). All hooks have empty default implementations so sub-classes only need to override the hooks they care about.

Callbacks are registered by passing them to the trainer’s callbacks list. They are also compatible with PyTorch Lightning trainers because this class extends lightning.pytorch.callbacks.Callback.

f = None
static batch_kronecker_product(a, b)[source]

Kronecker product of matrices a and b with leading batch dimensions. Batch dimensions are broadcast. The number of them mush :type a: torch.Tensor :type b: torch.Tensor :rtype: torch.Tensor

get_kronecker_triple_representation(indexed_triple: torch.LongTensor)[source]

Get kronecker embeddings

on_fit_start(trainer, model)[source]

Called once before the first training epoch.

Parameters:
  • trainer (AbstractTrainer or pl.Trainer) – The active trainer instance.

  • model (BaseKGE) – The model about to be trained.

class dicee.callbacks.Perturb(level: str = 'input', ratio: float = 0.0, method: str = None, scaler: float = None, frequency=None)[source]

Bases: dicee.abstracts.AbstractCallback

A callback for a three-Level Perturbation

Input Perturbation: During training an input x is perturbed by randomly replacing its element. In the context of knowledge graph embedding models, x can denote a triple, a tuple of an entity and a relation, or a tuple of two entities. A perturbation means that a component of x is randomly replaced by an entity or a relation.

Parameter Perturbation:

Output Perturbation:

level = 'input'
ratio = 0.0
method = None
scaler = None
frequency = None
on_train_batch_start(trainer, model, batch, batch_idx)[source]

Called when the train batch begins.

class dicee.callbacks.PeriodicEvalCallback(experiment_path: str, max_epochs: int, eval_every_n_epoch: int = 0, eval_at_epochs: list = None, save_model_every_n_epoch: bool = True, n_epochs_eval_model: str = 'val_test')[source]

Bases: dicee.abstracts.AbstractCallback

Callback to periodically evaluate the model and optionally save checkpoints during training.

Evaluates at regular intervals (every N epochs) or at explicitly specified epochs. Stores evaluation reports and model checkpoints.

experiment_dir
max_epochs
epoch_counter = 0
save_model_every_n_epoch = True
reports
n_epochs_eval_model = 'val_test'
default_eval_model = None
eval_epochs
on_fit_end(trainer, model)[source]

Called at the end of training. Saves final evaluation report.

on_train_epoch_end(trainer, model)[source]

Called at the end of each training epoch. Performs evaluation and checkpointing if scheduled.

class dicee.callbacks.LRScheduler(adaptive_lr_config: dict, total_epochs: int, experiment_dir: str, eta_max: float = 0.1, snapshot_dir: str = 'snapshots')[source]

Bases: dicee.abstracts.AbstractCallback

Callback for managing learning rate scheduling and model snapshots.

Supports cosine annealing (“cca”), MMCCLR (“mmcclr”), and their deferred (warmup) variants: - “deferred_cca” - “deferred_mmcclr”

At the end of each learning rate cycle, the model can optionally be saved as a snapshot.

total_epochs
experiment_dir
snapshot_dir
batches_per_epoch = None
total_steps = None
cycle_length = None
warmup_steps = None
lr_lambda = None
scheduler = None
step_count = 0
snapshot_loss
on_train_start(trainer, model)[source]

Initialize training parameters and LR scheduler at start of training.

on_train_batch_end(trainer, model, outputs, batch, batch_idx)[source]

Step the LR scheduler and save model snapshot if needed after each batch.

on_fit_end(trainer, model)[source]

Called once after the final training epoch completes.

Override to perform post-training actions such as saving the final model state, computing evaluation metrics, or cleaning up resources.