symjax.nn

Implements the machine learning/deep learning utilities to train and create/adapt any state of the art deep neural network, do training, adapt learning rates, etc

Activation functions

relu(x) Rectified linear unit activation function.
relu6(x) Rectified Linear Unit 6 activation function.
sigmoid(x) Sigmoid activation function.
softplus(x) Softplus activation function.
soft_sign(x) Soft-sign activation function.
silu(x) SiLU activation function.
swish(x, beta) Swish activation function.
log_sigmoid(x) Log-sigmoid activation function.
leaky_relu(x[, negative_slope]) Leaky rectified linear unit activation function.
hard_sigmoid(x) Hard Sigmoid activation function.
hard_silu(x) Hard SiLU activation function
hard_swish
hard_tanh(x) Hard \(\mathrm{tanh}\) activation function.
elu(x[, alpha]) Exponential linear unit activation function.
celu(x[, alpha]) Continuously-differentiable exponential linear unit activation.
selu(x) Scaled exponential linear unit activation.
gelu(x, approximate) Gaussian error linear unit activation function.
glu(linear_x, gated_x[, axis]) Gated linear unit activation function.

Other Ops

softmax(x[, axis]) Softmax function.
log_softmax(x[, axis]) Log-Softmax function.
normalize(x[, axis, mean, variance, epsilon]) Normalizes an array by subtracting mean and dividing by sqrt(var).
one_hot

Detailed Descriptions

symjax.nn.relu(x)[source]

Rectified linear unit activation function.

Computes the element-wise function:

\[\mathrm{relu}(x) = \max(x, 0)\]
symjax.nn.relu6(x)[source]

Rectified Linear Unit 6 activation function.

Computes the element-wise function

\[\mathrm{relu6}(x) = \min(\max(x, 0), 6)\]
symjax.nn.sigmoid(x)[source]

Sigmoid activation function.

Computes the element-wise function:

\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
symjax.nn.softplus(x)[source]

Softplus activation function.

Computes the element-wise function

\[\mathrm{softplus}(x) = \log(1 + e^x)\]
symjax.nn.soft_sign(x)[source]

Soft-sign activation function.

Computes the element-wise function

\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]
symjax.nn.silu(x)[source]

SiLU activation function.

Computes the element-wise function:

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]
symjax.nn.swish(x, beta)[source]

Swish activation function.

Computes the element-wise function:

\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-\beta * x}}\]
symjax.nn.log_sigmoid(x)[source]

Log-sigmoid activation function.

Computes the element-wise function:

\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
symjax.nn.leaky_relu(x, negative_slope=0.01)[source]

Leaky rectified linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]

where \(\alpha\) = negative_slope.

symjax.nn.hard_sigmoid(x)[source]

Hard Sigmoid activation function.

Computes the element-wise function

\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]
symjax.nn.hard_silu(x)[source]

Hard SiLU activation function

Computes the element-wise function

\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]
symjax.nn.hard_tanh(x)[source]

Hard \(\mathrm{tanh}\) activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & 0 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]
symjax.nn.elu(x, alpha=1.0)[source]

Exponential linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
symjax.nn.celu(x, alpha=1.0)[source]

Continuously-differentiable exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]

For more information, see Continuously Differentiable Exponential Linear Units.

symjax.nn.selu(x)[source]

Scaled exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]

where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).

For more information, see Self-Normalizing Neural Networks.

symjax.nn.gelu(x, approximate: bool = True)[source]

Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]

If approximate=True, uses the approximate formulation of GELU:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

For more information, see Gaussian Error Linear Units (GELUs), section 2.

Parameters:approximate – whether to use the approximate or exact formulation.
symjax.nn.glu(linear_x, gated_x, axis=-1)[source]

Gated linear unit activation function.

symjax.nn.softmax(x, axis=-1)[source]

Softmax function.

Computes the function which rescales elements to the range \([0, 1]\) such that the elements along axis sum to \(1\).

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
Parameters:axis – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.
symjax.nn.log_softmax(x, axis=-1)[source]

Log-Softmax function.

Computes the logarithm of the softmax function, which rescales elements to the range \([-\infty, 0)\).

\[\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]
Parameters:axis – the axis or axes along which the log_softmax should be computed. Either an integer or a tuple of integers.
symjax.nn.normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-05)[source]

Normalizes an array by subtracting mean and dividing by sqrt(var).