Source code for dicee.trainer.auto_batch_finder

import time
from typing import Callable, Optional, Tuple

import torch


[docs] def find_good_batch_size( train_loader: torch.utils.data.DataLoader, training_step_fn: Callable, device, ) -> Tuple[int, Optional[float]]: """Find a batch size that uses GPU memory efficiently. Progressively increases batch size (with tunable delta) until GPU memory exceeds 90% or a CUDA OOM error is raised, then backs off to the last safe batch size. Only supported on CUDA devices; returns the initial batch size unchanged on CPU. Args: train_loader: DataLoader wrapping the training dataset. training_step_fn: Callable[[batch], float] — runs one forward+backward pass and returns the scalar loss. Each trainer passes its own implementation so this function stays trainer-agnostic. device: torch.device (or anything accepted by torch.device()) used to monitor GPU memory. Returns: (batch_size, runtime) where runtime is the wall-clock seconds of the last successful batch, or None when batch finding was skipped. """ initial_batch_size = train_loader.batch_size training_dataset_size = len(train_loader.dataset) if initial_batch_size >= training_dataset_size: return training_dataset_size, None # Normalise device if not isinstance(device, torch.device): device = torch.device(device) if device.type != "cuda": print("Auto batch finding requires a CUDA device — skipping.") return initial_batch_size, None print(f"Auto batch finding — training data points: {training_dataset_size}") def _try_increasing(batch_size: int, delta: int): """Increase batch_size until OOM or >90 % GPU memory, return history.""" history = [] try: while True: start = time.time() loader = torch.utils.data.DataLoader( train_loader.dataset, batch_size=batch_size, shuffle=True, num_workers=train_loader.num_workers, collate_fn=train_loader.dataset.collate_fn, pin_memory=False, drop_last=False, persistent_workers=False, ) batch_loss = None for batch in loader: batch_loss = training_step_fn(batch) break free, total = torch.cuda.mem_get_info(device=device) pct_used = (total - free) / total rt = time.time() - start print( f"Batch Loss: {batch_loss:.4f}\t" f"GPU Usage: {pct_used:.3f}\t" f"Runtime: {rt:.3f}s\t" f"Batch Size: {batch_size}" ) history.append((batch_size, rt)) # Stay below 90 % to avoid the illegal-memory-access bug: # https://github.com/pytorch/pytorch/issues/21819 if pct_used > 0.9: return history, False if batch_size < training_dataset_size: batch_size += int(batch_size / delta) else: return history, True except torch.OutOfMemoryError: gpu_mem = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) allocated = torch.cuda.memory_allocated(device) / (1024 ** 3) print( f"CUDA OOM at batch_size={batch_size} " f"({gpu_mem:.2f} GB total, {allocated:.2f} GB allocated)" ) return history, False full_history = [] batch_size = initial_batch_size batch_rt = None for delta in range(1, 5): result, completed = _try_increasing(batch_size, delta=delta) full_history.extend(result) if completed: batch_size, batch_rt = full_history[-1] else: assert len(full_history) > 2, "GPU memory error on the very first batch" batch_size, batch_rt = full_history[-2] break if batch_size >= training_dataset_size: batch_size = training_dataset_size break return batch_size, batch_rt