dicee.callbacks

Classes

AccumulateEpochLossCallback

Abstract class for Callback class for knowledge graph embedding models

PrintCallback

Abstract class for Callback class for knowledge graph embedding models

KGESaveCallback

Abstract class for Callback class for knowledge graph embedding models

PseudoLabellingCallback

Abstract class for Callback class for knowledge graph embedding models

ASWA

Adaptive stochastic weight averaging

Eval

Abstract class for Callback class for knowledge graph embedding models

KronE

Abstract class for Callback class for knowledge graph embedding models

Perturb

A callback for a three-Level Perturbation

PeriodicEvalCallback

Callback to periodically evaluate the model and optionally save checkpoints during training.

LRScheduler

Callback for managing learning rate scheduling and model snapshots.

Functions

estimate_q(eps)

estimate rate of convergence q from sequence esp

compute_convergence(seq, i)

Module Contents

class dicee.callbacks.AccumulateEpochLossCallback(path: str)[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

path
on_fit_end(trainer, model) None[source]

Store epoch loss

Parameter

trainer:

model:

rtype:

None

class dicee.callbacks.PrintCallback[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

start_time
on_fit_start(trainer, pl_module)[source]

Call at the beginning of the training.

Parameter

trainer:

model:

rtype:

None

on_fit_end(trainer, pl_module)[source]

Call at the end of the training.

Parameter

trainer:

model:

rtype:

None

on_train_batch_end(*args, **kwargs)[source]

Call at the end of each mini-batch during the training.

Parameter

trainer:

model:

rtype:

None

on_train_epoch_end(*args, **kwargs)[source]

Call at the end of each epoch during training.

Parameter

trainer:

model:

rtype:

None

class dicee.callbacks.KGESaveCallback(every_x_epoch: int, max_epochs: int, path: str)[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

every_x_epoch
max_epochs
epoch_counter = 0
path
on_train_batch_end(*args, **kwargs)[source]

Call at the end of each mini-batch during the training.

Parameter

trainer:

model:

rtype:

None

on_fit_start(trainer, pl_module)[source]

Call at the beginning of the training.

Parameter

trainer:

model:

rtype:

None

on_train_epoch_end(*args, **kwargs)[source]

Call at the end of each epoch during training.

Parameter

trainer:

model:

rtype:

None

on_fit_end(*args, **kwargs)[source]

Call at the end of the training.

Parameter

trainer:

model:

rtype:

None

on_epoch_end(model, trainer, **kwargs)[source]
class dicee.callbacks.PseudoLabellingCallback(data_module, kg, batch_size)[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

data_module
kg
num_of_epochs = 0
unlabelled_size
batch_size
create_random_data()[source]
on_epoch_end(trainer, model)[source]
dicee.callbacks.estimate_q(eps)[source]

estimate rate of convergence q from sequence esp

dicee.callbacks.compute_convergence(seq, i)[source]
class dicee.callbacks.ASWA(num_epochs, path)[source]

Bases: dicee.abstracts.AbstractCallback

Adaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly.

path
num_epochs
initial_eval_setting = None
epoch_count = 0
alphas = []
val_aswa = -1
on_fit_end(trainer, model)[source]

Call at the end of the training.

Parameter

trainer:

model:

rtype:

None

static compute_mrr(trainer, model) float[source]
get_aswa_state_dict(model)[source]
decide(running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model)[source]

Perform Hard Update, software or rejection

Parameters:
  • running_model_state_dict

  • ensemble_state_dict

  • val_running_model

  • mrr_updated_ensemble_model

on_train_epoch_end(trainer, model)[source]

Call at the end of each epoch during training.

Parameter

trainer:

model:

rtype:

None

class dicee.callbacks.Eval(path, epoch_ratio: int = None)[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

path
reports = []
epoch_ratio = None
epoch_counter = 0
on_fit_start(trainer, model)[source]

Call at the beginning of the training.

Parameter

trainer:

model:

rtype:

None

on_fit_end(trainer, model)[source]

Call at the end of the training.

Parameter

trainer:

model:

rtype:

None

on_train_epoch_end(trainer, model)[source]

Call at the end of each epoch during training.

Parameter

trainer:

model:

rtype:

None

on_train_batch_end(*args, **kwargs)[source]

Call at the end of each mini-batch during the training.

Parameter

trainer:

model:

rtype:

None

class dicee.callbacks.KronE[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

f = None
static batch_kronecker_product(a, b)[source]

Kronecker product of matrices a and b with leading batch dimensions. Batch dimensions are broadcast. The number of them mush :type a: torch.Tensor :type b: torch.Tensor :rtype: torch.Tensor

get_kronecker_triple_representation(indexed_triple: torch.LongTensor)[source]

Get kronecker embeddings

on_fit_start(trainer, model)[source]

Call at the beginning of the training.

Parameter

trainer:

model:

rtype:

None

class dicee.callbacks.Perturb(level: str = 'input', ratio: float = 0.0, method: str = None, scaler: float = None, frequency=None)[source]

Bases: dicee.abstracts.AbstractCallback

A callback for a three-Level Perturbation

Input Perturbation: During training an input x is perturbed by randomly replacing its element. In the context of knowledge graph embedding models, x can denote a triple, a tuple of an entity and a relation, or a tuple of two entities. A perturbation means that a component of x is randomly replaced by an entity or a relation.

Parameter Perturbation:

Output Perturbation:

level = 'input'
ratio = 0.0
method = None
scaler = None
frequency = None
on_train_batch_start(trainer, model, batch, batch_idx)[source]

Called when the train batch begins.

class dicee.callbacks.PeriodicEvalCallback(experiment_path: str, max_epochs: int, eval_every_n_epoch: int = 0, eval_at_epochs: list = None, save_model_every_n_epoch: bool = True, n_epochs_eval_model: str = 'val_test')[source]

Bases: dicee.abstracts.AbstractCallback

Callback to periodically evaluate the model and optionally save checkpoints during training.

Evaluates at regular intervals (every N epochs) or at explicitly specified epochs. Stores evaluation reports and model checkpoints.

experiment_dir
max_epochs
epoch_counter = 0
save_model_every_n_epoch = True
reports
n_epochs_eval_model = 'val_test'
default_eval_model = None
eval_epochs
on_fit_end(trainer, model)[source]

Called at the end of training. Saves final evaluation report.

on_train_epoch_end(trainer, model)[source]

Called at the end of each training epoch. Performs evaluation and checkpointing if scheduled.

Parameters:
  • trainer (object) – The training controller.

  • model (torch.nn.Module) – The model being trained.

class dicee.callbacks.LRScheduler(adaptive_lr_config: dict, total_epochs: int, experiment_dir: str, eta_max: float = 0.1, snapshot_dir: str = 'snapshots')[source]

Bases: dicee.abstracts.AbstractCallback

Callback for managing learning rate scheduling and model snapshots.

Supports cosine annealing (“cca”), MMCCLR (“mmcclr”), and their deferred (warmup) variants: - “deferred_cca” - “deferred_mmcclr”

At the end of each learning rate cycle, the model can optionally be saved as a snapshot.

total_epochs
experiment_dir
snapshot_dir
batches_per_epoch = None
total_steps = None
cycle_length = None
warmup_steps = None
lr_lambda = None
scheduler = None
step_count = 0
snapshot_loss
on_train_start(trainer, model)[source]

Initialize training parameters and LR scheduler at start of training.

on_train_batch_end(trainer, model, outputs, batch, batch_idx)[source]

Step the LR scheduler and save model snapshot if needed after each batch.

on_fit_end(trainer, model)[source]

Call at the end of the training.

Parameter

trainer:

model:

rtype:

None