dicee.trainer.auto_batch_finder
Functions
|
Find a batch size that uses GPU memory efficiently. |
Module Contents
- dicee.trainer.auto_batch_finder.find_good_batch_size(train_loader: torch.utils.data.DataLoader, training_step_fn: Callable, device) Tuple[int, float | None][source]
Find a batch size that uses GPU memory efficiently.
Progressively increases batch size (with tunable delta) until GPU memory exceeds 90% or a CUDA OOM error is raised, then backs off to the last safe batch size. Only supported on CUDA devices; returns the initial batch size unchanged on CPU.
- Parameters:
train_loader – DataLoader wrapping the training dataset.
training_step_fn – Callable[[batch], float] — runs one forward+backward pass and returns the scalar loss. Each trainer passes its own implementation so this function stays trainer-agnostic.
device – torch.device (or anything accepted by torch.device()) used to monitor GPU memory.
- Returns:
(batch_size, runtime) where runtime is the wall-clock seconds of the last successful batch, or None when batch finding was skipped.