dicee.models.adopt

ADOPT Optimizer Implementation.

This module implements the ADOPT (Adaptive Optimization with Precise Tracking) algorithm, an advanced optimization method for training neural networks.

ADOPT Overview:

ADOPT is an adaptive learning rate optimization algorithm that combines the benefits of momentum-based methods with per-parameter learning rate adaptation. Unlike Adam, which applies momentum to raw gradients, ADOPT normalizes gradients first and then applies momentum, leading to more stable training dynamics.

Key Features: - Gradient normalization before momentum application - Adaptive per-parameter learning rates - Optional gradient clipping that grows with training steps - Support for decoupled weight decay (AdamW-style) - Multiple execution modes: single-tensor, multi-tensor (foreach), and fused (planned)

Algorithm Comparison:

Adam: m = β₁*m + (1-β₁)*g, θ = θ - α*m/√v ADOPT: m = β₁*m + (1-β₁)*g/√v, θ = θ - α*m

The key difference is that ADOPT normalizes gradients before momentum, which provides better stability and can lead to improved convergence.

Classes:

  • ADOPT: Main optimizer class (extends torch.optim.Optimizer)

Functions:

  • adopt: Functional API for ADOPT algorithm computation

  • _single_tensor_adopt: Single-tensor implementation (TorchScript compatible)

  • _multi_tensor_adopt: Multi-tensor implementation using foreach operations

Performance:

  • Single-tensor: Default, compatible with torch.jit.script

  • Multi-tensor (foreach): 2-3x faster on GPU through vectorization

  • Fused (planned): Would provide maximum performance via specialized kernels

Example:

>>> import torch
>>> from dicee.models.adopt import ADOPT
>>>
>>> model = torch.nn.Linear(10, 1)
>>> optimizer = ADOPT(model.parameters(), lr=0.001, weight_decay=0.01, decouple=True)
>>>
>>> # Training loop
>>> for epoch in range(num_epochs):
...     optimizer.zero_grad()
...     output = model(input)
...     loss = criterion(output, target)
...     loss.backward()
...     optimizer.step()

References:

Original implementation: https://github.com/iShohei220/adopt

Notes:

This implementation is based on the original ADOPT implementation and adapted to work with the PyTorch optimizer interface and the dicee framework.

Classes

ADOPT

ADOPT Optimizer.

Functions

adopt(params, grads, exp_avgs, exp_avg_sqs, state_steps)

Functional API that performs ADOPT algorithm computation.

Module Contents

class dicee.models.adopt.ADOPT(params: torch.optim.optimizer.ParamsT, lr: float | torch.Tensor = 0.001, betas: Tuple[float, float] = (0.9, 0.9999), eps: float = 1e-06, clip_lambda: Callable[[int], float] | None = lambda step: ..., weight_decay: float = 0.0, decouple: bool = False, *, foreach: bool | None = None, maximize: bool = False, capturable: bool = False, differentiable: bool = False, fused: bool | None = None)[source]

Bases: torch.optim.optimizer.Optimizer

ADOPT Optimizer.

ADOPT is an adaptive learning rate optimization algorithm that combines momentum-based updates with adaptive per-parameter learning rates. It uses exponential moving averages of gradients and squared gradients, with gradient clipping for stability.

The algorithm performs the following key operations: 1. Normalizes gradients by the square root of the second moment estimate 2. Applies optional gradient clipping based on the training step 3. Updates parameters using momentum-smoothed normalized gradients 4. Supports decoupled weight decay (AdamW-style) or L2 regularization

Mathematical formulation:

m_t = β₁ * m_{t-1} + (1 - β₁) * clip(g_t / √(v_t)) v_t = β₂ * v_{t-1} + (1 - β₂) * g_t² θ_t = θ_{t-1} - α * m_t

where:
  • θ_t: parameter at step t

  • g_t: gradient at step t

  • m_t: first moment estimate (momentum)

  • v_t: second moment estimate (variance)

  • α: learning rate

  • β₁, β₂: exponential decay rates

  • clip(): optional gradient clipping function

Reference:

Original implementation: https://github.com/iShohei220/adopt

Parameters:
  • params (ParamsT) – Iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float or Tensor, optional) – Learning rate. Can be a float or 1-element Tensor. Default: 1e-3

  • betas (Tuple[float, float], optional) – Coefficients (β₁, β₂) for computing running averages of gradient and its square. β₁ controls momentum, β₂ controls variance. Default: (0.9, 0.9999)

  • eps (float, optional) – Term added to denominator to improve numerical stability. Default: 1e-6

  • clip_lambda (Callable[[int], float], optional) – Function that takes the step number and returns the gradient clipping threshold. Common choices: - lambda step: step**0.25 (default, gradually increases clipping threshold) - lambda step: 1.0 (constant clipping) - None (no clipping) Default: lambda step: step**0.25

  • weight_decay (float, optional) – Weight decay coefficient (L2 penalty). Default: 0.0

  • decouple (bool, optional) – If True, uses decoupled weight decay (AdamW-style), applying weight decay directly to parameters. If False, adds weight decay to gradients (L2 regularization). Default: False

  • foreach (bool, optional) – If True, uses the faster foreach implementation for multi-tensor operations. Default: None (auto-select)

  • maximize (bool, optional) – If True, maximizes parameters instead of minimizing. Useful for reinforcement learning. Default: False

  • capturable (bool, optional) – If True, the optimizer is safe to capture in a CUDA graph. Requires learning rate as Tensor. Default: False

  • differentiable (bool, optional) – If True, the optimization step can be differentiated. Useful for meta-learning. Default: False

  • fused (bool, optional) – If True, uses fused kernel implementation (currently not supported). Default: None

Raises:
  • ValueError – If learning rate, epsilon, betas, or weight_decay are invalid.

  • RuntimeError – If fused is enabled (not currently supported).

  • RuntimeError – If lr is a Tensor with foreach=True and capturable=False.

Example

>>> # Basic usage
>>> optimizer = ADOPT(model.parameters(), lr=0.001)
>>> optimizer.zero_grad()
>>> loss.backward()
>>> optimizer.step()
>>> # With decoupled weight decay
>>> optimizer = ADOPT(model.parameters(), lr=0.001, weight_decay=0.01, decouple=True)
>>> # Custom gradient clipping
>>> optimizer = ADOPT(model.parameters(), clip_lambda=lambda step: max(1.0, step**0.5))

Note

  • For most use cases, the default hyperparameters work well

  • Consider using decouple=True for better generalization (similar to AdamW)

  • The clip_lambda function helps stabilize training in early steps

clip_lambda
__setstate__(state)[source]

Restore optimizer state from a checkpoint.

This method handles backward compatibility when loading optimizer state from older versions. It ensures all required fields are present with default values and properly converts step counters to tensors if needed.

Key responsibilities: 1. Set default values for newly added hyperparameters 2. Convert old-style scalar step counters to tensor format 3. Place step tensors on appropriate devices based on capturable/fused modes

Parameters:

state (dict) – Optimizer state dictionary (typically from torch.load()).

Note

  • This enables loading checkpoints saved with older ADOPT versions

  • Step counters are converted to appropriate device/dtype for compatibility

  • Capturable and fused modes require step tensors on parameter devices

step(closure=None)[source]

Perform a single optimization step.

This method executes one iteration of the ADOPT optimization algorithm across all parameter groups. It orchestrates the following workflow:

  1. Optionally evaluates a closure to recompute the loss (useful for algorithms like LBFGS or when loss needs multiple evaluations)

  2. For each parameter group: - Collects parameters with gradients and their associated state - Extracts hyperparameters (betas, learning rate, etc.) - Calls the functional adopt() API to perform the actual update

  3. Returns the loss value if a closure was provided

The functional API (adopt()) handles three execution modes: - Single-tensor: Updates one parameter at a time (default, JIT-compatible) - Multi-tensor (foreach): Batches operations for better performance - Fused: Uses fused CUDA kernels (not yet implemented)

Gradient scaling support: This method is compatible with automatic mixed precision (AMP) training. It can access grad_scale and found_inf attributes for gradient unscaling and inf/nan detection when used with GradScaler.

Parameters:

closure (Callable, optional) – A callable that reevaluates the model and returns the loss. The closure should: - Enable gradients (torch.enable_grad()) - Compute forward pass - Compute loss - Compute backward pass - Return the loss value Example: lambda: (loss := model(x), loss.backward(), loss)[-1] Default: None

Returns:

The loss value returned by the closure, or None if no

closure was provided.

Return type:

Optional[Tensor]

Example

>>> # Standard usage
>>> loss = criterion(model(input), target)
>>> loss.backward()
>>> optimizer.step()
>>> # With closure (e.g., for line search)
>>> def closure():
...     optimizer.zero_grad()
...     output = model(input)
...     loss = criterion(output, target)
...     loss.backward()
...     return loss
>>> loss = optimizer.step(closure)

Note

  • Call zero_grad() before computing gradients for the next step

  • CUDA graph capture is checked for safety when capturable=True

  • The method is thread-safe for different parameter groups

dicee.models.adopt.adopt(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[torch.Tensor], exp_avg_sqs: List[torch.Tensor], state_steps: List[torch.Tensor], foreach: bool | None = None, capturable: bool = False, differentiable: bool = False, fused: bool | None = None, grad_scale: torch.Tensor | None = None, found_inf: torch.Tensor | None = None, has_complex: bool = False, *, beta1: float, beta2: float, lr: float | torch.Tensor, clip_lambda: Callable[[int], float] | None, weight_decay: float, decouple: bool, eps: float, maximize: bool)[source]

Functional API that performs ADOPT algorithm computation.

This is the main functional interface for the ADOPT optimization algorithm. It dispatches to one of three implementations based on the execution mode:

  1. Single-tensor mode (default): Updates parameters one at a time - Compatible with torch.jit.script - More flexible but slower - Used when foreach=False or automatically for small models

  2. Multi-tensor (foreach) mode: Batches operations across tensors - 2-3x faster on GPU through vectorization - Groups tensors by device/dtype automatically - Used when foreach=True

  3. Fused mode: Uses specialized fused kernels (not yet implemented) - Would provide maximum performance - Currently raises RuntimeError if enabled

Algorithm overview (ADOPT):

ADOPT adapts learning rates per-parameter while using momentum on normalized gradients. The key innovation is normalizing gradients before momentum, which provides more stable training than standard Adam.

Mathematical formulation:

# Normalize gradient by its historical variance normed_g_t = g_t / √(v_t + ε)

# Optional gradient clipping for stability normed_g_t = clip(normed_g_t, threshold(t))

# Momentum on normalized gradients (key difference from Adam) m_t = β₁ * m_{t-1} + (1 - β₁) * normed_g_t

# Parameter update θ_t = θ_{t-1} - α * m_t

# Update variance estimate v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²

where:
  • θ: parameters

  • g: gradients

  • m: first moment (momentum of normalized gradients)

  • v: second moment (variance of raw gradients)

  • α: learning rate

  • β₁, β₂: exponential decay rates

  • ε: numerical stability constant

  • clip(): gradient clipping function based on step

Automatic mode selection:

When foreach and fused are both None (default), the function automatically selects the best implementation based on: - Parameter types and devices - Whether differentiable mode is enabled - Learning rate type (float vs Tensor) - Capturable mode requirements

param params:

Parameters to optimize.

type params:

List[Tensor]

param grads:

Gradients for each parameter.

type grads:

List[Tensor]

param exp_avgs:

First moment estimates (momentum).

type exp_avgs:

List[Tensor]

param exp_avg_sqs:

Second moment estimates (variance).

type exp_avg_sqs:

List[Tensor]

param state_steps:

Step counters (must be singleton tensors).

type state_steps:

List[Tensor]

param foreach:

Whether to use multi-tensor implementation. None: auto-select based on configuration (default).

type foreach:

Optional[bool]

param capturable:

If True, ensure CUDA graph capture safety.

type capturable:

bool

param differentiable:

If True, allow gradients through optimization step.

type differentiable:

bool

param fused:

If True, use fused kernels (not implemented).

type fused:

Optional[bool]

param grad_scale:

Gradient scaler for AMP training.

type grad_scale:

Optional[Tensor]

param found_inf:

Flag for inf/nan detection in AMP.

type found_inf:

Optional[Tensor]

param has_complex:

Whether any parameters are complex-valued.

type has_complex:

bool

param beta1:

Exponential decay rate for first moment (momentum). Typical range: 0.9-0.95.

type beta1:

float

param beta2:

Exponential decay rate for second moment (variance). Typical range: 0.999-0.9999 (higher than Adam).

type beta2:

float

param lr:

Learning rate. Can be a scalar Tensor for dynamic learning rate with capturable=True.

type lr:

Union[float, Tensor]

param clip_lambda:

Function that maps step number to gradient clipping threshold. None disables clipping.

type clip_lambda:

Optional[Callable[[int], float]]

param weight_decay:

Weight decay coefficient (L2 penalty).

type weight_decay:

float

param decouple:

If True, use decoupled weight decay (AdamW-style). Recommended for better generalization.

type decouple:

bool

param eps:

Small constant for numerical stability in normalization.

type eps:

float

param maximize:

If True, maximize objective instead of minimize.

type maximize:

bool

raises RuntimeError:

If torch.jit.script is used with foreach or fused.

raises RuntimeError:

If state_steps contains non-tensor elements.

raises RuntimeError:

If fused=True (not yet implemented).

raises RuntimeError:

If lr is Tensor with foreach=True and capturable=False.

Example

>>> # Typically called by ADOPT optimizer, not directly
>>> adopt(
...     params=[p1, p2],
...     grads=[g1, g2],
...     exp_avgs=[m1, m2],
...     exp_avg_sqs=[v1, v2],
...     state_steps=[step1, step2],
...     beta1=0.9,
...     beta2=0.9999,
...     lr=0.001,
...     clip_lambda=lambda s: s**0.25,
...     weight_decay=0.01,
...     decouple=True,
...     eps=1e-6,
...     maximize=False,
... )

Note

  • For distributed training, this API is compatible with torch/distributed/optim

  • The foreach mode is generally preferred for GPU training

  • Complex parameters are handled transparently by viewing as real

  • First optimization step only initializes variance, doesn’t update parameters

See also

  • ADOPT class: High-level optimizer interface

  • _single_tensor_adopt: Single-tensor implementation details

  • _multi_tensor_adopt: Multi-tensor implementation details