symjax.tensor.random
¶
bernoulli (key, p, shape) |
Sample Bernoulli random values with given shape and mean. |
beta (key, a, …) |
Sample Beta random values with given shape and float dtype. |
cauchy (key[, shape, dtype]) |
Sample Cauchy random values with given shape and float dtype. |
dirichlet (key, alpha[, shape, dtype]) |
Sample Dirichlet random values with given shape and float dtype. |
gamma (key, a[, shape, dtype]) |
Sample Gamma random values with given shape and float dtype. |
gumbel (key[, shape, dtype]) |
Sample Gumbel random values with given shape and float dtype. |
laplace (key[, shape, dtype]) |
Sample Laplace random values with given shape and float dtype. |
logistic (key[, shape, dtype]) |
Sample logistic random values with given shape and float dtype. |
multivariate_normal (key, mean, cov, shape, dtype) |
Sample multivariate normal random values with given mean and covariance. |
normal (key, shape, dtype) |
Sample standard normal random values with given shape and float dtype. |
pareto (key, b[, shape, dtype]) |
Sample Pareto random values with given shape and float dtype. |
randint (key, shape, minval, …) |
Sample uniform random values in [minval, maxval) with given shape/dtype. |
shuffle (key, x, axis) |
Shuffle the elements of an array uniformly at random along an axis. |
truncated_normal (key, lower, …) |
Sample truncated standard normal random values with given shape and dtype. |
uniform (key, shape, dtype, minval, …) |
Sample uniform random values in [minval, maxval) with given shape/dtype. |
Detailed Description¶
-
symjax.tensor.random.
bernoulli
(key: jax._src.numpy.lax_numpy.ndarray, p: jax._src.numpy.lax_numpy.ndarray = 0.5, shape: Optional[Sequence[int]] = None) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample Bernoulli random values with given shape and mean.
Parameters: - key – a PRNGKey used as the random key.
- p – optional, a float or array of floats for the mean of the random
variables. Must be broadcast-compatible with
shape
. Default 0.5. - shape – optional, a tuple of nonnegative integers representing the result
shape. Must be broadcast-compatible with
p.shape
. The default (None) produces a result shape equal top.shape
.
Returns: A random array with boolean dtype and shape given by
shape
ifshape
is not None, or elsep.shape
.
-
symjax.tensor.random.
beta
(key: jax._src.numpy.lax_numpy.ndarray, a: Union[float, jax._src.numpy.lax_numpy.ndarray], b: Union[float, jax._src.numpy.lax_numpy.ndarray], shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample Beta random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- a – a float or array of floats broadcast-compatible with
shape
representing the first parameter “alpha”. - b – a float or array of floats broadcast-compatible with
shape
representing the second parameter “beta”. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
a
andb
. The default (None) produces a result shape by broadcastinga
andb
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinga
andb
.
-
symjax.tensor.random.
cauchy
(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample Cauchy random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
dirichlet
(key, alpha, shape=None, dtype=<class 'numpy.float64'>)[source]¶ Sample Dirichlet random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- alpha – an array of shape
(..., n)
used as the concentration parameter of the random variables. - shape – optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
element of value
n
. Must be broadcast-compatible withalpha.shape[:-1]
. The default (None) produces a result shape equal toalpha.shape
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape + (alpha.shape[-1],)
ifshape
is not None, or elsealpha.shape
.
-
symjax.tensor.random.
gamma
(key, a, shape=None, dtype=<class 'numpy.float64'>)[source]¶ Sample Gamma random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- a – a float or array of floats broadcast-compatible with
shape
representing the parameter of the distribution. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
a
. The default (None) produces a result shape equal toa.shape
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else bya.shape
.
-
symjax.tensor.random.
gumbel
(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample Gumbel random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
laplace
(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample Laplace random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
logistic
(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample logistic random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
multivariate_normal
(key: jax._src.numpy.lax_numpy.ndarray, mean: jax._src.numpy.lax_numpy.ndarray, cov: jax._src.numpy.lax_numpy.ndarray, shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample multivariate normal random values with given mean and covariance.
Parameters: - key – a PRNGKey used as the random key.
- mean – a mean vector of shape
(..., n)
. - cov – a positive definite covariance matrix of shape
(..., n, n)
. The batch shape...
must be broadcast-compatible with that ofmean
. - shape – optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
axis. Must be broadcast-compatible with
mean.shape[:-1]
andcov.shape[:-2]
. The default (None) produces a result batch shape by broadcasting together the batch shapes ofmean
andcov
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape + mean.shape[-1:]
ifshape
is not None, or elsebroadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]
.
-
symjax.tensor.random.
normal
(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int] = (), dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample standard normal random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
pareto
(key, b, shape=None, dtype=<class 'numpy.float64'>)[source]¶ Sample Pareto random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- a – a float or array of floats broadcast-compatible with
shape
representing the parameter of the distribution. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
b
. The default (None) produces a result shape equal tob.shape
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else byb.shape
.
-
symjax.tensor.random.
randint
(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int], minval: Union[int, jax._src.numpy.lax_numpy.ndarray], maxval: Union[int, jax._src.numpy.lax_numpy.ndarray], dtype: numpy.dtype = <class 'numpy.int64'>)[source]¶ Sample uniform random values in [minval, maxval) with given shape/dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – a tuple of nonnegative integers representing the shape.
- minval – int or array of ints broadcast-compatible with
shape
, a minimum (inclusive) value for the range. - maxval – int or array of ints broadcast-compatible with
shape
, a maximum (exclusive) value for the range. - dtype – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
shuffle
(key: jax._src.numpy.lax_numpy.ndarray, x: jax._src.numpy.lax_numpy.ndarray, axis: int = 0) → jax._src.numpy.lax_numpy.ndarray[source]¶ Shuffle the elements of an array uniformly at random along an axis.
Parameters: - key – a PRNGKey used as the random key.
- x – the array to be shuffled.
- axis – optional, an int axis along which to shuffle (default 0).
Returns: A shuffled version of x.
-
symjax.tensor.random.
truncated_normal
(key: jax._src.numpy.lax_numpy.ndarray, lower: Union[float, jax._src.numpy.lax_numpy.ndarray], upper: Union[float, jax._src.numpy.lax_numpy.ndarray], shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample truncated standard normal random values with given shape and dtype.
Parameters: - key – a PRNGKey used as the random key.
- lower – a float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with
upper
. - upper – a float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with
lower
. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
lower
andupper
. The default (None) produces a result shape by broadcastinglower
andupper
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinglower
andupper
. Returns values in the open interval(lower, upper)
.
-
symjax.tensor.random.
uniform
(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int] = (), dtype: numpy.dtype = <class 'numpy.float64'>, minval: Union[float, jax._src.numpy.lax_numpy.ndarray] = 0.0, maxval: Union[float, jax._src.numpy.lax_numpy.ndarray] = 1.0) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample uniform random values in [minval, maxval) with given shape/dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- minval – optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
- maxval – optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
Returns: A random array with the specified shape and dtype.