JAX Refs: Controlled Mutation

June 26, 2026

JAX Refs are mutable array cells. They let you read and write array state while remaining inside JAX transformations such as jit and grad.

A jax.Array is a value. In JAX, values do not change. To “update” an array, you compute another array:

x = x.at[i].set(y)

This strictness is one reason JAX composes well with compilation, automatic differentiation, vectorization, and parallel execution. The compiler sees values and dependencies, not hidden writes.

But immutability is not always the clearest way to write a program.

Refs are best for operational state: logs, buffers, scratch space, and diagnostics. Sometimes the thing you want is not a new value. You want a place to put data: a loss computed deep inside a training step, a scratch buffer used inside a loop, a running statistic, a diagnostic from a backward pass, or an output buffer filled in-place.

Ref is JAX’s small, explicit answer to that problem. JAX documentation, “Ref: mutable arrays for data plumbing and memory control”.

Values and locations

The distinction is simple:

Table 1: arrays are values; refs are mutable locations.

Object Mental model Operation
jax.Array A value Compute a new value
jax.Ref A mutable location Read or write the stored value

An array expression does not update its inputs:

x = jnp.array([1., 2., 3.])
y = x + 1

The original x is still [1., 2., 3.]. The expression produced a new value, y.

A Ref is different. It names a location that contains an array value:

x_ref = jax.new_ref(jnp.array([1., 2., 3.]))

x_ref[1] += 10

print(x_ref[...])  # [ 1., 12.,  3.]

The syntax follows NumPy indexing. Read from a ref by indexing it:

x = x_ref[...]

Write the whole value with assignment:

x_ref[...] = x + 1

Update a slice in-place:

x_ref[i] += delta

The notation is standard, but the meaning is not. A Ref introduces controlled mutation into JAX’s staged computation model.

Why refs exist

The usual JAX style is functional:

new_state, output = f(old_state, input)

That style is often the right one. It makes dependencies explicit, keeps functions easy to transform, and gives the compiler a clean program to analyze.

But not all data in a program has the same status. Model parameters, optimizer state, batch inputs, and PRNG keys usually belong in the function signature. They define the computation.

Metrics, scratch space, and diagnostics are different. They are often operational. You need them, but they are not always part of the function you are trying to express.

A Ref gives that operational data an explicit home:

This state is mutable, and its location is visible.

That is the compromise. Mutation is allowed, but it is attached to a particular reference, not hidden in arbitrary Python state.

Recording losses inside a loop

Suppose we run a few steps of gradient descent and want to record the loss at each step.

import jax
import jax.numpy as jnp

def mse(w, x, y):
    pred = x @ w
    return jnp.mean((pred - y) ** 2)

steps = 5
lr = 0.1

Without Ref: carry the log through the loop

In ordinary functional JAX, the loss history must be part of the loop state.

@jax.jit
def train_without_refs(w, x, y):
    losses0 = jnp.zeros((steps,))

    def body(carry, t):
        w, losses = carry

        loss, grad = jax.value_and_grad(mse)(w, x, y)
        w = w - lr * grad

        # This returns a new array value.
        losses = losses.at[t].set(loss)

        return (w, losses), None

    (w, losses), _ = jax.lax.scan(
        body,
        (w, losses0),
        jnp.arange(steps),
    )
    return w, losses

This is good JAX. It is explicit and transformable. But the log has become part of the loop carry. In a small example that is harmless. In a real training step, this pattern spreads: every function that computes an interesting diagnostic must return it, and every caller must forward it. The algorithm becomes harder to read because the bookkeeping travels with the state.

With Ref: put the log in an output buffer

With a Ref, the loss history can live in a separate mutable buffer.

@jax.jit
def train_with_refs(w, x, y, losses_ref):
    def body(w, t):
        loss, grad = jax.value_and_grad(mse)(w, x, y)
        w = w - lr * grad

        # The loss is diagnostic data, not optimization state.
        losses_ref[t] = jax.lax.stop_gradient(loss)

        return w, None

    w, _ = jax.lax.scan(
        body,
        w,
        jnp.arange(steps),
    )
    return w

The caller owns the buffer:

losses_ref = jax.new_ref(jnp.zeros((steps,)))

w_final = train_with_refs(w0, x, y, losses_ref)
losses = losses_ref[...]

print(losses)

The structural change is the point. Without refs:

(w, losses) -> body -> (w, losses)

With refs:

w -> body -> w

The loop carry is now just the optimization state. The loss history is still explicit, but it lives beside the algorithm rather than inside it. That is the main use of Ref: separating the value flow of the computation from the data plumbing around it.

Why stop_gradient appears

This line matters:

losses_ref[t] = jax.lax.stop_gradient(loss)

The loss buffer is a diagnostic. We want to record the value, not make the mutable write part of the differentiated computation.

A useful rule is:

If a ref is used only for metrics, logging, or diagnostics, write stopped values into it.

This keeps the mathematical function visible:

w_final = train(w, x, y)

The ref records what happened. It does not define what the function is.

A Ref is not an array

The right mental model is that references are similar to boxes where you can put values:

x_ref = jax.new_ref(1.0)

try:
    jnp.sin(x_ref)   # wrong
except TypeError:
    pass

jnp.sin(x_ref[...])  # right

A ref is the mutable location. The array value is stored inside it.

The two basic operations are:

x_ref[...]      # read the stored value
x_ref[...] = y  # replace the stored value

For indexed state:

x_ref[i]        # read one slice
x_ref[i] = y    # write one slice
x_ref[i] += y   # update one slice

The compact rule is:

Do math on ref[...], not on ref.

The brackets are not decoration. They are the boundary between the mutable cell and the array value inside it. This is similar to dereferencing or pointer programming in imperative languages like C/C++.

When to use refs

Use refs when the program is clearer with an explicit mutable location.

Table 2: operational state can live in refs; mathematical state should usually remain an array.

Prefer a Ref for Prefer an array for
Metric buffers Model parameters
Temporary scratch arrays Optimizer state
Output buffers Batch inputs
Running statistics and logs Return values that define the function
Diagnostics from compiled code Values you want to batch, serialize, or pass around freely

If a value is part of the mathematical interface of the function, pass it and return it. If it is operational data that needs a stable location, a ref may be the clearer tool.

End note

A Ref is a controlled mutable cell for JAX arrays.

An array is a value. A ref is a location containing a value. The compact rule is:

ref[...]      # read
ref[...] = x  # write

Internal refs can preserve a pure interface from the caller’s perspective. Ref arguments and closed-over refs are explicit side effects. The restrictions around refs exist to rule out ambiguous aliasing and transformation behavior.

The point is not that JAX becomes imperative. The point is narrower: JAX now has a small, explicit place for mutation when mutation is the clearest way to express data plumbing or memory control.

JAX Refs: Controlled Mutation - June 26, 2026 - clay harmon