FlashOptim in JAX: a small journey
tldr: presenting flashoptim-jax -- a JAX-native version of FlashOptim
jax-flashoptim is ~10% faster than the PyTorch implementation for this update, and takes 50% less memory than the JAX baseline.
In mixed-precision training, you run your forward and backward passes in bf16 for speed and memory savings, but keep fp32 master weights and optimizer state for numerical stability. For Adam, that means storing — per parameter — a bf16 working copy (2 bytes), an fp32 master copy (4 bytes), and two fp32 momentum buffers (4 bytes each): 14 bytes per parameter total. On a 7B-parameter model, the optimizer state alone eats 56 GB while the bf16 working weights take just 14 GB.
FlashOptim showed that you can compress this state dramatically through a few quantization tricks (see paper). Their implementation is in PyTorch with Triton kernels, and I want to use it in JAX.
The goal: Implement FlashOptim in JAX with the same memory and speed as PyTorch.
This post walks through four implementations (least to most involved) until the baseline is matched.
- Triton via DLPack: Just call the implemented PyTorch kernels using DLPack.
- JAX unfused: Reimplement FlashOptim in plain good ol' JAX ops, leaving fusion to XLA compiler.
- JAX fused: Fuse dequantize, Adam, and requantize into a single Pallas kernel calculating update.
- JAX FlashOptim: Fuse everything (including applying updates) into a single in-place kernel.
Baselines
We measure speed, steady-state memory, and peak memory use of each optimizer for single-step updates on an A100 on the largest parameter groups of Qwen3-8B (hidden_size=4096, intermediate_size=12288, GQA with 32 Q-heads / 8 KV-heads). Peak memory is measured after compilation/warmup for both frameworks. Note that JAX's XLA compilation itself can transiently allocate significantly more memory than the steady-state footprint.
Our target to beat, with a full-precision baseline for JAX and PyTorch:
We plot relative memory efficiency (1/memory) and relative speed (1/step time) so we can construct a pareto frontier where good=up-and-to-the-right. Note that FlashOptim uses significantly less memory with comparable speed.
Implementation code: PyTorch baseline
The standard PyTorch mixed-precision setup: fp32 parameters, fp32 optimizer state, with autocast handling bf16 externally during the forward/backward pass. fused=True runs the entire Adam update in a single CUDA kernel.
# init
w = torch.nn.Parameter(torch.randn(shape, dtype=torch.float32, device="cuda"))
opt = torch.optim.AdamW([w], lr=1e-3, weight_decay=1e-2, fused=True)
# step (autocast handles bf16 for forward/backward)
w.grad = grad_w
opt.step()
opt.zero_grad(set_to_none=True)
Implementation code: JAX baseline
optax.adamw on fp32 params with donate_argnums for buffer reuse (lower memory, comparable speed).
# init
tx = optax.adamw(learning_rate=1e-3, weight_decay=1e-2)
opt_state = tx.init(params) # params are fp32
# step
@jax.jit(donate_argnums=(0, 1))
def step(params, opt_state, grads):
updates, new_state = tx.update(grads, opt_state, params)
return optax.apply_updates(params, updates), new_state
params, opt_state = step(params, opt_state, grads)
Implementation code: PyTorch FlashOptim
The original FlashOptim implementation: bf16 parameters, quantized int8/uint8 optimizer state (~5 bytes/param). Uses fused Triton kernels that dequantize, update, and requantize in a single pass.
# init
from flashoptim import FlashAdamW
w = torch.nn.Parameter(torch.randn(shape, dtype=torch.bfloat16, device="cuda"))
opt = FlashAdamW([w], lr=1e-3)
# step
w.grad = grad_w
opt.step()
Benchmarking methodology
Each data point is the median of 200 timing samples, where each sample times a batch of 20 consecutive optimizer steps (4,000 total steps per measurement). A separate warmup phase runs first. Each configuration runs in its own subprocess to avoid cross-contamination of GPU state. Memory is measured after warmup/compilation: steady-state = bytes currently allocated; peak = max allocated after compilation (JAX via device.memory_stats(), PyTorch via torch.cuda.max_memory_allocated() after reset_peak_memory_stats()).
# PyTorch timing
for _ in range(warmup):
for _ in range(steps_per_sync): # steps_per_sync = 20
w.grad = grad_w; opt.step(); opt.zero_grad(set_to_none=True)
torch.cuda.synchronize()
samples = []
for _ in range(steps): # steps = 200
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(steps_per_sync):
w.grad = grad_w; opt.step(); opt.zero_grad(set_to_none=True)
torch.cuda.synchronize()
samples.append((time.perf_counter() - t0) * 1e6 / steps_per_sync) # µs
median_us = statistics.median(samples)
# JAX timing (identical structure, different sync)
for _ in range(warmup):
for _ in range(steps_per_sync):
params, opt_state = step(params, opt_state, grads)
jax.block_until_ready((params, opt_state))
samples = []
for _ in range(steps):
t0 = time.perf_counter()
for _ in range(steps_per_sync):
params, opt_state = step(params, opt_state, grads)
jax.block_until_ready((params, opt_state))
samples.append((time.perf_counter() - t0) * 1e6 / steps_per_sync)
median_us = statistics.median(samples)
Implementation 1: Triton via DLPack
The simplest thing to try is: don't reimplement anything yet. Just call the existing PyTorch FlashOptim Triton kernels from JAX using DLPack, which lets the two frameworks exchange GPU tensors without copying the underlying storage.
Speed is slower than native PyTorch FlashOptim because each leaf tensor crosses the JAX→PyTorch boundary individually via jax.pure_callback, paying Python dispatch and DLPack conversion overhead per parameter. Memory is higher because pure_callback is opaque to XLA — it can't use donate_argnums (explained later) to recycle input buffers, so both old and new copies of params/state coexist during the step.
Implementation code: Triton via DLPack
# init
state = init_torch_flashoptim_state(params)
# step
def dlpack_step(params, state, grads):
return jax.pure_callback(
torch_flashoptim_step,
result_shape_dtypes=(params, state),
params, state, grads,
)
step = jax.jit(dlpack_step)
params, state = step(params, state, grads)
The actual update runs in PyTorch. JAX just hands tensors across the framework boundary, waits for the callback to finish, then resumes execution.
Implementation 2: A drop-in optax-shaped optimizer
The next thing to try is to implement the FlashOptim algorithm in plain JAX ops. Dequantize the stored int8/uint8 state, run the standard Adam math in fp32, requantize and store. No custom kernels; leave fusion to XLA compilation.
Here's the core of the unfused path (still in the codebase as _flash_adamw_leaf_unfused, kept for readability and testing):
# Reconstruct fp32 parameter from bf16 + error correction code
if use_ecc:
param_f32 = reconstruct_leaf(param, ecc)
else:
param_f32 = jnp.asarray(param, dtype=jnp.float32)
# Dequantize optimizer state from int8/uint8 back to fp32
mu_f32 = dequantize_momentum(mu, group_size) # int8 + fp16 scales → fp32
nu_f32 = dequantize_variance(nu, group_size) # uint8 + fp16 scales → fp32
# Standard AdamW update
mu_f32 = b1 * mu_f32 + (1 - b1) * grad_f32
nu_f32 = b2 * nu_f32 + (1 - b2) * grad_f32 ** 2
param_f32 = param_f32 * (1 - lr * weight_decay)
param_f32 = param_f32 - (lr / bc1) * mu_f32 / (sqrt(nu_f32 / bc2) + eps)
# Requantize and store
update = new_param.astype(param.dtype) - param # compute delta
new_mu = quantize_momentum(mu_f32, group_size) # fp32 → int8 + fp16 scales
new_nu = quantize_variance(nu_f32, group_size) # fp32 → uint8 + fp16 scales
This works correctly and can be easily wrapped as a drop-in-replacement for optax optimizers, but there is lots of unnecessary data movement: each step materializes several full-size fp32 intermediate arrays, and the caller has to apply the delta afterward (new_params = params + result.update). XLA does not fuse these ops well.
Speed aside, the memory situation is also bad: naively, XLA keeps old and new copies of params and state alive simultaneously, roughly doubling memory. We can fix this with donate_argnums — JAX's mechanism for telling XLA that the caller won't use the input buffers after the call returns so they can be overwritten:
@jax.jit(donate_argnums=(0, 1))
def train_step(params, opt_state, grads):
updates, new_state = tx.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_state
Implementation code: unfused (with and without donation)
# init
tx = flash_adamw(learning_rate=1e-3, fused=False)
state = tx.init(params)
# step
def train_step(params, state, grads):
updates, new_state = tx.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_state
# without donation:
step = jax.jit(train_step)
# with donation:
step = jax.jit(train_step, donate_argnums=(0, 1))
The only difference is donate_argnums=(0, 1). Inside the compiled function, XLA runs identical code — separate kernels for dequantize, Adam math, requantize. Donation just controls whether old buffers are reused afterward.
The unfused JAX path is 3× slower than PyTorch at the MLP shape that dominates real models. Donation closes the memory gap but does nothing for speed — it reclaims buffers at the jit boundary, but inside the step, XLA still materializes all the same fp32 intermediates through separate ops. Each quantize/dequantize is its own kernel launch with its own memory traffic, and XLA can't fuse them.
Implementation 3: Fusing into a single Pallas kernel
The key insight from the fused FlashOptim implementation is that the dequantize → update → requantize pipeline has no cross-parameter dependencies — each block of elements can be processed independently. So, we can fuse the entire thing into one GPU kernel: load the quantized state, do all the math in registers, and write back the requantized result. No intermediate fp32 arrays ever touch global memory.
In JAX, Pallas lets you write GPU kernels in a NumPy-like syntax that compiles to Triton and plays nicely with the XLA compiler. Here's the structure of the fused quantized kernel, mirroring the Triton implementation closely:
import jax.experimental.pallas as pl
import jax.experimental.pallas.gpu as plgpu
def _flash_adamw_leaf_quantized_kernel(
grad_ref, param_ref, ecc_ref,
mu_values_ref, nu_values_ref, mu_scales_ref, nu_scales_ref,
lr_ref, b1_ref, b2_ref, eps_ref, weight_decay_ref,
bias_correction1_ref, bias_correction2_ref,
update_ref, ecc_out_ref,
mu_values_out_ref, mu_scales_out_ref,
nu_values_out_ref, nu_scales_out_ref,
block_size: int, group_size: int, ...
):
pid = pl.program_id(0)
offsets = pid * block_size + jnp.arange(block_size)
# Load quantized state from global memory
mu_vals = plgpu.load(mu_values_ref.at[offsets], ...) # int8
mu_sc = plgpu.load(mu_scales_ref.at[...], ...) # fp16 per-group scales
# ... (load everything else)
# Dequantize in-register
mu = inverse_softsign(mu_vals / 127.0) * mu_sc # int8 → fp32
nu = (nu_vals / 255.0 * nu_sc) ** 2 # uint8 → fp32
# Adam math in-register (weight decay + bias correction omitted for clarity)
mu = b1 * mu + (1 - b1) * grad
nu = b2 * nu + (1 - b2) * grad ** 2
param_f32 = param_f32 - lr * mu / (sqrt(nu) + eps)
# Requantize in-register and write back
mu_out = softsign(mu / absmax(mu)) * 127 # → int8
nu_out = sqrt(nu) / absmax(sqrt(nu)) * 255 # → uint8
plgpu.store(update_ref.at[offsets], new_param_lp - param_lp) # delta
plgpu.store(mu_values_out_ref.at[offsets], mu_out)
# ... (store everything else)
This fused kernel follows the standard optax pattern: it returns a delta (new_param - old_param), and the caller applies it with new_params = params + result.update. Since it's a normal optax-shaped optimizer, donate_argnums works the same way as for the unfused path — XLA reclaims the old param/state buffers at the jit boundary:
Implementation code: fused delta (with and without donation)
inplace=False tells the Pallas kernel to write new_param - old_param (a delta). The state outputs (mu, nu, ecc, scales) use input_output_aliases so donation can reuse those buffers in-place. But the param buffer can't be aliased (the kernel still needs to read old params), so the caller adds the delta back in a separate pass.
# init
tx = flash_adamw(learning_rate=1e-3, fused=True, inplace=False)
state = tx.init(params)
# step — kernel outputs a delta, then we add it back
def train_step(params, state, grads):
updates, new_state = tx.update(grads, state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_state
# without donation:
step = jax.jit(train_step)
# with donation:
step = jax.jit(train_step, donate_argnums=(0, 1))
Fusion takes JAX from 3× slower to ~75% of PyTorch's speed without donation, and 90% with donation. Donation helps both memory and speed: the kernel declares input_output_aliases for the state outputs (mu, nu, ecc, scales). This Pallas feature tells XLA that a kernel input and output share the same buffer — the kernel reads the old value and writes the new value to the same memory. So, paired with donate_argnums, no copies or extra allocations are needed. The one output that can't be aliased is the param buffer: the kernel writes a delta (new_param - old_param), not the new param value, so the output has a different meaning than the input. After the kernel returns, the caller computes params + delta via optax.apply_updates — an extra kernel launch and an extra round-trip through memory bandwidth — which limits how fast the delta path can get.
Implementation 4: Direct-write output
There is one final optimization we will pursue: removing the round trip when computing and applying updates. As stated above, following the optax pattern of updates, new_state = tx.update(grads, opt_state, params) followed by optax.apply_updates(params, updates) adds an unnecessary kernel launch: we can instead just calculate the new parameter in a single pass.
The change is one line in the Pallas kernel and adding one more input-output alias:
# Before (delta):
plgpu.store(update_ref.at[offsets], new_param_lp - param_lp) # delta
# After (direct-write):
plgpu.store(update_ref.at[offsets], new_param_lp) # full new value
Now, the kernel's param output IS the new param value — so we can alias the param input buffer with the param output buffer too, not just the state buffers.
With all buffers aliased and donated, the kernel reads old params and writes new params into the same buffer — zero-copies, zero extra allocation, no second pass. The entire step is one kernel launch with one pass through memory.
Implementation code: direct-write (with and without donation)
inplace=True tells the Pallas kernel to write the full new parameter value (not a delta). This enables input_output_aliases inside the kernel, so when combined with donate_argnums, the kernel reads old params and writes new params into the same buffer.
# init
tx = flash_adamw(learning_rate=1e-3, fused=True, inplace=True)
state = tx.init(params)
# step — kernel writes final param values directly (no delta)
def train_step(params, state, grads):
return tx.step(params, state, grads)
# without donation (kernel writes to fresh buffer):
step = jax.jit(train_step)
# with donation (kernel aliases input/output buffers):
step = jax.jit(train_step, donate_argnums=(0, 1))
Direct-write alone doesn't help much — without donation, XLA can't alias the buffers, so the kernel writes to a fresh allocation anyway. The payoff only comes from the combination: direct-write gives the kernel the right output shape for input_output_aliases, and donation tells XLA those input buffers are available.
This final JAX version — which has now earned the title of JAX FlashOptim — is 6% faster than PyTorch FlashOptim at the tested MLP shape and is within 1% of memory. Across all shapes, we went from 3× slower and 60% more memory (unfused) to 6–59% faster and memory-equivalent.
The final API
The end result is a small API surface that hides all this complexity:
from flashoptim_jax import flash_adamw
tx = flash_adamw(learning_rate=1e-3, weight_decay=1e-2)
opt_state = tx.init(params)
@jax.jit(donate_argnums=(0, 1))
def train_step(params, opt_state, batch):
grads = jax.grad(loss_fn)(params, batch)
return tx.step(params, opt_state, grads)
Under the hood: quantized int8/uint8 optimizer state, fused Pallas kernels, in-place parameter updates, and buffer donation. On the surface: two lines to set up, one function call per step.
The code is at github.com/samacqua/flashoptim-jax and installable via uv pip install flashoptim-jax.
Notes
Other optimizers + param dtypes
These have a very similar trajectory and are implemented in the package, I only used bf16 and AdamW in this blog for clarity.
Correctness
The repo has both tests and examples showing it matches performance on real datasets/models.
Gradient Release
The tables above show that buffer donation closes the memory gap between JAX and PyTorch for the optimizer step itself. But the PyTorch FlashOptim has one more trick that can't be (easily, afaik) ported to JAX: gradient release.
PyTorch's enable_gradient_release() uses register_post_accumulate_grad_hook to step each parameter the moment its gradient is ready during the backward pass, then immediately sets p.grad = None to free the gradient. At any point during backprop, only a few gradients are alive (one per in-flight autograd node), not the full gradient tree. For a 7B model with fp32 gradients, this saves ~28 GB of peak memory.
In JAX, jax.grad is a pure function that returns the complete gradient pytree — XLA compiles the entire backward pass into a single fused program, and there's no hook mechanism to intercept individual gradients as they're computed. You can't insert Python callbacks into XLA-compiled code.
TPU Compatibility
Part of the motivation of porting this to JAX is to use this with TPUs. This is not yet implemented.
Citation
This is an independent reimplementation weekend project. If you use it, please cite the original paper: