dicee.callbacks

Classes

AccumulateEpochLossCallback

Abstract class for Callback class for knowledge graph embedding models

PrintCallback

Abstract class for Callback class for knowledge graph embedding models

KGESaveCallback

Abstract class for Callback class for knowledge graph embedding models

PseudoLabellingCallback

Abstract class for Callback class for knowledge graph embedding models

ASWA

Adaptive stochastic weight averaging

Eval

Abstract class for Callback class for knowledge graph embedding models

KronE

Abstract class for Callback class for knowledge graph embedding models

Perturb

A callback for a three-Level Perturbation

Functions

estimate_q(eps)

estimate rate of convergence q from sequence esp

compute_convergence(seq, i)

Module Contents

class dicee.callbacks.AccumulateEpochLossCallback(path: str)[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

path
on_fit_end(trainer, model) None[source]

Store epoch loss

Parameter

trainer:

model:

rtype:

None

class dicee.callbacks.PrintCallback[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

start_time
on_fit_start(trainer, pl_module)[source]

Call at the beginning of the training.

Parameter

trainer:

model:

rtype:

None

on_fit_end(trainer, pl_module)[source]

Call at the end of the training.

Parameter

trainer:

model:

rtype:

None

on_train_batch_end(*args, **kwargs)[source]

Call at the end of each mini-batch during the training.

Parameter

trainer:

model:

rtype:

None

on_train_epoch_end(*args, **kwargs)[source]

Call at the end of each epoch during training.

Parameter

trainer:

model:

rtype:

None

class dicee.callbacks.KGESaveCallback(every_x_epoch: int, max_epochs: int, path: str)[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

every_x_epoch
max_epochs
epoch_counter = 0
path
on_train_batch_end(*args, **kwargs)[source]

Call at the end of each mini-batch during the training.

Parameter

trainer:

model:

rtype:

None

on_fit_start(trainer, pl_module)[source]

Call at the beginning of the training.

Parameter

trainer:

model:

rtype:

None

on_train_epoch_end(*args, **kwargs)[source]

Call at the end of each epoch during training.

Parameter

trainer:

model:

rtype:

None

on_fit_end(*args, **kwargs)[source]

Call at the end of the training.

Parameter

trainer:

model:

rtype:

None

on_epoch_end(model, trainer, **kwargs)[source]
class dicee.callbacks.PseudoLabellingCallback(data_module, kg, batch_size)[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

data_module
kg
num_of_epochs = 0
unlabelled_size
batch_size
create_random_data()[source]
on_epoch_end(trainer, model)[source]
dicee.callbacks.estimate_q(eps)[source]

estimate rate of convergence q from sequence esp

dicee.callbacks.compute_convergence(seq, i)[source]
class dicee.callbacks.ASWA(num_epochs, path)[source]

Bases: dicee.abstracts.AbstractCallback

Adaptive stochastic weight averaging ASWE keeps track of the validation performance and update s the ensemble model accordingly.

path
num_epochs
initial_eval_setting = None
epoch_count = 0
alphas = []
val_aswa
on_fit_end(trainer, model)[source]

Call at the end of the training.

Parameter

trainer:

model:

rtype:

None

static compute_mrr(trainer, model) float[source]
get_aswa_state_dict(model)[source]
decide(running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model)[source]

Perform Hard Update, software or rejection

Parameters:
  • running_model_state_dict

  • ensemble_state_dict

  • val_running_model

  • mrr_updated_ensemble_model

on_train_epoch_end(trainer, model)[source]

Call at the end of each epoch during training.

Parameter

trainer:

model:

rtype:

None

class dicee.callbacks.Eval(path, epoch_ratio: int = None)[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

path
reports = []
epoch_ratio
epoch_counter = 0
on_fit_start(trainer, model)[source]

Call at the beginning of the training.

Parameter

trainer:

model:

rtype:

None

on_fit_end(trainer, model)[source]

Call at the end of the training.

Parameter

trainer:

model:

rtype:

None

on_train_epoch_end(trainer, model)[source]

Call at the end of each epoch during training.

Parameter

trainer:

model:

rtype:

None

on_train_batch_end(*args, **kwargs)[source]

Call at the end of each mini-batch during the training.

Parameter

trainer:

model:

rtype:

None

class dicee.callbacks.KronE[source]

Bases: dicee.abstracts.AbstractCallback

Abstract class for Callback class for knowledge graph embedding models

Parameter

f = None
static batch_kronecker_product(a, b)[source]

Kronecker product of matrices a and b with leading batch dimensions. Batch dimensions are broadcast. The number of them mush :type a: torch.Tensor :type b: torch.Tensor :rtype: torch.Tensor

get_kronecker_triple_representation(indexed_triple: torch.LongTensor)[source]

Get kronecker embeddings

on_fit_start(trainer, model)[source]

Call at the beginning of the training.

Parameter

trainer:

model:

rtype:

None

class dicee.callbacks.Perturb(level: str = 'input', ratio: float = 0.0, method: str = None, scaler: float = None, frequency=None)[source]

Bases: dicee.abstracts.AbstractCallback

A callback for a three-Level Perturbation

Input Perturbation: During training an input x is perturbed by randomly replacing its element. In the context of knowledge graph embedding models, x can denote a triple, a tuple of an entity and a relation, or a tuple of two entities. A perturbation means that a component of x is randomly replaced by an entity or a relation.

Parameter Perturbation:

Output Perturbation:

level
ratio
method
scaler
frequency
on_train_batch_start(trainer, model, batch, batch_idx)[source]

Called when the train batch begins.