dicee.models.adopt
ADOPT Optimizer Implementation.
This module implements the ADOPT (Adaptive Optimization with Precise Tracking) algorithm, an advanced optimization method for training neural networks.
ADOPT Overview:
ADOPT is an adaptive learning rate optimization algorithm that combines the benefits of momentum-based methods with per-parameter learning rate adaptation. Unlike Adam, which applies momentum to raw gradients, ADOPT normalizes gradients first and then applies momentum, leading to more stable training dynamics.
Key Features: - Gradient normalization before momentum application - Adaptive per-parameter learning rates - Optional gradient clipping that grows with training steps - Support for decoupled weight decay (AdamW-style) - Multiple execution modes: single-tensor, multi-tensor (foreach), and fused (planned)
Algorithm Comparison:
Adam: m = β₁*m + (1-β₁)*g, θ = θ - α*m/√v ADOPT: m = β₁*m + (1-β₁)*g/√v, θ = θ - α*m
The key difference is that ADOPT normalizes gradients before momentum, which provides better stability and can lead to improved convergence.
Classes:
ADOPT: Main optimizer class (extends torch.optim.Optimizer)
Functions:
adopt: Functional API for ADOPT algorithm computation
_single_tensor_adopt: Single-tensor implementation (TorchScript compatible)
_multi_tensor_adopt: Multi-tensor implementation using foreach operations
Performance:
Single-tensor: Default, compatible with torch.jit.script
Multi-tensor (foreach): 2-3x faster on GPU through vectorization
Fused (planned): Would provide maximum performance via specialized kernels
Example:
>>> import torch
>>> from dicee.models.adopt import ADOPT
>>>
>>> model = torch.nn.Linear(10, 1)
>>> optimizer = ADOPT(model.parameters(), lr=0.001, weight_decay=0.01, decouple=True)
>>>
>>> # Training loop
>>> for epoch in range(num_epochs):
... optimizer.zero_grad()
... output = model(input)
... loss = criterion(output, target)
... loss.backward()
... optimizer.step()
References:
Original implementation: https://github.com/iShohei220/adopt
Notes:
This implementation is based on the original ADOPT implementation and adapted to work with the PyTorch optimizer interface and the dicee framework.
Classes
ADOPT Optimizer. |
Functions
|
Functional API that performs ADOPT algorithm computation. |
Module Contents
- class dicee.models.adopt.ADOPT(params: torch.optim.optimizer.ParamsT, lr: float | torch.Tensor = 0.001, betas: Tuple[float, float] = (0.9, 0.9999), eps: float = 1e-06, clip_lambda: Callable[[int], float] | None = lambda step: ..., weight_decay: float = 0.0, decouple: bool = False, *, foreach: bool | None = None, maximize: bool = False, capturable: bool = False, differentiable: bool = False, fused: bool | None = None)[source]
Bases:
torch.optim.optimizer.OptimizerADOPT Optimizer.
ADOPT is an adaptive learning rate optimization algorithm that combines momentum-based updates with adaptive per-parameter learning rates. It uses exponential moving averages of gradients and squared gradients, with gradient clipping for stability.
The algorithm performs the following key operations: 1. Normalizes gradients by the square root of the second moment estimate 2. Applies optional gradient clipping based on the training step 3. Updates parameters using momentum-smoothed normalized gradients 4. Supports decoupled weight decay (AdamW-style) or L2 regularization
- Mathematical formulation:
m_t = β₁ * m_{t-1} + (1 - β₁) * clip(g_t / √(v_t)) v_t = β₂ * v_{t-1} + (1 - β₂) * g_t² θ_t = θ_{t-1} - α * m_t
- where:
θ_t: parameter at step t
g_t: gradient at step t
m_t: first moment estimate (momentum)
v_t: second moment estimate (variance)
α: learning rate
β₁, β₂: exponential decay rates
clip(): optional gradient clipping function
- Reference:
Original implementation: https://github.com/iShohei220/adopt
- Parameters:
params (ParamsT) – Iterable of parameters to optimize or dicts defining parameter groups.
lr (float or Tensor, optional) – Learning rate. Can be a float or 1-element Tensor. Default: 1e-3
betas (Tuple[float, float], optional) – Coefficients (β₁, β₂) for computing running averages of gradient and its square. β₁ controls momentum, β₂ controls variance. Default: (0.9, 0.9999)
eps (float, optional) – Term added to denominator to improve numerical stability. Default: 1e-6
clip_lambda (Callable[[int], float], optional) – Function that takes the step number and returns the gradient clipping threshold. Common choices: - lambda step: step**0.25 (default, gradually increases clipping threshold) - lambda step: 1.0 (constant clipping) - None (no clipping) Default: lambda step: step**0.25
weight_decay (float, optional) – Weight decay coefficient (L2 penalty). Default: 0.0
decouple (bool, optional) – If True, uses decoupled weight decay (AdamW-style), applying weight decay directly to parameters. If False, adds weight decay to gradients (L2 regularization). Default: False
foreach (bool, optional) – If True, uses the faster foreach implementation for multi-tensor operations. Default: None (auto-select)
maximize (bool, optional) – If True, maximizes parameters instead of minimizing. Useful for reinforcement learning. Default: False
capturable (bool, optional) – If True, the optimizer is safe to capture in a CUDA graph. Requires learning rate as Tensor. Default: False
differentiable (bool, optional) – If True, the optimization step can be differentiated. Useful for meta-learning. Default: False
fused (bool, optional) – If True, uses fused kernel implementation (currently not supported). Default: None
- Raises:
ValueError – If learning rate, epsilon, betas, or weight_decay are invalid.
RuntimeError – If fused is enabled (not currently supported).
RuntimeError – If lr is a Tensor with foreach=True and capturable=False.
Example
>>> # Basic usage >>> optimizer = ADOPT(model.parameters(), lr=0.001) >>> optimizer.zero_grad() >>> loss.backward() >>> optimizer.step()
>>> # With decoupled weight decay >>> optimizer = ADOPT(model.parameters(), lr=0.001, weight_decay=0.01, decouple=True)
>>> # Custom gradient clipping >>> optimizer = ADOPT(model.parameters(), clip_lambda=lambda step: max(1.0, step**0.5))
Note
For most use cases, the default hyperparameters work well
Consider using decouple=True for better generalization (similar to AdamW)
The clip_lambda function helps stabilize training in early steps
- clip_lambda
- __setstate__(state)[source]
Restore optimizer state from a checkpoint.
This method handles backward compatibility when loading optimizer state from older versions. It ensures all required fields are present with default values and properly converts step counters to tensors if needed.
Key responsibilities: 1. Set default values for newly added hyperparameters 2. Convert old-style scalar step counters to tensor format 3. Place step tensors on appropriate devices based on capturable/fused modes
- Parameters:
state (dict) – Optimizer state dictionary (typically from torch.load()).
Note
This enables loading checkpoints saved with older ADOPT versions
Step counters are converted to appropriate device/dtype for compatibility
Capturable and fused modes require step tensors on parameter devices
- step(closure=None)[source]
Perform a single optimization step.
This method executes one iteration of the ADOPT optimization algorithm across all parameter groups. It orchestrates the following workflow:
Optionally evaluates a closure to recompute the loss (useful for algorithms like LBFGS or when loss needs multiple evaluations)
For each parameter group: - Collects parameters with gradients and their associated state - Extracts hyperparameters (betas, learning rate, etc.) - Calls the functional adopt() API to perform the actual update
Returns the loss value if a closure was provided
The functional API (adopt()) handles three execution modes: - Single-tensor: Updates one parameter at a time (default, JIT-compatible) - Multi-tensor (foreach): Batches operations for better performance - Fused: Uses fused CUDA kernels (not yet implemented)
Gradient scaling support: This method is compatible with automatic mixed precision (AMP) training. It can access grad_scale and found_inf attributes for gradient unscaling and inf/nan detection when used with GradScaler.
- Parameters:
closure (Callable, optional) – A callable that reevaluates the model and returns the loss. The closure should: - Enable gradients (torch.enable_grad()) - Compute forward pass - Compute loss - Compute backward pass - Return the loss value Example: lambda: (loss := model(x), loss.backward(), loss)[-1] Default: None
- Returns:
- The loss value returned by the closure, or None if no
closure was provided.
- Return type:
Optional[Tensor]
Example
>>> # Standard usage >>> loss = criterion(model(input), target) >>> loss.backward() >>> optimizer.step()
>>> # With closure (e.g., for line search) >>> def closure(): ... optimizer.zero_grad() ... output = model(input) ... loss = criterion(output, target) ... loss.backward() ... return loss >>> loss = optimizer.step(closure)
Note
Call zero_grad() before computing gradients for the next step
CUDA graph capture is checked for safety when capturable=True
The method is thread-safe for different parameter groups
- dicee.models.adopt.adopt(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[torch.Tensor], exp_avg_sqs: List[torch.Tensor], state_steps: List[torch.Tensor], foreach: bool | None = None, capturable: bool = False, differentiable: bool = False, fused: bool | None = None, grad_scale: torch.Tensor | None = None, found_inf: torch.Tensor | None = None, has_complex: bool = False, *, beta1: float, beta2: float, lr: float | torch.Tensor, clip_lambda: Callable[[int], float] | None, weight_decay: float, decouple: bool, eps: float, maximize: bool)[source]
Functional API that performs ADOPT algorithm computation.
This is the main functional interface for the ADOPT optimization algorithm. It dispatches to one of three implementations based on the execution mode:
Single-tensor mode (default): Updates parameters one at a time - Compatible with torch.jit.script - More flexible but slower - Used when foreach=False or automatically for small models
Multi-tensor (foreach) mode: Batches operations across tensors - 2-3x faster on GPU through vectorization - Groups tensors by device/dtype automatically - Used when foreach=True
Fused mode: Uses specialized fused kernels (not yet implemented) - Would provide maximum performance - Currently raises RuntimeError if enabled
Algorithm overview (ADOPT):
ADOPT adapts learning rates per-parameter while using momentum on normalized gradients. The key innovation is normalizing gradients before momentum, which provides more stable training than standard Adam.
- Mathematical formulation:
# Normalize gradient by its historical variance normed_g_t = g_t / √(v_t + ε)
# Optional gradient clipping for stability normed_g_t = clip(normed_g_t, threshold(t))
# Momentum on normalized gradients (key difference from Adam) m_t = β₁ * m_{t-1} + (1 - β₁) * normed_g_t
# Parameter update θ_t = θ_{t-1} - α * m_t
# Update variance estimate v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
- where:
θ: parameters
g: gradients
m: first moment (momentum of normalized gradients)
v: second moment (variance of raw gradients)
α: learning rate
β₁, β₂: exponential decay rates
ε: numerical stability constant
clip(): gradient clipping function based on step
Automatic mode selection:
When foreach and fused are both None (default), the function automatically selects the best implementation based on: - Parameter types and devices - Whether differentiable mode is enabled - Learning rate type (float vs Tensor) - Capturable mode requirements
- param params:
Parameters to optimize.
- type params:
List[Tensor]
- param grads:
Gradients for each parameter.
- type grads:
List[Tensor]
- param exp_avgs:
First moment estimates (momentum).
- type exp_avgs:
List[Tensor]
- param exp_avg_sqs:
Second moment estimates (variance).
- type exp_avg_sqs:
List[Tensor]
- param state_steps:
Step counters (must be singleton tensors).
- type state_steps:
List[Tensor]
- param foreach:
Whether to use multi-tensor implementation. None: auto-select based on configuration (default).
- type foreach:
Optional[bool]
- param capturable:
If True, ensure CUDA graph capture safety.
- type capturable:
bool
- param differentiable:
If True, allow gradients through optimization step.
- type differentiable:
bool
- param fused:
If True, use fused kernels (not implemented).
- type fused:
Optional[bool]
- param grad_scale:
Gradient scaler for AMP training.
- type grad_scale:
Optional[Tensor]
- param found_inf:
Flag for inf/nan detection in AMP.
- type found_inf:
Optional[Tensor]
- param has_complex:
Whether any parameters are complex-valued.
- type has_complex:
bool
- param beta1:
Exponential decay rate for first moment (momentum). Typical range: 0.9-0.95.
- type beta1:
float
- param beta2:
Exponential decay rate for second moment (variance). Typical range: 0.999-0.9999 (higher than Adam).
- type beta2:
float
- param lr:
Learning rate. Can be a scalar Tensor for dynamic learning rate with capturable=True.
- type lr:
Union[float, Tensor]
- param clip_lambda:
Function that maps step number to gradient clipping threshold. None disables clipping.
- type clip_lambda:
Optional[Callable[[int], float]]
- param weight_decay:
Weight decay coefficient (L2 penalty).
- type weight_decay:
float
- param decouple:
If True, use decoupled weight decay (AdamW-style). Recommended for better generalization.
- type decouple:
bool
- param eps:
Small constant for numerical stability in normalization.
- type eps:
float
- param maximize:
If True, maximize objective instead of minimize.
- type maximize:
bool
- raises RuntimeError:
If torch.jit.script is used with foreach or fused.
- raises RuntimeError:
If state_steps contains non-tensor elements.
- raises RuntimeError:
If fused=True (not yet implemented).
- raises RuntimeError:
If lr is Tensor with foreach=True and capturable=False.
Example
>>> # Typically called by ADOPT optimizer, not directly >>> adopt( ... params=[p1, p2], ... grads=[g1, g2], ... exp_avgs=[m1, m2], ... exp_avg_sqs=[v1, v2], ... state_steps=[step1, step2], ... beta1=0.9, ... beta2=0.9999, ... lr=0.001, ... clip_lambda=lambda s: s**0.25, ... weight_decay=0.01, ... decouple=True, ... eps=1e-6, ... maximize=False, ... )
Note
For distributed training, this API is compatible with torch/distributed/optim
The foreach mode is generally preferred for GPU training
Complex parameters are handled transparently by viewing as real
First optimization step only initializes variance, doesn’t update parameters
See also
ADOPT class: High-level optimizer interface
_single_tensor_adopt: Single-tensor implementation details
_multi_tensor_adopt: Multi-tensor implementation details