dicee.trainer.auto_batch_finder

Functions

find_good_batch_size(→ Tuple[int, Optional[float]])

Find a batch size that uses GPU memory efficiently.

Module Contents

dicee.trainer.auto_batch_finder.find_good_batch_size(train_loader: torch.utils.data.DataLoader, training_step_fn: Callable, device) Tuple[int, float | None][source]

Find a batch size that uses GPU memory efficiently.

Progressively increases batch size (with tunable delta) until GPU memory exceeds 90% or a CUDA OOM error is raised, then backs off to the last safe batch size. Only supported on CUDA devices; returns the initial batch size unchanged on CPU.

Parameters:
  • train_loader – DataLoader wrapping the training dataset.

  • training_step_fn – Callable[[batch], float] — runs one forward+backward pass and returns the scalar loss. Each trainer passes its own implementation so this function stays trainer-agnostic.

  • device – torch.device (or anything accepted by torch.device()) used to monitor GPU memory.

Returns:

(batch_size, runtime) where runtime is the wall-clock seconds of the last successful batch, or None when batch finding was skipped.