dicee.trainer.model_parallelism

Classes

TensorParallel

Abstract class for Trainer class for knowledge graph embedding models

Functions

extract_input_outputs(z[, device])

find_good_batch_size(train_loader, tp_ensemble_model)

forward_backward_update_loss(→ float)

Module Contents

dicee.trainer.model_parallelism.extract_input_outputs(z: list, device=None)[source]
dicee.trainer.model_parallelism.find_good_batch_size(train_loader, tp_ensemble_model)[source]
dicee.trainer.model_parallelism.forward_backward_update_loss(z: Tuple, ensemble_model) float[source]
class dicee.trainer.model_parallelism.TensorParallel(args, callbacks)[source]

Bases: dicee.abstracts.AbstractTrainer

Abstract class for Trainer class for knowledge graph embedding models

Parameter

argsstr

?

callbacks: list

?

fit(*args, **kwargs)[source]

Train model