dicee.trainer.model_parallelism

Classes

TensorParallel

Abstract base class for KGE model trainers.

Functions

extract_input_outputs(z[, device])

find_good_batch_size(train_loader, tp_ensemble_model)

forward_backward_update_loss(→ float)

Module Contents

dicee.trainer.model_parallelism.extract_input_outputs(z: list, device=None)[source]
dicee.trainer.model_parallelism.find_good_batch_size(train_loader, tp_ensemble_model)[source]
dicee.trainer.model_parallelism.forward_backward_update_loss(z: Tuple, ensemble_model) float[source]
class dicee.trainer.model_parallelism.TensorParallel(args, callbacks)[source]

Bases: dicee.abstracts.AbstractTrainer

Abstract base class for KGE model trainers.

Provides the callback dispatch mechanism shared by all concrete trainer implementations (TorchTrainer, TorchDDPTrainer, etc.). Sub-classes call the on_* hooks at the appropriate points in the training loop so that any registered AbstractCallback can react.

Parameters:
  • args (argparse.Namespace or similar) – Processed configuration object. Must expose at least random_seed (int).

  • callbacks (list of AbstractCallback) – Ordered list of callback instances to invoke at each lifecycle hook.

fit(*args, **kwargs)[source]

Train model