dicee.trainer
Submodules
Classes
DICE_Trainer implement |
Package Contents
- class dicee.trainer.DICE_Trainer(args, is_continual_training, storage_path, evaluator=None)[source]
- DICE_Trainer implement
1- Pytorch Lightning trainer (https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) 2- Multi-GPU Trainer(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) 3- CPU Trainer
args
is_continual_training:bool
storage_path:str
evaluator:
report:dict
- report
- args
- trainer = None
- is_continual_training
- storage_path
- evaluator
- form_of_labelling = None
- continual_start()[source]
Initialize training.
Load model
(3) Load trainer (3) Fit model
Parameter
- returns:
model
form_of_labelling (str)
- initialize_trainer(callbacks: List) lightning.Trainer [source]
Initialize Trainer from input arguments
- initialize_dataset(dataset: dicee.knowledge_graph.KG, form_of_labelling) torch.utils.data.Dataset [source]
- start(knowledge_graph: dicee.knowledge_graph.KG) Tuple[dicee.models.base_model.BaseKGE, str] [source]
Train selected model via the selected training strategy
- k_fold_cross_validation(dataset) Tuple[dicee.models.base_model.BaseKGE, str] [source]
Perform K-fold Cross-Validation
Obtain K train and test splits.
- For each split,
2.1 initialize trainer and model 2.2. Train model with configuration provided in args. 2.3. Compute the mean reciprocal rank (MRR) score of the model on the test respective split.
Report the mean and average MRR .
- Parameters:
self
dataset
- Returns:
model