from typing import Tuple
import torch
from ..abstracts import AbstractTrainer
from ..models.ensemble import EnsembleKGE
from ..static_funcs_training import make_iterable_verbose
from .auto_batch_finder import find_good_batch_size
[docs]
def forward_backward_update_loss(z:Tuple, ensemble_model)->float:
# () Get a random batch of data points (z).
x_batch, y_batch = extract_input_outputs(z)
# () Move the batch of labels into the master GPU : GPU-0.
y_batch = y_batch.to("cuda:0")
# () Forward pas on the batch of input data points (yhat on the master GPU).
yhat = ensemble_model(x_batch)
# () Compute the loss.
loss = torch.nn.functional.binary_cross_entropy_with_logits(yhat, y_batch)
# () Compute the gradient of the loss w.r.t. parameters.
loss.backward()
# () Parameter update.
ensemble_model.step()
return loss.item()
[docs]
class TensorParallel(AbstractTrainer):
def __init__(self, args, callbacks):
super().__init__(args, callbacks)
[docs]
def fit(self, *args, **kwargs):
""" Train model """
assert len(args) == 1
ensemble_model, = args
assert isinstance(ensemble_model,EnsembleKGE), (f"Selected model must "
f"be an instance of EnsembleKGE{type(ensemble_model)}")
# () Run on_fit_start callbacks.
self.on_fit_start(self, ensemble_model)
# () Sanity checking
assert torch.cuda.device_count()== len(ensemble_model)
# () Get DataLoader
train_dataloader = kwargs['train_dataloaders']
# () Find a batch size so that available GPU memory is *almost* fully used.
if self.attributes.auto_batch_finding:
def _tp_training_step(batch):
return forward_backward_update_loss(batch, ensemble_model)
batch_size, batch_rt = find_good_batch_size(
train_dataloader, _tp_training_step, device=torch.device("cuda", 0)
)
train_dataloader = torch.utils.data.DataLoader(train_dataloader.dataset,
batch_size=batch_size,
shuffle=True,
sampler=None,
batch_sampler=None,
num_workers=self.attributes.num_core,
collate_fn=train_dataloader.dataset.collate_fn,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
persistent_workers=False)
#if batch_rt is not None:
# expected_training_time=batch_rt * len(train_dataloader) * self.attributes.num_epochs
# print(f"Exp.Training Runtime: {expected_training_time/60 :.3f} in mins\t|\tBatch Size:{batch_size}\t|\tBatch RT:{batch_rt:.3f}\t|\t # of batches:{len(train_dataloader)}\t|\t# of epochs:{self.attributes.num_epochs}")
# () Number of batches to reach a single epoch.
num_of_batches = len(train_dataloader)
# () Start training.
for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs),
verbose=True, position=0, leave=True)):
# () Accumulate the batch losses.
self.on_train_epoch_start(self, ensemble_model)
epoch_loss = 0
# () Iterate over batches.
for i, z in enumerate(train_dataloader):
# () Forward, Loss, Backward, and Update on a given batch of data points.
batch_loss = forward_backward_update_loss(z,ensemble_model)
# () Accumulate the batch losses to compute the epoch loss.
epoch_loss += batch_loss
# if verbose=TRue, show info.
if hasattr(tqdm_bar, 'set_description_str'):
tqdm_bar.set_description_str(f"Epoch:{epoch + 1}")
if i > 0:
tqdm_bar.set_postfix_str(
f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}")
else:
tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}")
# Store the epoch loss
ensemble_model.loss_history.append(epoch_loss)
self.on_train_epoch_end(self, ensemble_model)
# Run on_fit_end callbacks after the training is done.
self.on_fit_end(self, ensemble_model)
# Create and evaluate a combined model from the ensemble model.
#create_and_evaluate_combined_model(self, ensemble_model) # Experimental
# TODO: Later, maybe we should write a callback to save the models in disk
return ensemble_model
"""
def batchwisefit(self, *args, **kwargs):
assert len(args) == 1
model, = args
# (1) Run the fit the start callback.
self.on_fit_start(self, model)
# (2) Setup DDP.
optimizer = model.configure_optimizers()
num_gpus = torch.cuda.device_count()
for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs),
verbose=True, position=0, leave=True)):
epoch_loss = 0
num_of_batches = len(kwargs['train_dataloaders'])
for i, (x_batch, y_batch) in enumerate(kwargs['train_dataloaders']):
# Define a large batch into small batches
x_splits = torch.chunk(x_batch, num_gpus)
y_splits = torch.chunk(y_batch, num_gpus)
# Forward pass. We need to paralelize it
gpu_losses = []
for gpu_id, (x_split, y_split) in enumerate(zip(x_splits, y_splits)):
y_split = y_split.to(f"cuda:{gpu_id}")
h_emb, r_emb, t_emb = model.get_triple_representation(x_split)
h_emb, r_emb, t_emb = h_emb.pin_memory().to(f"cuda:{gpu_id}",
non_blocking=True), r_emb.pin_memory().to(f"cuda:{gpu_id}", non_blocking=True), t_emb.pin_memory().to(f"cuda:{gpu_id}", non_blocking=True)
yhat = model.score(h_emb, r_emb, t_emb)
gpu_losses.append(torch.nn.functional.binary_cross_entropy_with_logits(yhat, y_split).to("cuda:0"))
loss = sum(gpu_losses) / len(gpu_losses)
loss.backward()
batch_loss = loss.item()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
epoch_loss += batch_loss
if hasattr(tqdm_bar, 'set_description_str'):
tqdm_bar.set_description_str(f"Epoch:{epoch + 1}")
if i > 0:
tqdm_bar.set_postfix_str(
f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}")
else:
tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}")
def torch_buggy_fit(self, *args, **kwargs):
assert len(args) == 1
model, = args
# () Run the fit the start callback.
self.on_fit_start(self, model)
# () Init Process Group with NCCL.
torch.distributed.init_process_group(backend="nccl")
# () Get Rank and World Size.
rank = dist.get_rank()
world_size = dist.get_world_size()
# () Reinitialize Rank based on manuel seed rank.
torch.manual_seed(rank)
model.param_init(model.entity_embeddings.weight.data)
model.param_init(model.relation_embeddings.weight.data)
# () .
device = torch.device(f'cuda:{rank}')
model.to(device)
# () .
optimizer = model.configure_optimizers()
# () .
for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs),
verbose=True, position=0, leave=True)):
epoch_loss = 0
num_of_batches = len(kwargs['train_dataloaders'])
# () .
for i, z in enumerate(kwargs['train_dataloaders']):
optimizer.zero_grad()
# () Get batch and move it on GPUs .
inputs,targets = extract_input_outputs(z,device)
# () Predict .
yhats = model(inputs)
# () TODO: Pytorch Bug https://github.com/pytorch/pytorch/issues/58005 .
dist.all_reduce(yhats,op=dist.ReduceOp.SUM)
# () Compute loss .
loss = torch.nn.functional.binary_cross_entropy_with_logits(yhats, targets)
# () Backward .
loss.backward()
# () .
batch_loss = loss.item()
# () .
optimizer.step()
# () .
epoch_loss +=batch_loss
# () .
if rank==0 and hasattr(tqdm_bar, 'set_description_str'):
tqdm_bar.set_description_str(f"Epoch:{epoch + 1}")
if i > 0:
tqdm_bar.set_postfix_str(f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}")
else:
tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}")
# () .
torch.distributed.destroy_process_group()
# () .
self.on_fit_end(self, model)
def create_and_evaluate_combined_model( trainer,ensemble_model):
# Create and evaluate a combined model from the ensemble model
combined_model_args = ensemble_model.models[0].args
combined_entity_embeddings, combined_relation_embeddings = ensemble_model.get_embeddings()
combined_model_args["embedding_dim"] = combined_entity_embeddings.shape[1]
combined_model, form_of_labelling = intialize_model(combined_model_args)
combined_model.entity_embeddings.weight.data = combined_entity_embeddings
combined_model.relation_embeddings.weight.data = combined_relation_embeddings
combined_model.eval()
combined_model.to("cpu")
print(f"Evaluating combined Ensemble of {combined_model_args['model']}")
eval_result = trainer.evaluator.eval(dataset=trainer.dataset,
trained_model=combined_model,
form_of_labelling=form_of_labelling,
during_training=False)
trainer.evaluator.report["combined_model"] = eval_result
torch.save(combined_model.state_dict(), f'{trainer.attributes.full_storage_path}/model.pt')
return
"""