symjax.tensor.random

bernoulli(key, p, shape) Sample Bernoulli random values with given shape and mean.
beta(key, a, …) Sample Beta random values with given shape and float dtype.
cauchy(key[, shape, dtype]) Sample Cauchy random values with given shape and float dtype.
dirichlet(key, alpha[, shape, dtype]) Sample Dirichlet random values with given shape and float dtype.
gamma(key, a[, shape, dtype]) Sample Gamma random values with given shape and float dtype.
gumbel(key[, shape, dtype]) Sample Gumbel random values with given shape and float dtype.
laplace(key[, shape, dtype]) Sample Laplace random values with given shape and float dtype.
logistic(key[, shape, dtype]) Sample logistic random values with given shape and float dtype.
multivariate_normal(key, mean, cov, shape, dtype) Sample multivariate normal random values with given mean and covariance.
normal(key, shape, dtype) Sample standard normal random values with given shape and float dtype.
pareto(key, b[, shape, dtype]) Sample Pareto random values with given shape and float dtype.
randint(key, shape, minval, …) Sample uniform random values in [minval, maxval) with given shape/dtype.
shuffle(key, x, axis) Shuffle the elements of an array uniformly at random along an axis.
truncated_normal(key, lower, …) Sample truncated standard normal random values with given shape and dtype.
uniform(key, shape, dtype, minval, …) Sample uniform random values in [minval, maxval) with given shape/dtype.

Detailed Description

symjax.tensor.random.bernoulli(key: jax._src.numpy.lax_numpy.ndarray, p: jax._src.numpy.lax_numpy.ndarray = 0.5, shape: Optional[Sequence[int]] = None) → jax._src.numpy.lax_numpy.ndarray[source]

Sample Bernoulli random values with given shape and mean.

Parameters:
  • key – a PRNGKey used as the random key.
  • p – optional, a float or array of floats for the mean of the random variables. Must be broadcast-compatible with shape. Default 0.5.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with p.shape. The default (None) produces a result shape equal to p.shape.
Returns:

A random array with boolean dtype and shape given by shape if shape is not None, or else p.shape.

symjax.tensor.random.beta(key: jax._src.numpy.lax_numpy.ndarray, a: Union[float, jax._src.numpy.lax_numpy.ndarray], b: Union[float, jax._src.numpy.lax_numpy.ndarray], shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]

Sample Beta random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • a – a float or array of floats broadcast-compatible with shape representing the first parameter “alpha”.
  • b – a float or array of floats broadcast-compatible with shape representing the second parameter “beta”.
  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a and b. The default (None) produces a result shape by broadcasting a and b.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting a and b.

symjax.tensor.random.cauchy(key, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Cauchy random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

symjax.tensor.random.dirichlet(key, alpha, shape=None, dtype=<class 'numpy.float64'>)[source]

Sample Dirichlet random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • alpha – an array of shape (..., n) used as the concentration parameter of the random variables.
  • shape – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last element of value n. Must be broadcast-compatible with alpha.shape[:-1]. The default (None) produces a result shape equal to alpha.shape.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and shape given by shape + (alpha.shape[-1],) if shape is not None, or else alpha.shape.

symjax.tensor.random.gamma(key, a, shape=None, dtype=<class 'numpy.float64'>)[source]

Sample Gamma random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • a – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.
  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a. The default (None) produces a result shape equal to a.shape.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by a.shape.

symjax.tensor.random.gumbel(key, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Gumbel random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

symjax.tensor.random.laplace(key, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample Laplace random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

symjax.tensor.random.logistic(key, shape=(), dtype=<class 'numpy.float64'>)[source]

Sample logistic random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

symjax.tensor.random.multivariate_normal(key: jax._src.numpy.lax_numpy.ndarray, mean: jax._src.numpy.lax_numpy.ndarray, cov: jax._src.numpy.lax_numpy.ndarray, shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]

Sample multivariate normal random values with given mean and covariance.

Parameters:
  • key – a PRNGKey used as the random key.
  • mean – a mean vector of shape (..., n).
  • cov – a positive definite covariance matrix of shape (..., n, n). The batch shape ... must be broadcast-compatible with that of mean.
  • shape – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with mean.shape[:-1] and cov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes of mean and cov.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and shape given by shape + mean.shape[-1:] if shape is not None, or else broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].

symjax.tensor.random.normal(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int] = (), dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]

Sample standard normal random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified shape and dtype.

symjax.tensor.random.pareto(key, b, shape=None, dtype=<class 'numpy.float64'>)[source]

Sample Pareto random values with given shape and float dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • a – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.
  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with b. The default (None) produces a result shape equal to b.shape.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by b.shape.

symjax.tensor.random.randint(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int], minval: Union[int, jax._src.numpy.lax_numpy.ndarray], maxval: Union[int, jax._src.numpy.lax_numpy.ndarray], dtype: numpy.dtype = <class 'numpy.int64'>)[source]

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – a tuple of nonnegative integers representing the shape.
  • minval – int or array of ints broadcast-compatible with shape, a minimum (inclusive) value for the range.
  • maxval – int or array of ints broadcast-compatible with shape, a maximum (exclusive) value for the range.
  • dtype – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).
Returns:

A random array with the specified shape and dtype.

symjax.tensor.random.shuffle(key: jax._src.numpy.lax_numpy.ndarray, x: jax._src.numpy.lax_numpy.ndarray, axis: int = 0) → jax._src.numpy.lax_numpy.ndarray[source]

Shuffle the elements of an array uniformly at random along an axis.

Parameters:
  • key – a PRNGKey used as the random key.
  • x – the array to be shuffled.
  • axis – optional, an int axis along which to shuffle (default 0).
Returns:

A shuffled version of x.

symjax.tensor.random.truncated_normal(key: jax._src.numpy.lax_numpy.ndarray, lower: Union[float, jax._src.numpy.lax_numpy.ndarray], upper: Union[float, jax._src.numpy.lax_numpy.ndarray], shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]

Sample truncated standard normal random values with given shape and dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • lower – a float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with upper.
  • upper – a float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with lower.
  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with lower and upper. The default (None) produces a result shape by broadcasting lower and upper.
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns:

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting lower and upper. Returns values in the open interval (lower, upper).

symjax.tensor.random.uniform(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int] = (), dtype: numpy.dtype = <class 'numpy.float64'>, minval: Union[float, jax._src.numpy.lax_numpy.ndarray] = 0.0, maxval: Union[float, jax._src.numpy.lax_numpy.ndarray] = 1.0) → jax._src.numpy.lax_numpy.ndarray[source]

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • key – a PRNGKey used as the random key.
  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
  • minval – optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
  • maxval – optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
Returns:

A random array with the specified shape and dtype.