dicee.trainer.torch_trainer_ddp

Classes

TorchDDPTrainer

Abstract base class for KGE model trainers.

NodeTrainer

Functions

make_iterable_verbose(→ Iterable)

Module Contents

dicee.trainer.torch_trainer_ddp.make_iterable_verbose(iterable_object, verbose, desc='Default', position=None, leave=True) Iterable[source]
class dicee.trainer.torch_trainer_ddp.TorchDDPTrainer(args, callbacks)[source]

Bases: dicee.abstracts.AbstractTrainer

Abstract base class for KGE model trainers.

Provides the callback dispatch mechanism shared by all concrete trainer implementations (TorchTrainer, TorchDDPTrainer, etc.). Sub-classes call the on_* hooks at the appropriate points in the training loop so that any registered AbstractCallback can react.

Parameters:
  • args (argparse.Namespace or similar) – Processed configuration object. Must expose at least random_seed (int).

  • callbacks (list of AbstractCallback) – Ordered list of callback instances to invoke at each lifecycle hook.

fit(*args, **kwargs)[source]
class dicee.trainer.torch_trainer_ddp.NodeTrainer(trainer, model: torch.nn.Module, train_dataset_loader: torch.utils.data.DataLoader, callbacks, num_epochs: int)[source]
trainer
local_rank
global_rank
optimizer
train_dataset_loader
loss_func
callbacks
model
num_epochs
loss_history = []
ctx
scaler
extract_input_outputs(z: list)[source]
train()[source]