dicee.trainer.dice_trainer

DICE Trainer module for knowledge graph embedding training.

Provides the DICE_Trainer class which supports multiple training backends including PyTorch Lightning, DDP, and custom CPU/GPU trainers.

Classes

DICE_Trainer

DICE_Trainer implement

Functions

load_term_mapping(→ polars.DataFrame)

Load term-to-index mapping from CSV file.

initialize_trainer(...)

Initialize the appropriate trainer based on configuration.

get_callbacks(→ List)

Create list of callbacks based on configuration.

Module Contents

dicee.trainer.dice_trainer.load_term_mapping(file_path: str) polars.DataFrame

Load term-to-index mapping from CSV file.

Parameters:

file_path – Base path without extension.

Returns:

Polars DataFrame containing the mapping.

dicee.trainer.dice_trainer.initialize_trainer(args, callbacks: List) dicee.trainer.torch_trainer.TorchTrainer | dicee.trainer.model_parallelism.TensorParallel | dicee.trainer.torch_trainer_ddp.TorchDDPTrainer | lightning.Trainer

Initialize the appropriate trainer based on configuration.

Parameters:
  • args – Configuration arguments containing trainer type.

  • callbacks – List of training callbacks.

Returns:

Initialized trainer instance.

Raises:

AssertionError – If trainer is None after initialization.

dicee.trainer.dice_trainer.get_callbacks(args) List

Create list of callbacks based on configuration.

Parameters:

args – Configuration arguments.

Returns:

List of callback instances.

class dicee.trainer.dice_trainer.DICE_Trainer(args, is_continual_training: bool, storage_path, evaluator=None)
DICE_Trainer implement

1- Pytorch Lightning trainer (https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) 2- Multi-GPU Trainer(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) 3- CPU Trainer

args

is_continual_training:bool

storage_path:str

evaluator:

report:dict

report
args
trainer = None
is_continual_training
storage_path
evaluator = None
form_of_labelling = None
continual_start(knowledge_graph)
  1. Initialize training.

  2. Load model

(3) Load trainer (3) Fit model

Parameter

returns:
  • model

  • form_of_labelling (str)

initialize_trainer(callbacks: List) lightning.Trainer | dicee.trainer.model_parallelism.TensorParallel | dicee.trainer.torch_trainer.TorchTrainer | dicee.trainer.torch_trainer_ddp.TorchDDPTrainer

Initialize Trainer from input arguments

initialize_or_load_model()
init_dataloader(dataset: torch.utils.data.Dataset) torch.utils.data.DataLoader
init_dataset() torch.utils.data.Dataset
start(knowledge_graph: dicee.knowledge_graph.KG | numpy.memmap) Tuple[dicee.models.base_model.BaseKGE, str]

Start the training

  1. Initialize Trainer

  2. Initialize or load a pretrained KGE model

in DDP setup, we need to load the memory map of already read/index KG.

k_fold_cross_validation(dataset) Tuple[dicee.models.base_model.BaseKGE, str]

Perform K-fold Cross-Validation

  1. Obtain K train and test splits.

  2. For each split,

    2.1 initialize trainer and model 2.2. Train model with configuration provided in args. 2.3. Compute the mean reciprocal rank (MRR) score of the model on the test respective split.

  3. Report the mean and average MRR .

Parameters:
  • self

  • dataset

Returns:

model