jax.ref.addupdate#

jax.ref.addupdate(ref, idx, x)[source]#

Add to an element in an ArrayRef in-place.

This is analogous to ref[idx] += value for a NumPy array ref and NumPy-style indexer idx. However, for an ArrayRef ref, executing ref[idx] += value actually performs a ref_get, add, and ref_set, so using this function can be more efficient under autodiff. For more on mutable array refs, refer to the ArrayRef guide.

Parameters:
  • ref (AbstractRef) – a jax.ref.ArrayRef object. On return, the buffer will be mutated by this operation.

  • idx (Indexer | tuple[Indexer, ...] | None) – a NumPy-style indexer

  • x (Array) – a jax.Array object (note, not a jax.ref.ArrayRef) containing the values to add at the specified indices.

Returns:

None

Return type:

None

Examples

>>> import jax
>>> ref = jax.array_ref(jax.numpy.arange(5))
>>> jax.ref.addupdate(ref, 2, 10)
>>> ref
ArrayRef([ 0,  1, 12,  3,  4], dtype=int32)

Equivalent operation via indexing syntax:

>>> ref = jax.array_ref(jax.numpy.arange(5))
>>> ref[2] += 10
>>> ref
ArrayRef([ 0,  1, 12,  3,  4], dtype=int32)

Use ... to add to a scalar ref:

>>> ref = jax.array_ref(jax.numpy.int32(2))
>>> ref[...] += 10
>>> ref
ArrayRef(12, dtype=int32)