symjax.probabilities

Implementation of basic distribution, their (log) densities, sampling, KL divergence, entropies

Categorical([probabilities, logits, eps])
Normal(mean, cov) (batched, multivariate) normal distribution
KL(X, Y[, EPS]) Normal: distributions are specified by means and log stds.

Detailed Descriptions

symjax.probabilities.KL(X, Y, EPS=1e-08)[source]

Normal: distributions are specified by means and log stds. (https://en.wikipedia.org/wiki/Kullback-Leibler_divergence#Multivariate_normal_distributions)

\[ \begin{align}\begin{aligned}KL(p||q)=\int [\log(p(x))-\log(q(x))]p(x)dx\\=\int[\frac{1}{2}log(\frac{|\Sigma_2|}{|\Sigma_1|})−\frac{1}{2}(x−\mu_1)^𝑇\Sigma_1^{-1}(x−\mu_1)+\frac{1}{2}(x−\mu_2)^𝑇\Sigma_2^{−1}(x−\mu_2)] p(x)dx\\=\frac{1}{2}log(\frac{|\Sigma_2|}{|\Sigma_1|})−\frac{1}{2}tr {𝐸[(x−\mu_1)(x−\mu_1)^𝑇] Σ−11}+\frac{1}{2}𝐸[(x−\mu_2)^𝑇\Sigma_2^{−1}(x−\mu_2)]\\=\frac{1}{2}log(\frac{|\Sigma_2|}{|\Sigma_1|})−\frac{1}{2}tr {𝐼𝑑}+\frac{1}{2}(\mu_1−\mu_2)^𝑇Σ_2^{-1}(\mu_1−\mu_2)+\frac{1}{2}tr{\Sigma_2^{-1}\Sigma_1}\\=\frac{1}{2}[log(\frac{|\Sigma_2|}{|\Sigma_1|})−𝑑+tr{\Sigma_2^{−1}\Sigma_1}+(\mu_2−\mu_1)^𝑇\Sigma_2^{−1}(\mu_2−\mu_1)].\end{aligned}\end{align} \]
class symjax.probabilities.Categorical(probabilities=None, logits=None, eps=1e-08)[source]
class symjax.probabilities.Normal(mean, cov)[source]

(batched, multivariate) normal distribution

Parameters:
  • mean (N dimensional Tensor) – the mean of the normal distribution, the last dimension is the one used to represent the dimension of the data, the first dimensions are indexed ones
  • cov ((N or N+1) dimensional Tensor) – the covariance matrix, if N-dimensional then it is assumed to be diagonal, if (N+1)-dimensional then the last 2 dimensions are the ones representing the covariance dimensions and thus their shape should be equal
entropy()[source]

Compute the differential entropy of the multivariate normal.

Returns:h – Entropy of the multivariate normal distribution
Return type:scalar
log_prob(x)[source]

Log of the multivariate normal probability density function.

Parameters:x (Tensor) – samples to use to evaluate the log pdf, with the last axis of x denoting the components.
Returns:pdf – Log of the probability density function evaluated at x
Return type:Tensor
prob(value)[source]

Multivariate normal probability density function.

Parameters:x (Tensor) – samples to use to evaluate the log pdf, with the last axis of x denoting the components.
Returns:pdf – Probability density function evaluated at x
Return type:Tensor
sample()[source]

Draw random samples from a multivariate normal distribution.

Returns:rvs – Random variates based on given mean and cov.
Return type:ndarray or scalar