jax.ref module#

jax.ref has the API for working with ArrayRef.

API#

AbstractRef(inner_aval[, memory_space])

Abstract mutable array reference.

ArrayRef(aval, buf)

Mutable array reference.

array_ref(init_val, *[, memory_space])

Create a mutable array reference with initial value init_val.

freeze(ref)

Invalidate a given reference and return its final value.

get(ref[, idx])

Read a value from an ArrayRef.

set(ref, idx, value)

Set a value in an ArrayRef in-place.

swap(ref, idx, value[, _function_name])

Set an array value inplace while returning the existing value.

addupdate(ref, idx, x)

Add to an element in an ArrayRef in-place.