jax.nn.logmeanexp# jax.nn.logmeanexp(x, axis=None, where=None, keepdims=False)[source]# Log mean exp. Computes the function: \[\text{logmeanexp}(x) = \log \frac{1}{n} \sum_{i=1}^n \exp x_i = \text{logsumexp}(x) - \log n\] Parameters: x (ArrayLike) – Input array. axis (int | tuple[int, ...] | None) – Axis or axes along which to reduce. where (ArrayLike | None) – Elements to include in the reduction. Optional. keepdims (bool) – Preserve the dimensions of the input. Returns: An array. Return type: Array See also jax.nn.logsumexp()