Source code for dicee.trainer.model_parallelism

import torch
from ..abstracts import AbstractTrainer
from ..static_funcs_training import make_iterable_verbose
from ..models.ensemble import EnsembleKGE
from typing import Tuple
import time

[docs] def extract_input_outputs(z: list, device=None): # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) if len(z) == 2: x_batch, y_batch = z # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) if device: x_batch, y_batch = x_batch.to(device, non_blocking=True), y_batch.pin_memory().to(device, non_blocking=True) return x_batch, y_batch elif len(z) == 3: x_batch, y_idx_batch, y_batch, = z if device: x_batch, y_batch, y_idx_batch = x_batch.pin_memory().to(device, non_blocking=True), y_batch.pin_memory().to( device, non_blocking=True), y_idx_batch.pin_memory().to(device, non_blocking=True) return (x_batch, y_idx_batch), y_batch else: raise ValueError('Unexpected batch shape..')
[docs] def find_good_batch_size(train_loader,tp_ensemble_model): # () Initial batch size. initial_batch_size=train_loader.batch_size # () # of training data points. training_dataset_size=len(train_loader.dataset) # () Batch is large enough. if initial_batch_size >= training_dataset_size: return training_dataset_size, None # () Log the number of training data points. print("Number of training data points:",training_dataset_size) def increase_batch_size_until_cuda_out_of_memory(ensemble_model, train_loader, batch_size,delta: int = None): assert delta is not None, "delta cannot be None." assert isinstance(delta, int), "delta must be a positive integer." # () Store the batch sizes and GPU memory usages in a tuple. batch_sizes_and_mem_usages = [] # () Increase the batch size until a stopping criterion is reached. try: while True: start_time=time.time() # () Initialize a dataloader with a current batch_size train_dataloaders = torch.utils.data.DataLoader(train_loader.dataset, batch_size=batch_size, shuffle=True, sampler=None, batch_sampler=None, num_workers=train_loader.num_workers, collate_fn=train_loader.dataset.collate_fn, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, persistent_workers=False) batch_loss = None for i, batch_of_training_data in enumerate(train_dataloaders): batch_loss = forward_backward_update_loss(batch_of_training_data, ensemble_model) break global_free_memory, total_memory = torch.cuda.mem_get_info(device="cuda:0") percentage_used_gpu_memory = (total_memory - global_free_memory) / total_memory rt=time.time()-start_time print(f"Random Batch Loss: {batch_loss:0.4}\tGPU Usage: {percentage_used_gpu_memory:0.3}\tRuntime: {rt:.3f}\tBatch Size: {batch_size}") # Store the batch size and the runtime batch_sizes_and_mem_usages.append((batch_size, rt)) # () # https://github.com/pytorch/pytorch/issues/21819 # CD: as we reach close to 1.0 GPU memory usage, we observe RuntimeError: CUDA error: an illegal memory access was encountered. # CD: To avoid this problem, we add the following condition as a temp solution. if percentage_used_gpu_memory > 0.9: # Mimik out of memory error return batch_sizes_and_mem_usages, False if batch_size < training_dataset_size: # Increase the batch size. batch_size += int(batch_size / delta) else: return batch_sizes_and_mem_usages,True except torch.OutOfMemoryError as e: print(f"torch.OutOfMemoryError caught! {e}\n\n") return batch_sizes_and_mem_usages, False history_batch_sizes_and_mem_usages=[] batch_size=initial_batch_size for delta in range(1,5,1): result,flag= increase_batch_size_until_cuda_out_of_memory(tp_ensemble_model, train_loader, batch_size,delta=delta) history_batch_sizes_and_mem_usages.extend(result) if flag: batch_size, batch_rt = history_batch_sizes_and_mem_usages[-1] else: assert len(history_batch_sizes_and_mem_usages)>2, "GPU memory errorin the first try" # CUDA ERROR Observed batch_size, batch_rt=history_batch_sizes_and_mem_usages[-2] # https://github.com/pytorch/pytorch/issues/21819 break if batch_size>=training_dataset_size: batch_size=training_dataset_size break else: continue return batch_size, batch_rt
[docs] def forward_backward_update_loss(z:Tuple, ensemble_model)->float: # () Get a random batch of data points (z). x_batch, y_batch = extract_input_outputs(z) # () Move the batch of labels into the master GPU : GPU-0. y_batch = y_batch.to("cuda:0") # () Forward pas on the batch of input data points (yhat on the master GPU). yhat = ensemble_model(x_batch) # () Compute the loss. loss = torch.nn.functional.binary_cross_entropy_with_logits(yhat, y_batch) # () Compute the gradient of the loss w.r.t. parameters. loss.backward() # () Parameter update. ensemble_model.step() return loss.item()
[docs] class TensorParallel(AbstractTrainer): def __init__(self, args, callbacks): super().__init__(args, callbacks)
[docs] def fit(self, *args, **kwargs): """ Train model """ assert len(args) == 1 ensemble_model, = args assert isinstance(ensemble_model,EnsembleKGE), (f"Selected model must " f"be an instance of EnsembleKGE{type(ensemble_model)}") # () Run on_fit_start callbacks. self.on_fit_start(self, ensemble_model) # () Sanity checking assert torch.cuda.device_count()== len(ensemble_model) # () Get DataLoader train_dataloader = kwargs['train_dataloaders'] # () Find a batch size so that available GPU memory is *almost* fully used. if self.attributes.auto_batch_finding: batch_size, batch_rt=find_good_batch_size(train_dataloader, ensemble_model) train_dataloader = torch.utils.data.DataLoader(train_dataloader.dataset, batch_size=batch_size, shuffle=True, sampler=None, batch_sampler=None, num_workers=self.attributes.num_core, collate_fn=train_dataloader.dataset.collate_fn, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, persistent_workers=False) #if batch_rt is not None: # expected_training_time=batch_rt * len(train_dataloader) * self.attributes.num_epochs # print(f"Exp.Training Runtime: {expected_training_time/60 :.3f} in mins\t|\tBatch Size:{batch_size}\t|\tBatch RT:{batch_rt:.3f}\t|\t # of batches:{len(train_dataloader)}\t|\t# of epochs:{self.attributes.num_epochs}") # () Number of batches to reach a single epoch. num_of_batches = len(train_dataloader) # () Start training. for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs), verbose=True, position=0, leave=True)): # () Accumulate the batch losses. epoch_loss = 0 # () Iterate over batches. for i, z in enumerate(train_dataloader): # () Forward, Loss, Backward, and Update on a given batch of data points. batch_loss = forward_backward_update_loss(z,ensemble_model) # () Accumulate the batch losses to compute the epoch loss. epoch_loss += batch_loss # if verbose=TRue, show info. if hasattr(tqdm_bar, 'set_description_str'): tqdm_bar.set_description_str(f"Epoch:{epoch + 1}") if i > 0: tqdm_bar.set_postfix_str( f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}") else: tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}") # Store the epoch loss ensemble_model.loss_history.append(epoch_loss) # Run on_fit_end callbacks after the training is done. self.on_fit_end(self, ensemble_model) # TODO: Later, maybe we should write a callback to save the models in disk return ensemble_model
""" def batchwisefit(self, *args, **kwargs): assert len(args) == 1 model, = args # (1) Run the fit the start callback. self.on_fit_start(self, model) # (2) Setup DDP. optimizer = model.configure_optimizers() num_gpus = torch.cuda.device_count() for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs), verbose=True, position=0, leave=True)): epoch_loss = 0 num_of_batches = len(kwargs['train_dataloaders']) for i, (x_batch, y_batch) in enumerate(kwargs['train_dataloaders']): # Define a large batch into small batches x_splits = torch.chunk(x_batch, num_gpus) y_splits = torch.chunk(y_batch, num_gpus) # Forward pass. We need to paralelize it gpu_losses = [] for gpu_id, (x_split, y_split) in enumerate(zip(x_splits, y_splits)): y_split = y_split.to(f"cuda:{gpu_id}") h_emb, r_emb, t_emb = model.get_triple_representation(x_split) h_emb, r_emb, t_emb = h_emb.pin_memory().to(f"cuda:{gpu_id}", non_blocking=True), r_emb.pin_memory().to(f"cuda:{gpu_id}", non_blocking=True), t_emb.pin_memory().to(f"cuda:{gpu_id}", non_blocking=True) yhat = model.score(h_emb, r_emb, t_emb) gpu_losses.append(torch.nn.functional.binary_cross_entropy_with_logits(yhat, y_split).to("cuda:0")) loss = sum(gpu_losses) / len(gpu_losses) loss.backward() batch_loss = loss.item() optimizer.step() optimizer.zero_grad(set_to_none=True) epoch_loss += batch_loss if hasattr(tqdm_bar, 'set_description_str'): tqdm_bar.set_description_str(f"Epoch:{epoch + 1}") if i > 0: tqdm_bar.set_postfix_str( f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}") else: tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}") def torch_buggy_fit(self, *args, **kwargs): assert len(args) == 1 model, = args # () Run the fit the start callback. self.on_fit_start(self, model) # () Init Process Group with NCCL. torch.distributed.init_process_group(backend="nccl") # () Get Rank and World Size. rank = dist.get_rank() world_size = dist.get_world_size() # () Reinitialize Rank based on manuel seed rank. torch.manual_seed(rank) model.param_init(model.entity_embeddings.weight.data) model.param_init(model.relation_embeddings.weight.data) # () . device = torch.device(f'cuda:{rank}') model.to(device) # () . optimizer = model.configure_optimizers() # () . for epoch in (tqdm_bar := make_iterable_verbose(range(self.attributes.num_epochs), verbose=True, position=0, leave=True)): epoch_loss = 0 num_of_batches = len(kwargs['train_dataloaders']) # () . for i, z in enumerate(kwargs['train_dataloaders']): optimizer.zero_grad() # () Get batch and move it on GPUs . inputs,targets = extract_input_outputs(z,device) # () Predict . yhats = model(inputs) # () TODO: Pytorch Bug https://github.com/pytorch/pytorch/issues/58005 . dist.all_reduce(yhats,op=dist.ReduceOp.SUM) # () Compute loss . loss = torch.nn.functional.binary_cross_entropy_with_logits(yhats, targets) # () Backward . loss.backward() # () . batch_loss = loss.item() # () . optimizer.step() # () . epoch_loss +=batch_loss # () . if rank==0 and hasattr(tqdm_bar, 'set_description_str'): tqdm_bar.set_description_str(f"Epoch:{epoch + 1}") if i > 0: tqdm_bar.set_postfix_str(f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}") else: tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}") # () . torch.distributed.destroy_process_group() # () . self.on_fit_end(self, model) """