jax.ref.addupdate#
- jax.ref.addupdate(ref, idx, x)[source]#
Add to an element in an ArrayRef in-place.
This is analogous to
ref[idx] += valuefor a NumPy arrayrefand NumPy-style indexeridx. However, for an ArrayRefref, executingref[idx] += valueactually performs aref_get, add, andref_set, so using this function can be more efficient under autodiff. For more on mutable array refs, refer to the ArrayRef guide.- Parameters:
ref (AbstractRef) – a
jax.ref.ArrayRefobject. On return, the buffer will be mutated by this operation.idx (Indexer | tuple[Indexer, ...] | None) – a NumPy-style indexer
x (Array) – a
jax.Arrayobject (note, not ajax.ref.ArrayRef) containing the values to add at the specified indices.
- Returns:
None
- Return type:
None
Examples
>>> import jax >>> ref = jax.array_ref(jax.numpy.arange(5)) >>> jax.ref.addupdate(ref, 2, 10) >>> ref ArrayRef([ 0, 1, 12, 3, 4], dtype=int32)
Equivalent operation via indexing syntax:
>>> ref = jax.array_ref(jax.numpy.arange(5)) >>> ref[2] += 10 >>> ref ArrayRef([ 0, 1, 12, 3, 4], dtype=int32)
Use
...to add to a scalar ref:>>> ref = jax.array_ref(jax.numpy.int32(2)) >>> ref[...] += 10 >>> ref ArrayRef(12, dtype=int32)