Source code for symjax.tensor.ops_numpy

from .base import Tensor, _add_method
import inspect
import sys

import jax.lax as jla
import jax.numpy as jnp

from .base import jax_wrap
import symjax

module = sys.modules[__name__]

_JNP_NAMES = [c[0] for c in inspect.getmembers(jnp, inspect.isfunction)]
_TO_SKIP = [
    "<lambda>",
    "blackman",
    "bartlett",
    "hamming",
    "hanning",
    "kaiser",
    "add_docstring",
    "add_newdoc",
    "alen",
    "apply_along_axis",
    "apply_over_axes",
    "array2string",
    "array_equiv",
    "array_repr",
    "array_split",
    "array_str",
    "asanyarray",
    "asarray_chkfinite",
    "ascontiguousarray",
    "asfarray",
    "asfortranarray",
    "copy",
    "copyto",
    "custom_tra,nsforms",
    "delete",
    "deprecate",
    "device_put",
    "digitize",
    "disp",
    "function",
    "func",
    "find_common_type",
    "format_float_positional",
    "format_float_scientific",
    "frexp",
    "frombuffer",
    "fromfile",
    "fromfunction",
    "fromiter",
    "frompyfunc",
    "fromregex",
    "fromstring",
    "fv",
    "genfromtxt",
    "get_array_wrap",
    "get_include",
    "get_module_functions",
    "get_printoptions",
    "getbufsize",
    "geterr",
    "geterrcall",
    "geterrobj",
    "gradient",
    "int_asbuffer",
    "is_busday",
    "isrealobj",
    "issctype",
    "issubclass_",
    "issubdtype",
    "issubsctype",
    "iterable",
    "jit",
    "load",
    "loads",
    "loadtxt",
    "lookfor",
    "mafromtxt",
    "maximum_sctype",
    "may_share_memory",
    "mintypecode",
    "ndfromtxt",
    "nested_iters",
    "nextafter",
    "nper",
    "npv",
    "obj2sctype",
    "packbits",
    "printoptions",
    "recfromcsv",
    "recfromtxt",
    "removechars",
    "result_type",
    "safe_eval",
    "save",
    "savetxt",
    "savez",
    "savez_compressed",
    "sctype2char",
    "searchsorted",
    "select",
    "shares_memory",
    "show",
    "size",
    "source",
    "spacing",
    "strtobool",
    "trapz",
    "typename",
    "union1d",
    "unique",
    "unpackbits",
    "update_numpydoc",
    "who",
    "zeros_like",
    "ones_like",
]

for name in _JNP_NAMES:
    if name in _TO_SKIP:
        continue
    module.__dict__.update({name: jax_wrap(jnp.__dict__[name])})

for name in [
    "issubsctype",
    "issubdtype",
    "nextafter",
    "result_type",
    "select",
    "isneginf",
    "isposinf",
    "logaddexp",
    "logaddexp2",
]:
    module.__dict__.update({name: jax_wrap(jnp.__dict__[name])})


_zeros_like = jax_wrap(jnp.zeros_like)
_ones_like = jax_wrap(jnp.ones_like)


[docs]def zeros_like(input, detach=False): if detach: return module.__dict__["zeros"](input.shape, input.dtype) else: return module.__dict__["_zeros_like"](input)
[docs]def ones_like(input, detach=False): if detach: return module.__dict__["ones"](input.shape, input.dtype) else: return module.__dict__["_ones_like"](input)
cast = jax_wrap(jla.convert_element_type) complex = jax_wrap(jla.complex) range = module.__dict__["arange"]
[docs]def flatten(input): """reshape the input into a vector""" return module.__dict__["reshape"](input, (-1,))
[docs]def flatten2d(input): """reshape the input into a matrix""" assert input.ndim > 1 if input.ndim == 2: return input return reshape(input, (input.shape[0], -1))
[docs]def flatten3d(input): """reshape the input into a 3D-tensor""" assert input.ndim > 1 if input.ndim == 3: return input return reshape(input, (input.shape[0], input.shape[1], -1))
[docs]def flatten4d(input): """reshape the input into a 4D-tensor""" assert input.ndim > 1 if input.ndim == 4: return input return reshape(input, (input.shape[0], input.shape[1], input.shape[2], -1))
################ getitem = jax_wrap(jnp.lax_numpy._rewriting_take) _add_method(Tensor)(getitem, "__getitem__") # overloading the basic arithmetic operators _add_method(Tensor)(module.__dict__["add"], "__add__") _add_method(Tensor)(module.__dict__["add"], "__radd__") _add_method(Tensor)(module.__dict__["multiply"], "__mul__") _add_method(Tensor)(module.__dict__["multiply"], "__rmul__") _add_method(Tensor)(module.__dict__["true_divide"], "__truediv__") _add_method(Tensor)(lambda a, b: module.__dict__["true_divide"](b, a), "__rtruediv__") _add_method(Tensor)(module.__dict__["floor_divide"], "__floordiv__") _add_method(Tensor)(lambda a, b: module.__dict__["floor_divide"](b, a), "__rfloordiv__") _add_method(Tensor)(module.__dict__["subtract"], "__sub__") _add_method(Tensor)(lambda a, b: module.__dict__["subtract"](b, a), "__rsub__") _add_method(Tensor)(module.__dict__["power"], "__pow__") _add_method(Tensor)(module.__dict__["negative"], "__neg__") # logical operator _add_method(Tensor)(module.__dict__["bitwise_or"], "__or__") _add_method(Tensor)(module.__dict__["bitwise_and"], "__and__") _add_method(Tensor)(module.__dict__["bitwise_xor"], "__xor__") _add_method(Tensor)(module.__dict__["invert"], "__invert__") _add_method(Tensor)(module.__dict__["left_shift"], "__lshift__") _add_method(Tensor)(module.__dict__["right_shift"], "__rshift__") # overloading comparison operators _add_method(Tensor)(module.__dict__["less"], "__lt__") _add_method(Tensor)(module.__dict__["greater_equal"], "__rlt__") _add_method(Tensor)(module.__dict__["greater"], "__gt__") _add_method(Tensor)(module.__dict__["less_equal"], "__rgt__") _add_method(Tensor)(module.__dict__["greater_equal"], "__ge__") _add_method(Tensor)(module.__dict__["less"], "__rge__") _add_method(Tensor)(module.__dict__["less_equal"], "__le__") _add_method(Tensor)(module.__dict__["greater"], "__rle__") # additional basic arithmetic operators _add_method(Tensor)(module.__dict__["sum"], "sum") _add_method(Tensor)(module.__dict__["prod"], "prod") _add_method(Tensor)(module.__dict__["mean"], "mean") _add_method(Tensor)(module.__dict__["max"], "max") _add_method(Tensor)(module.__dict__["min"], "min") _add_method(Tensor)(module.__dict__["std"], "std") _add_method(Tensor)(module.__dict__["var"], "var") _add_method(Tensor)(module.__dict__["argmax"], "argmax") _add_method(Tensor)(module.__dict__["argmin"], "argmin") # additional convenient operators _add_method(Tensor)(module.__dict__["real"], "real") _add_method(Tensor)(module.__dict__["imag"], "imag") _add_method(Tensor)(module.__dict__["conjugate"], "conj") _add_method(Tensor)(module.__dict__["conjugate"], "conjugate") _add_method(Tensor)(module.__dict__["cast"], "cast") _add_method(Tensor)(module.__dict__["cast"], "astype") _add_method(Tensor)(module.__dict__["squeeze"], "squeeze") _add_method(Tensor)(module.__dict__["flatten"], "flatten") _add_method(Tensor)(module.__dict__["flatten2d"], "flatten2d") _add_method(Tensor)(module.__dict__["reshape"], "reshape") _add_method(Tensor)(module.__dict__["dot"], "dot") _add_method(Tensor)(module.__dict__["repeat"], "repeat") _add_method(Tensor)(module.__dict__["expand_dims"], "expand_dims") _add_method(Tensor)(module.__dict__["matmul"], "matmul") _add_method(Tensor)(module.__dict__["round"], "round") _add_method(Tensor)(module.__dict__["transpose"], "transpose") _add_method(Tensor)(module.__dict__["clip"], "clip") from symjax.tensor.linalg import norm _add_method(Tensor)(norm, "norm")