dicee.trainer.dice_trainer
DICE Trainer module for knowledge graph embedding training.
Provides the DICE_Trainer class which supports multiple training backends including PyTorch Lightning, DDP, and custom CPU/GPU trainers.
Classes
DICE_Trainer implement |
Functions
|
Load term-to-index mapping from CSV file. |
|
Initialize the appropriate trainer based on configuration. |
|
Create list of callbacks based on configuration. |
Module Contents
- dicee.trainer.dice_trainer.load_term_mapping(file_path: str) polars.DataFrame
Load term-to-index mapping from CSV file.
- Parameters:
file_path – Base path without extension.
- Returns:
Polars DataFrame containing the mapping.
- dicee.trainer.dice_trainer.initialize_trainer(args, callbacks: List) dicee.trainer.torch_trainer.TorchTrainer | dicee.trainer.model_parallelism.TensorParallel | dicee.trainer.torch_trainer_ddp.TorchDDPTrainer | lightning.Trainer
Initialize the appropriate trainer based on configuration.
- Parameters:
args – Configuration arguments containing trainer type.
callbacks – List of training callbacks.
- Returns:
Initialized trainer instance.
- Raises:
AssertionError – If trainer is None after initialization.
- dicee.trainer.dice_trainer.get_callbacks(args) List
Create list of callbacks based on configuration.
- Parameters:
args – Configuration arguments.
- Returns:
List of callback instances.
- class dicee.trainer.dice_trainer.DICE_Trainer(args, is_continual_training: bool, storage_path, evaluator=None)
- DICE_Trainer implement
1- Pytorch Lightning trainer (https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) 2- Multi-GPU Trainer(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) 3- CPU Trainer
args
is_continual_training:bool
storage_path:str
evaluator:
report:dict
- report
- args
- trainer = None
- is_continual_training
- storage_path
- evaluator = None
- form_of_labelling = None
- continual_start(knowledge_graph)
Initialize training.
Load model
(3) Load trainer (3) Fit model
Parameter
- returns:
model
form_of_labelling (str)
- initialize_trainer(callbacks: List) lightning.Trainer | dicee.trainer.model_parallelism.TensorParallel | dicee.trainer.torch_trainer.TorchTrainer | dicee.trainer.torch_trainer_ddp.TorchDDPTrainer
Initialize Trainer from input arguments
- initialize_or_load_model()
- init_dataloader(dataset: torch.utils.data.Dataset) torch.utils.data.DataLoader
- init_dataset() torch.utils.data.Dataset
- start(knowledge_graph: dicee.knowledge_graph.KG | numpy.memmap) Tuple[dicee.models.base_model.BaseKGE, str]
Start the training
Initialize Trainer
Initialize or load a pretrained KGE model
in DDP setup, we need to load the memory map of already read/index KG.
- k_fold_cross_validation(dataset) Tuple[dicee.models.base_model.BaseKGE, str]
Perform K-fold Cross-Validation
Obtain K train and test splits.
- For each split,
2.1 initialize trainer and model 2.2. Train model with configuration provided in args. 2.3. Compute the mean reciprocal rank (MRR) score of the model on the test respective split.
Report the mean and average MRR .
- Parameters:
self
dataset
- Returns:
model