dicee.trainer.model_parallelism

Classes

TensorParallel

Abstract class for Trainer class for knowledge graph embedding models

Functions

extract_input_outputs(z[, device])

find_good_batch_size(train_loader, tp_ensemble_model)

forward_backward_update_loss(→ float)

Module Contents

dicee.trainer.model_parallelism.extract_input_outputs(z: list, device=None)
dicee.trainer.model_parallelism.find_good_batch_size(train_loader, tp_ensemble_model)
dicee.trainer.model_parallelism.forward_backward_update_loss(z: Tuple, ensemble_model) float
class dicee.trainer.model_parallelism.TensorParallel(args, callbacks)

Bases: dicee.abstracts.AbstractTrainer

Abstract class for Trainer class for knowledge graph embedding models

Parameter

argsstr

?

callbacks: list

?

fit(*args, **kwargs)

Train model