dicee.trainer.dice_trainer
Classes
DICE_Trainer implement |
Functions
|
|
|
|
|
Module Contents
- dicee.trainer.dice_trainer.initialize_trainer(args, callbacks) dicee.trainer.torch_trainer.TorchTrainer | dicee.trainer.model_parallelism.TensorParallel | dicee.trainer.torch_trainer_ddp.TorchDDPTrainer | lightning.Trainer [source]
- class dicee.trainer.dice_trainer.DICE_Trainer(args, is_continual_training: bool, storage_path, evaluator=None)[source]
- DICE_Trainer implement
1- Pytorch Lightning trainer (https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) 2- Multi-GPU Trainer(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) 3- CPU Trainer
args
is_continual_training:bool
storage_path:str
evaluator:
report:dict
- report
- args
- trainer = None
- is_continual_training
- storage_path
- evaluator = None
- form_of_labelling = None
- continual_start(knowledge_graph)[source]
Initialize training.
Load model
(3) Load trainer (3) Fit model
Parameter
- returns:
model
form_of_labelling (str)
- initialize_trainer(callbacks: List) lightning.Trainer | dicee.trainer.model_parallelism.TensorParallel | dicee.trainer.torch_trainer.TorchTrainer | dicee.trainer.torch_trainer_ddp.TorchDDPTrainer [source]
Initialize Trainer from input arguments
- start(knowledge_graph: dicee.knowledge_graph.KG | numpy.memmap) Tuple[dicee.models.base_model.BaseKGE, str] [source]
Start the training
Initialize Trainer
Initialize or load a pretrained KGE model
in DDP setup, we need to load the memory map of already read/index KG.
- k_fold_cross_validation(dataset) Tuple[dicee.models.base_model.BaseKGE, str] [source]
Perform K-fold Cross-Validation
Obtain K train and test splits.
- For each split,
2.1 initialize trainer and model 2.2. Train model with configuration provided in args. 2.3. Compute the mean reciprocal rank (MRR) score of the model on the test respective split.
Report the mean and average MRR .
- Parameters:
self
dataset
- Returns:
model