dicee.trainer.torch_trainer_ddp

Classes

TorchDDPTrainer

A Trainer based on torch.nn.parallel.DistributedDataParallel

NodeTrainer

Functions

make_iterable_verbose(→ Iterable)

Module Contents

dicee.trainer.torch_trainer_ddp.make_iterable_verbose(iterable_object, verbose, desc='Default', position=None, leave=True) Iterable[source]
class dicee.trainer.torch_trainer_ddp.TorchDDPTrainer(args, callbacks)[source]

Bases: dicee.abstracts.AbstractTrainer

A Trainer based on torch.nn.parallel.DistributedDataParallel

Arguments

entity_idxs

mapping.

relation_idxs

mapping.

form

?

store

?

label_smoothing_rate

Using hard targets (0,1) drives weights to infinity. An outlier produces enormous gradients.

Return type:

torch.utils.data.Dataset

fit(*args, **kwargs)[source]

Train model

class dicee.trainer.torch_trainer_ddp.NodeTrainer(trainer, model: torch.nn.Module, train_dataset_loader: torch.utils.data.DataLoader, callbacks, num_epochs: int)[source]
trainer
local_rank
global_rank
optimizer
train_dataset_loader
loss_func
callbacks
model
num_epochs
loss_history = []
extract_input_outputs(z: list)[source]
train()[source]

Training loop for DDP