Source code for dicee.trainer.model_parallelism

from typing import Tuple

import torch

from ..abstracts import AbstractTrainer
from ..models.ensemble import EnsembleKGE
from ..static_funcs_training import make_iterable_verbose
from .auto_batch_finder import find_good_batch_size


[docs] def extract_input_outputs(z: list, device=None): # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) if len(z) == 2: x_batch, y_batch = z # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) if device: x_batch, y_batch = x_batch.to(device, non_blocking=True), y_batch.pin_memory().to(device, non_blocking=True) return x_batch, y_batch elif len(z) == 3: x_batch, y_idx_batch, y_batch, = z if device: x_batch, y_batch, y_idx_batch = x_batch.pin_memory().to(device, non_blocking=True), y_batch.pin_memory().to( device, non_blocking=True), y_idx_batch.pin_memory().to(device, non_blocking=True) return (x_batch, y_idx_batch), y_batch else: raise ValueError('Unexpected batch shape..')
[docs] def forward_backward_update_loss(z:Tuple, ensemble_model)->float: # () Get a random batch of data points (z). x_batch, y_batch = extract_input_outputs(z) # () Move the batch of labels into the master GPU : GPU-0. y_batch = y_batch.to("cuda:0") # () Forward pas on the batch of input data points (yhat on the master GPU). yhat = ensemble_model(x_batch) # () Compute the loss. loss = torch.nn.functional.binary_cross_entropy_with_logits(yhat, y_batch) # () Compute the gradient of the loss w.r.t. parameters. loss.backward() # () Parameter update. ensemble_model.step() return loss.item()
[docs] class TensorParallel(AbstractTrainer): def __init__(self, args, callbacks): super().__init__(args, callbacks)
[docs] def fit(self, *args, **kwargs): """ Train model """ assert len(args) == 1 ensemble_model, = args assert isinstance(ensemble_model,EnsembleKGE), (f"Selected model must " f"be an instance of EnsembleKGE{type(ensemble_model)}") # () Run on_fit_start callbacks. self.on_fit_start(self, ensemble_model) # () Sanity checking assert torch.cuda.device_count()== len(ensemble_model) # () Get DataLoader train_dataloader = kwargs['train_dataloaders'] # () Find a batch size so that available GPU memory is *almost* fully used. if self.attributes.auto_batch_finding: def _tp_training_step(batch): return forward_backward_update_loss(batch, ensemble_model) batch_size, batch_rt = find_good_batch_size( train_dataloader, _tp_training_step, device=torch.device("cuda", 0) ) train_dataloader = torch.utils.data.DataLoader(train_dataloader.dataset, batch_size=batch_size, shuffle=True, sampler=None, batch_sampler=None, num_workers=self.attributes.num_core, collate_fn=train_dataloader.dataset.collate_fn, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, persistent_workers=False) #if batch_rt is not None: # expected_training_time=batch_rt * len(train_dataloader) * self.attributes.num_epochs # print(f"Exp.Training Runtime: {expected_training_time/60 :.3f} in mins\t|\tBatch Size:{batch_size}\t|\tBatch RT:{batch_rt:.3f}\t|\t # of batches:{len(train_dataloader)}\t|\t# of epochs:{self.attributes.num_epochs}") # () Number of batches to reach a single epoch. num_of_batches = len(train_dataloader) # () Start training. for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs), verbose=True, position=0, leave=True)): # () Accumulate the batch losses. self.on_train_epoch_start(self, ensemble_model) epoch_loss = 0 # () Iterate over batches. for i, z in enumerate(train_dataloader): # () Forward, Loss, Backward, and Update on a given batch of data points. batch_loss = forward_backward_update_loss(z,ensemble_model) # () Accumulate the batch losses to compute the epoch loss. epoch_loss += batch_loss # if verbose=TRue, show info. if hasattr(tqdm_bar, 'set_description_str'): tqdm_bar.set_description_str(f"Epoch:{epoch + 1}") if i > 0: tqdm_bar.set_postfix_str( f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}") else: tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}") # Store the epoch loss ensemble_model.loss_history.append(epoch_loss) self.on_train_epoch_end(self, ensemble_model) # Run on_fit_end callbacks after the training is done. self.on_fit_end(self, ensemble_model) # Create and evaluate a combined model from the ensemble model. #create_and_evaluate_combined_model(self, ensemble_model) # Experimental # TODO: Later, maybe we should write a callback to save the models in disk return ensemble_model
""" def batchwisefit(self, *args, **kwargs): assert len(args) == 1 model, = args # (1) Run the fit the start callback. self.on_fit_start(self, model) # (2) Setup DDP. optimizer = model.configure_optimizers() num_gpus = torch.cuda.device_count() for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs), verbose=True, position=0, leave=True)): epoch_loss = 0 num_of_batches = len(kwargs['train_dataloaders']) for i, (x_batch, y_batch) in enumerate(kwargs['train_dataloaders']): # Define a large batch into small batches x_splits = torch.chunk(x_batch, num_gpus) y_splits = torch.chunk(y_batch, num_gpus) # Forward pass. We need to paralelize it gpu_losses = [] for gpu_id, (x_split, y_split) in enumerate(zip(x_splits, y_splits)): y_split = y_split.to(f"cuda:{gpu_id}") h_emb, r_emb, t_emb = model.get_triple_representation(x_split) h_emb, r_emb, t_emb = h_emb.pin_memory().to(f"cuda:{gpu_id}", non_blocking=True), r_emb.pin_memory().to(f"cuda:{gpu_id}", non_blocking=True), t_emb.pin_memory().to(f"cuda:{gpu_id}", non_blocking=True) yhat = model.score(h_emb, r_emb, t_emb) gpu_losses.append(torch.nn.functional.binary_cross_entropy_with_logits(yhat, y_split).to("cuda:0")) loss = sum(gpu_losses) / len(gpu_losses) loss.backward() batch_loss = loss.item() optimizer.step() optimizer.zero_grad(set_to_none=True) epoch_loss += batch_loss if hasattr(tqdm_bar, 'set_description_str'): tqdm_bar.set_description_str(f"Epoch:{epoch + 1}") if i > 0: tqdm_bar.set_postfix_str( f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}") else: tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}") def torch_buggy_fit(self, *args, **kwargs): assert len(args) == 1 model, = args # () Run the fit the start callback. self.on_fit_start(self, model) # () Init Process Group with NCCL. torch.distributed.init_process_group(backend="nccl") # () Get Rank and World Size. rank = dist.get_rank() world_size = dist.get_world_size() # () Reinitialize Rank based on manuel seed rank. torch.manual_seed(rank) model.param_init(model.entity_embeddings.weight.data) model.param_init(model.relation_embeddings.weight.data) # () . device = torch.device(f'cuda:{rank}') model.to(device) # () . optimizer = model.configure_optimizers() # () . for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs), verbose=True, position=0, leave=True)): epoch_loss = 0 num_of_batches = len(kwargs['train_dataloaders']) # () . for i, z in enumerate(kwargs['train_dataloaders']): optimizer.zero_grad() # () Get batch and move it on GPUs . inputs,targets = extract_input_outputs(z,device) # () Predict . yhats = model(inputs) # () TODO: Pytorch Bug https://github.com/pytorch/pytorch/issues/58005 . dist.all_reduce(yhats,op=dist.ReduceOp.SUM) # () Compute loss . loss = torch.nn.functional.binary_cross_entropy_with_logits(yhats, targets) # () Backward . loss.backward() # () . batch_loss = loss.item() # () . optimizer.step() # () . epoch_loss +=batch_loss # () . if rank==0 and hasattr(tqdm_bar, 'set_description_str'): tqdm_bar.set_description_str(f"Epoch:{epoch + 1}") if i > 0: tqdm_bar.set_postfix_str(f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}") else: tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}") # () . torch.distributed.destroy_process_group() # () . self.on_fit_end(self, model) def create_and_evaluate_combined_model( trainer,ensemble_model): # Create and evaluate a combined model from the ensemble model combined_model_args = ensemble_model.models[0].args combined_entity_embeddings, combined_relation_embeddings = ensemble_model.get_embeddings() combined_model_args["embedding_dim"] = combined_entity_embeddings.shape[1] combined_model, form_of_labelling = intialize_model(combined_model_args) combined_model.entity_embeddings.weight.data = combined_entity_embeddings combined_model.relation_embeddings.weight.data = combined_relation_embeddings combined_model.eval() combined_model.to("cpu") print(f"Evaluating combined Ensemble of {combined_model_args['model']}") eval_result = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=combined_model, form_of_labelling=form_of_labelling, during_training=False) trainer.evaluator.report["combined_model"] = eval_result torch.save(combined_model.state_dict(), f'{trainer.attributes.full_storage_path}/model.pt') return """