dicee.trainer.auto_batch_finder =============================== .. py:module:: dicee.trainer.auto_batch_finder Functions --------- .. autoapisummary:: dicee.trainer.auto_batch_finder.find_good_batch_size Module Contents --------------- .. py:function:: find_good_batch_size(train_loader: torch.utils.data.DataLoader, training_step_fn: Callable, device) -> Tuple[int, Optional[float]] Find a batch size that uses GPU memory efficiently. Progressively increases batch size (with tunable delta) until GPU memory exceeds 90% or a CUDA OOM error is raised, then backs off to the last safe batch size. Only supported on CUDA devices; returns the initial batch size unchanged on CPU. :param train_loader: DataLoader wrapping the training dataset. :param training_step_fn: Callable[[batch], float] — runs one forward+backward pass and returns the scalar loss. Each trainer passes its own implementation so this function stays trainer-agnostic. :param device: torch.device (or anything accepted by torch.device()) used to monitor GPU memory. :returns: (batch_size, runtime) where runtime is the wall-clock seconds of the last successful batch, or None when batch finding was skipped.