Source code for symjax.tensor.linalg

import inspect
import sys

import jax.numpy.linalg as jnpl
import jax.scipy.linalg as jspl

from .base import jax_wrap
from .normalization import normalize

from . import random
from . import ops_numpy as T
from . import ops_special as S
from . import control_flow as C

from_scipy = [
    "cholesky",
    "block_diag",
    "cho_solve",
    "eigh",
    "expm",
    # "expm_frechet",
    "inv",
    "lu",
    "lu_factor",
    "lu_solve",
    "solve_triangular",
    "tril",
    "triu",
]

NAMES = [c[0] for c in inspect.getmembers(jnpl, inspect.isfunction)] + [
    "pinv",
    "slogdet",
]
NAMES.remove("norm")

module = sys.modules[__name__]
for name in NAMES:
    if name not in from_scipy:
        module.__dict__.update({name: jax_wrap(jnpl.__dict__[name])})

for name in from_scipy:
    module.__dict__.update({name: jax_wrap(jspl.__dict__[name])})


_norm = jax_wrap(jnpl.__dict__["norm"])


[docs]def norm(x, ord=2, axis=None, keepdims=False): """ Tensor/Matrix/Vector norm. For matrices and vectors, this function is able to return one of eight different matrix norms, or one of an infinite number of vector norms (described below), depending on the value of the ``ord`` parameter. for higher-dimensional tensors, only :math:`0<ord<\\infty` is supported. Parameters ---------- x : array_like Input array. If `axis` is None, `x` must be 1-D or 2-D, unless `ord` is None. If both `axis` and `ord` are None, the 2-norm of ``x.ravel`` will be returned. ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional Order of the norm (see table under ``Notes``). inf means numpy's `inf` object. The default is `2`. axis : {None, int, 2-tuple of ints}, optional. If `axis` is an integer, it specifies the axis of `x` along which to compute the vector norms. If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If `axis` is None then either a vector norm (when `x` is 1-D) or a matrix norm (when `x` is 2-D) is returned. The default is None. .. versionadded:: 1.8.0 keepdims : bool, optional If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original `x`. .. versionadded:: 1.10.0 Returns ------- n : float or ndarray Norm of the matrix or vector(s). See Also -------- scipy.linalg.norm : Similar function in SciPy. Notes ----- For values of ``ord < 1``, the result is, strictly speaking, not a mathematical 'norm', but it may still be useful for various numerical purposes. The following norms can be calculated: ===== ============================ ========================== ord norm for matrices norm for vectors ===== ============================ ========================== None Frobenius norm 2-norm 'fro' Frobenius norm -- 'nuc' nuclear norm -- inf max(sum(abs(x), axis=1)) max(abs(x)) -inf min(sum(abs(x), axis=1)) min(abs(x)) 0 -- sum(x != 0) 1 max(sum(abs(x), axis=0)) as below -1 min(sum(abs(x), axis=0)) as below 2 2-norm (largest sing. value) as below -2 smallest singular value as below other -- sum(abs(x)**ord)**(1./ord) ===== ============================ ========================== The Frobenius norm is given by [1]_: :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` The nuclear norm is the sum of the singular values. Both the Frobenius and nuclear norm orders are only defined for matrices and raise a ValueError when ``x.ndim != 2``. References ---------- .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 Examples -------- >>> from numpy import linalg as LA >>> a = np.arange(9) - 4 >>> a array([-4, -3, -2, ..., 2, 3, 4]) >>> b = a.reshape((3, 3)) >>> b array([[-4, -3, -2], [-1, 0, 1], [ 2, 3, 4]]) >>> LA.norm(a) 7.745966692414834 >>> LA.norm(b) 7.745966692414834 >>> LA.norm(b, 'fro') 7.745966692414834 >>> LA.norm(a, np.inf) 4.0 >>> LA.norm(b, np.inf) 9.0 >>> LA.norm(a, -np.inf) 0.0 >>> LA.norm(b, -np.inf) 2.0 >>> LA.norm(a, 1) 20.0 >>> LA.norm(b, 1) 7.0 >>> LA.norm(a, -1) -4.6566128774142013e-010 >>> LA.norm(b, -1) 6.0 >>> LA.norm(a, 2) 7.745966692414834 >>> LA.norm(b, 2) 7.3484692283495345 >>> LA.norm(a, -2) 0.0 >>> LA.norm(b, -2) 1.8570331885190563e-016 # may vary >>> LA.norm(a, 3) 5.8480354764257312 # may vary >>> LA.norm(a, -3) 0.0 Using the `axis` argument to compute vector norms: >>> c = np.array([[ 1, 2, 3], ... [-1, 1, 4]]) >>> LA.norm(c, axis=0) array([ 1.41421356, 2.23606798, 5. ]) >>> LA.norm(c, axis=1) array([ 3.74165739, 4.24264069]) >>> LA.norm(c, ord=1, axis=1) array([ 6., 6.]) Using the `axis` argument to compute matrix norms: >>> m = np.arange(8).reshape(2,2,2) >>> LA.norm(m, axis=(1,2)) array([ 3.74165739, 11.22497216]) >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) (3.7416573867739413, 11.224972160321824) """ if hasattr(axis, "__len__"): if len(axis) > 2: return T.power( T.power(T.abs(x), ord).sum(axis=axis, keepdims=keepdims), 1.0 / ord, ) else: return _norm(x, ord, axis, keepdims) else: return _norm(x, ord, axis, keepdims)
[docs]def singular_vectors_power_iteration(weight, axis=0, n_iters=1): # This power iteration produces approximations of `u` and `v`. u = normalize(random.randn(weight.shape[0]), dim=0) v = normalize(random.randn(weight.shape[1]), dim=0) for _ in range(n_iters): v = normalize(weight.t().dot(u), dim=0) u = normalize(weight.dot(v), dim=0) return u, v
[docs]def eigenvector_power_iteration(weight, axis=0, n_iters=1): # This power iteration produces approximations of `u`. u = normalize(random.randn(weight.shape[0]), dim=0) for _ in range(n_iters): u = normalize(weight.t().dot(u), dim=0) return u
def _proj(v, u): return v - v.dot(u) * u / (1e-28 + u.dot(u)), 0 def gram_schmidt(V, normalize=True): """gram-schmidt orthogonalization Parameters: ----------- V: Tensor of rank 2 a matrix to orthogonalize. The vectors should be in the rows of V normalize: bool whether to renormalize the orthogonalized vectors or not, default to ``True`` Returns: -------- U: Tensor of rank 2 a matrix with orthogonalized rows from V, note that those vectors are not normalized """ if normalize: U = S.index_add(T.zeros_like(V), 0, V[0] / norm(V[0], 2)) else: U = S.index_add(T.zeros_like(V), 0, V[0]) def fn(U, v, k): coeffs = T.dot(U, v) / ((U ** 2).sum(1) + 1e-28) p = v - (U * coeffs[:, None]).sum(0) if normalize: return S.index_update(U, k, p / norm(p, 2)), k else: return S.index_update(U, k, p / norm(p, 2)), k U, _ = C.scan(fn, init=U, sequences=[V[1:], T.arange(1, V.shape[0])]) return U def modified_gram_schmidt(V): """modified gram-schmidt orthogonalization Parameters: ----------- V: Tensor of rank 2 a matrix to orthogonalize. The vectors should be in the rows of V Returns: -------- U: Tensor of rank 2 a matrix with orthogonalized rows from V, note that those vectors are not normalized """ U = S.index_add(T.zeros_like(V), 0, V[0] / norm(V[0], 2)) def fn(U, v, k): coeffs = T.dot(U, v) / ((U ** 2).sum(1) + 1e-28) uk, _ = C.scan(_proj, init=v, sequences=[U]) return S.index_update(U, k, uk / norm(uk, 2)), 0 U, _ = C.scan(fn, init=U, sequences=[V[1:], T.arange(1, V.shape[0])]) return U