symjax.tensor.random¶
bernoulli(key, p, shape) |
Sample Bernoulli random values with given shape and mean. |
beta(key, a, …) |
Sample Beta random values with given shape and float dtype. |
cauchy(key[, shape, dtype]) |
Sample Cauchy random values with given shape and float dtype. |
dirichlet(key, alpha[, shape, dtype]) |
Sample Dirichlet random values with given shape and float dtype. |
gamma(key, a[, shape, dtype]) |
Sample Gamma random values with given shape and float dtype. |
gumbel(key[, shape, dtype]) |
Sample Gumbel random values with given shape and float dtype. |
laplace(key[, shape, dtype]) |
Sample Laplace random values with given shape and float dtype. |
logistic(key[, shape, dtype]) |
Sample logistic random values with given shape and float dtype. |
multivariate_normal(key, mean, cov, shape, dtype) |
Sample multivariate normal random values with given mean and covariance. |
normal(key, shape, dtype) |
Sample standard normal random values with given shape and float dtype. |
pareto(key, b[, shape, dtype]) |
Sample Pareto random values with given shape and float dtype. |
randint(key, shape, minval, …) |
Sample uniform random values in [minval, maxval) with given shape/dtype. |
shuffle(key, x, axis) |
Shuffle the elements of an array uniformly at random along an axis. |
truncated_normal(key, lower, …) |
Sample truncated standard normal random values with given shape and dtype. |
uniform(key, shape, dtype, minval, …) |
Sample uniform random values in [minval, maxval) with given shape/dtype. |
Detailed Description¶
-
symjax.tensor.random.bernoulli(key: jax._src.numpy.lax_numpy.ndarray, p: jax._src.numpy.lax_numpy.ndarray = 0.5, shape: Optional[Sequence[int]] = None) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample Bernoulli random values with given shape and mean.
Parameters: - key – a PRNGKey used as the random key.
- p – optional, a float or array of floats for the mean of the random
variables. Must be broadcast-compatible with
shape. Default 0.5. - shape – optional, a tuple of nonnegative integers representing the result
shape. Must be broadcast-compatible with
p.shape. The default (None) produces a result shape equal top.shape.
Returns: A random array with boolean dtype and shape given by
shapeifshapeis not None, or elsep.shape.
-
symjax.tensor.random.beta(key: jax._src.numpy.lax_numpy.ndarray, a: Union[float, jax._src.numpy.lax_numpy.ndarray], b: Union[float, jax._src.numpy.lax_numpy.ndarray], shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample Beta random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- a – a float or array of floats broadcast-compatible with
shaperepresenting the first parameter “alpha”. - b – a float or array of floats broadcast-compatible with
shaperepresenting the second parameter “beta”. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
aandb. The default (None) produces a result shape by broadcastingaandb. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shapeifshapeis not None, or else by broadcastingaandb.
-
symjax.tensor.random.cauchy(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample Cauchy random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.dirichlet(key, alpha, shape=None, dtype=<class 'numpy.float64'>)[source]¶ Sample Dirichlet random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- alpha – an array of shape
(..., n)used as the concentration parameter of the random variables. - shape – optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
element of value
n. Must be broadcast-compatible withalpha.shape[:-1]. The default (None) produces a result shape equal toalpha.shape. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape + (alpha.shape[-1],)ifshapeis not None, or elsealpha.shape.
-
symjax.tensor.random.gamma(key, a, shape=None, dtype=<class 'numpy.float64'>)[source]¶ Sample Gamma random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- a – a float or array of floats broadcast-compatible with
shaperepresenting the parameter of the distribution. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
a. The default (None) produces a result shape equal toa.shape. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and with shape given by
shapeifshapeis not None, or else bya.shape.
-
symjax.tensor.random.gumbel(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample Gumbel random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.laplace(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample Laplace random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.logistic(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample logistic random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.multivariate_normal(key: jax._src.numpy.lax_numpy.ndarray, mean: jax._src.numpy.lax_numpy.ndarray, cov: jax._src.numpy.lax_numpy.ndarray, shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample multivariate normal random values with given mean and covariance.
Parameters: - key – a PRNGKey used as the random key.
- mean – a mean vector of shape
(..., n). - cov – a positive definite covariance matrix of shape
(..., n, n). The batch shape...must be broadcast-compatible with that ofmean. - shape – optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
axis. Must be broadcast-compatible with
mean.shape[:-1]andcov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes ofmeanandcov. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape + mean.shape[-1:]ifshapeis not None, or elsebroadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].
-
symjax.tensor.random.normal(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int] = (), dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample standard normal random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.pareto(key, b, shape=None, dtype=<class 'numpy.float64'>)[source]¶ Sample Pareto random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- a – a float or array of floats broadcast-compatible with
shaperepresenting the parameter of the distribution. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
b. The default (None) produces a result shape equal tob.shape. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and with shape given by
shapeifshapeis not None, or else byb.shape.
-
symjax.tensor.random.randint(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int], minval: Union[int, jax._src.numpy.lax_numpy.ndarray], maxval: Union[int, jax._src.numpy.lax_numpy.ndarray], dtype: numpy.dtype = <class 'numpy.int64'>)[source]¶ Sample uniform random values in [minval, maxval) with given shape/dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – a tuple of nonnegative integers representing the shape.
- minval – int or array of ints broadcast-compatible with
shape, a minimum (inclusive) value for the range. - maxval – int or array of ints broadcast-compatible with
shape, a maximum (exclusive) value for the range. - dtype – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.shuffle(key: jax._src.numpy.lax_numpy.ndarray, x: jax._src.numpy.lax_numpy.ndarray, axis: int = 0) → jax._src.numpy.lax_numpy.ndarray[source]¶ Shuffle the elements of an array uniformly at random along an axis.
Parameters: - key – a PRNGKey used as the random key.
- x – the array to be shuffled.
- axis – optional, an int axis along which to shuffle (default 0).
Returns: A shuffled version of x.
-
symjax.tensor.random.truncated_normal(key: jax._src.numpy.lax_numpy.ndarray, lower: Union[float, jax._src.numpy.lax_numpy.ndarray], upper: Union[float, jax._src.numpy.lax_numpy.ndarray], shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample truncated standard normal random values with given shape and dtype.
Parameters: - key – a PRNGKey used as the random key.
- lower – a float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with
upper. - upper – a float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with
lower. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
lowerandupper. The default (None) produces a result shape by broadcastinglowerandupper. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shapeifshapeis not None, or else by broadcastinglowerandupper. Returns values in the open interval(lower, upper).
-
symjax.tensor.random.uniform(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int] = (), dtype: numpy.dtype = <class 'numpy.float64'>, minval: Union[float, jax._src.numpy.lax_numpy.ndarray] = 0.0, maxval: Union[float, jax._src.numpy.lax_numpy.ndarray] = 1.0) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample uniform random values in [minval, maxval) with given shape/dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- minval – optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
- maxval – optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
Returns: A random array with the specified shape and dtype.