dicee.trainer.dice_trainer

Classes

DICE_Trainer

DICE_Trainer implement

Functions

initialize_trainer(args, callbacks)

get_callbacks(args)

Module Contents

dicee.trainer.dice_trainer.initialize_trainer(args, callbacks)[source]
dicee.trainer.dice_trainer.get_callbacks(args)[source]
class dicee.trainer.dice_trainer.DICE_Trainer(args, is_continual_training, storage_path, evaluator=None)[source]
DICE_Trainer implement

1- Pytorch Lightning trainer (https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) 2- Multi-GPU Trainer(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) 3- CPU Trainer

args

is_continual_training:bool

storage_path:str

evaluator:

report:dict

report
args
trainer = None
is_continual_training
storage_path
evaluator
form_of_labelling = None
continual_start()[source]
  1. Initialize training.

  2. Load model

(3) Load trainer (3) Fit model

Parameter

returns:
  • model

  • form_of_labelling (str)

initialize_trainer(callbacks: List) lightning.Trainer[source]

Initialize Trainer from input arguments

initialize_or_load_model()[source]
initialize_dataloader(dataset: torch.utils.data.Dataset) torch.utils.data.DataLoader[source]
initialize_dataset(dataset: dicee.knowledge_graph.KG, form_of_labelling) torch.utils.data.Dataset[source]
start(knowledge_graph: dicee.knowledge_graph.KG) Tuple[dicee.models.base_model.BaseKGE, str][source]

Train selected model via the selected training strategy

k_fold_cross_validation(dataset) Tuple[dicee.models.base_model.BaseKGE, str][source]

Perform K-fold Cross-Validation

  1. Obtain K train and test splits.

  2. For each split,

    2.1 initialize trainer and model 2.2. Train model with configuration provided in args. 2.3. Compute the mean reciprocal rank (MRR) score of the model on the test respective split.

  3. Report the mean and average MRR .

Parameters:
  • self

  • dataset

Returns:

model