dicee.trainer.model_parallelism
Classes
Abstract base class for KGE model trainers. |
Functions
|
|
|
|
|
Module Contents
- dicee.trainer.model_parallelism.forward_backward_update_loss(z: Tuple, ensemble_model) float[source]
- class dicee.trainer.model_parallelism.TensorParallel(args, callbacks)[source]
Bases:
dicee.abstracts.AbstractTrainerAbstract base class for KGE model trainers.
Provides the callback dispatch mechanism shared by all concrete trainer implementations (TorchTrainer, TorchDDPTrainer, etc.). Sub-classes call the
on_*hooks at the appropriate points in the training loop so that any registeredAbstractCallbackcan react.- Parameters:
args (argparse.Namespace or similar) – Processed configuration object. Must expose at least
random_seed(int).callbacks (list of AbstractCallback) – Ordered list of callback instances to invoke at each lifecycle hook.