Source code for dicee.trainer.torch_trainer_ddp

import os
import torch
from typing import Iterable
from dicee.abstracts import AbstractTrainer
from torch.utils.data import DataLoader
from tqdm import tqdm

torch.set_float32_matmul_precision('high')

[docs] def make_iterable_verbose(iterable_object, verbose, desc="Default", position=None, leave=True) -> Iterable: if verbose: return tqdm(iterable_object, desc=desc, position=position, leave=leave) else: return iterable_object
[docs] class TorchDDPTrainer(AbstractTrainer): """ A Trainer based on torch.nn.parallel.DistributedDataParallel Arguments ---------- train_set_idx Indexed triples for the training. entity_idxs mapping. relation_idxs mapping. form ? store ? label_smoothing_rate Using hard targets (0,1) drives weights to infinity. An outlier produces enormous gradients. Returns ------- torch.utils.data.Dataset """ def __init__(self, args, callbacks): super().__init__(args, callbacks)
[docs] def fit(self, *args, **kwargs): """ Train model """ assert len(args) == 1 model, = args # (1) Run the fit the start callback. self.on_fit_start(self, model) # (2) Setup DDP. torch.distributed.init_process_group(backend="nccl") train_dataset_loader = kwargs['train_dataloaders'] # (1) Create DATA LOADER. train_dataset_loader = DataLoader(train_dataset_loader.dataset, batch_size=self.attributes.batch_size, pin_memory=True, shuffle=False, num_workers=self.attributes.num_core, persistent_workers=False, collate_fn=kwargs['train_dataloaders'].dataset.collate_fn, sampler=torch.utils.data.distributed.DistributedSampler( train_dataset_loader.dataset)) # (3) Start NodeTrainer. NodeTrainer(self, model, train_dataset_loader, self.callbacks, self.attributes.num_epochs).train() torch.distributed.destroy_process_group() self.on_fit_end(self, model)
[docs] class NodeTrainer: def __init__(self, trainer, model: torch.nn.Module, train_dataset_loader: DataLoader, callbacks, num_epochs: int) -> None: # (1) Trainer. self.trainer = trainer # (2) Local and Global Ranks. self.local_rank = int(os.environ["LOCAL_RANK"]) self.global_rank = int(os.environ["RANK"]) self.optimizer = model.configure_optimizers() # (3) Send model to local trainer. self.train_dataset_loader = train_dataset_loader self.loss_func = model.loss self.callbacks = callbacks self.model = torch.compile(model,mode="reduce-overhead").to(self.local_rank) self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank], output_device=self.local_rank) self.num_epochs = num_epochs self.loss_history = [] def _load_snapshot(self, snapshot_path): raise NotImplementedError def _run_batch(self, source: torch.LongTensor, targets: torch.FloatTensor): """ Forward + Backward + Update over a single batch Parameters ---------- source: targets Returns ------- batch loss """ self.optimizer.zero_grad() output = self.model(source) loss = self.loss_func(output, targets) batch_loss = loss.item() loss.backward() self.optimizer.step() return batch_loss
[docs] def extract_input_outputs(self, z: list): if len(z) == 2: x_batch, y_batch = z return x_batch.to(self.local_rank), y_batch.to(self.local_rank) elif len(z) == 3: x_batch, y_idx_batch, y_batch, = z x_batch, y_idx_batch, y_batch = x_batch.to(self.local_rank), y_idx_batch.to(self.local_rank), y_batch.to( self.local_rank) return (x_batch, y_idx_batch), y_batch else: raise ValueError('Unexpected batch shape..')
def _run_epoch(self, epoch: int) -> float: """ Single pass/iteration over the training dataset Parameters ---------- epoch:int epoch number of the DistributedSampler Returns ------- Average mini batch loss over the training dataset """ self.train_dataset_loader.sampler.set_epoch(epoch) epoch_loss = 0 i = 0 for i, z in enumerate(self.train_dataset_loader): source, targets = self.extract_input_outputs(z) batch_loss = self._run_batch(source, targets) epoch_loss += batch_loss return epoch_loss / (i + 1)
[docs] def train(self): """ Training loop for DDP Returns ------- """ for epoch in (tqdm_bar := make_iterable_verbose(range(self.num_epochs), verbose=self.local_rank == self.global_rank == 0, position=0, leave=True)): self.train_dataset_loader.sampler.set_epoch(epoch) epoch_loss = 0 for i, z in enumerate(self.train_dataset_loader): source, targets = self.extract_input_outputs(z) batch_loss = self._run_batch(source, targets) epoch_loss += batch_loss tqdm_bar.set_description_str(f"Epoch:{epoch + 1}") if i > 0: tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}") else: tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}") avg_epoch_loss = epoch_loss / len(self.train_dataset_loader) if self.local_rank == self.global_rank == 0: self.model.module.loss_history.append(avg_epoch_loss) for c in self.callbacks: c.on_train_epoch_end(self.trainer, self.model.module)