symjax.tensor.linalg.solve_triangular

symjax.tensor.linalg.solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True)[source]

Solve the equation a x = b for x, assuming a is a triangular matrix.

LAX-backend implementation of solve_triangular(). Original docstring below.

Parameters:
  • a ((M, M) array_like) – A triangular matrix
  • b ((M,) or (M, N) array_like) – Right-hand side matrix in a x = b
  • lower (bool, optional) – Use only data contained in the lower triangle of a. Default is to use upper triangle.
  • trans ({0, 1, 2, 'N', 'T', 'C'}, optional) – Type of system to solve:
  • unit_diagonal (bool, optional) – If True, diagonal elements of a are assumed to be 1 and will not be referenced.
  • overwrite_b (bool, optional) – Allow overwriting data in b (may enhance performance)
  • check_finite (bool, optional) – Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns:

x – Solution to the system a x = b. Shape of return matches b.

Return type:

(M,) or (M, N) ndarray

Raises:

LinAlgError – If a is singular

Notes

New in version 0.9.0.

Examples

Solve the lower triangular system a x = b, where:

     [3  0  0  0]       [4]
a =  [2  1  0  0]   b = [2]
     [1  0  1  0]       [4]
     [1  1  1  1]       [2]
>>> from scipy.linalg import solve_triangular
>>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
>>> b = np.array([4, 2, 4, 2])
>>> x = solve_triangular(a, b, lower=True)
>>> x
array([ 1.33333333, -0.66666667,  2.66666667, -1.33333333])
>>> a.dot(x)  # Check the result
array([ 4.,  2.,  4.,  2.])