symjax.tensor.linalg.lu

symjax.tensor.linalg.lu(a, permute_l=False, overwrite_a=False, check_finite=True)[source]

Compute pivoted LU decomposition of a matrix.

LAX-backend implementation of lu(). Original docstring below.

The decomposition is:

A = P L U

where P is a permutation matrix, L lower triangular with unit diagonal elements, and U upper triangular.

a : (M, N) array_like
Array to decompose
permute_l : bool, optional
Perform the multiplication P*L (Default: do not permute)
overwrite_a : bool, optional
Whether to overwrite data in a (may improve performance)
check_finite : bool, optional
Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.

(If permute_l == False)

p : (M, M) ndarray
Permutation matrix
l : (M, K) ndarray
Lower triangular or trapezoidal matrix with unit diagonal. K = min(M, N)
u : (K, N) ndarray
Upper triangular or trapezoidal matrix

(If permute_l == True)

pl : (M, K) ndarray
Permuted L matrix. K = min(M, N)
u : (K, N) ndarray
Upper triangular or trapezoidal matrix

This is a LU factorization routine written for SciPy.

>>> from scipy.linalg import lu
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> p, l, u = lu(A)
>>> np.allclose(A - p @ l @ u, np.zeros((4, 4)))
True