Source code for symjax.tensor.base

import re
from functools import wraps

import warnings
import jax
import jax.numpy as jnp
import numpy
import os
import symjax
import numpy as np


def is_constant(item):
    """checks if node only depends on constants/shapes"""

    key = "only_involves_shapes_or_constants"
    if isinstance(item, Constant) or not isvar(item) or item is None:
        return True
    elif type(item) in [list, tuple]:
        return numpy.all([is_constant(p) for p in item])
    else:
        return False


def _add_method(cls):
    # important we keep the self inside the function call !
    def decorator(func, name=""):
        @wraps(func)
        def wrapper(self, *args, **kwargs):
            return func(self, *args, **kwargs)

        if name == "":
            setattr(cls, func.__name__, wrapper)
        else:
            setattr(cls, name, wrapper)
        return func  # returning func means func can still be used normally

    return decorator


def _args_formatting(args, extra_args, indices):
    """utility function to be used in the Tensor class to correctly join the
    args and extra_args based on the indices

    Parameters:
    -----------

    args: List

    extra_args: List

    indices: List of binary values
        the indices (one per element) to join args and extra_args in the correct
        order

    """
    output = ()
    arg_iterator = iter(args)
    extra_arg_iterator = iter(extra_args)
    for i in indices:
        if i:
            output += (next(extra_arg_iterator),)
        else:
            output += (next(arg_iterator),)
    return output


def isvar(item):
    """check whether an item (possibly a nested list etc) contains a variable
    (any subtype of Tensor)"""
    # in case of nested lists/tuples, recursively call the function on it
    if type(item) == slice:
        return False
    elif type(item) == list or type(item) == tuple:
        return numpy.any([isvar(value) for value in item])
    # otherwise cheack that it is a subtype of Tensor or a Tracer and not
    # a callable
    else:
        cond1 = isinstance(item, Tensor) or (type(item) in [Constant, MultiOutputOp])
        #        cond2 = isinstance(item, jax.interpreters.partial_eval.JaxprTracer)
        cond3 = callable(item)
        return cond1 and not cond3  # (cond1 or cond2) and cond3


_numpy_signature_re = re.compile(r"^([\w., ]+=)?\s*[\w\.]+\(.*\)$")


def update_numpydoc(docstr, fun, op):
    """Transforms the numpy docstring to remove references of
    parameters that are supported by the numpy version but not the JAX version"""

    # Some numpy functions have an extra tab at the beginning of each line,
    # If this function is one of those we remove this extra tab from all the lines
    if not hasattr(op, "__code__"):
        return docstr
    if docstr[:4] == "    ":
        lines = docstr.split("\n")
        for idx, line in enumerate(lines):
            lines[idx] = line.replace("    ", "", 1)
        docstr = "\n".join(lines)

    begin_idx = docstr.find("Parameters")
    begin_idx = docstr.find("--\n", begin_idx) + 2
    end_idx = docstr.find("Returns", begin_idx)

    parameters = docstr[begin_idx:end_idx]
    param_list = parameters.replace("\n    ", "@@").split("\n")
    for idx, p in enumerate(param_list):
        param = p[: p.find(" : ")].split(", ")[0]
        if param not in op.__code__.co_varnames:
            param_list[idx] = ""
    param_list = [param for param in param_list if param != ""]
    parameters = "\n".join(param_list).replace("@@", "\n    ")
    return docstr[: begin_idx + 1] + parameters + docstr[end_idx - 2 :]


def get_output_tree(
    jax_function,
    *args,
    **kwargs,
):

    # we need to remove the static arguments first

    # we first do it for the kwargs
    static_kwargs = {}
    var_kwargs = {}
    for name, arg in list(kwargs.items()):
        # dummy, static = create_dummy(arg)
        if is_constant(arg):
            static_kwargs.update({name: arg})
        else:
            var_kwargs.update({name: arg})

    # we need to do the same for the args
    static_args = []
    var_args = []
    who_static = []
    for i, arg in enumerate(args):
        # dummy, static = create_dummy(arg)
        # who_static.append(static)
        if is_constant(arg):
            who_static.append(1)
            static_args.append(arg)
        else:
            who_static.append(0)
            var_args.append(arg)

    # we need to define an abstract function that only takes as input the
    # non-static arguments, internally join them with the static ones
    # and return the output. This is because the jax shape inference
    # functions does not work with static arguments (such as the dimensions
    # of the transpose function)

    def abstract_func(*args, **kwargs):
        all_args = _args_formatting(args, static_args, who_static)
        return jax_function(*all_args, **kwargs, **static_kwargs)

    # now we evaluate the shape from the jax built-in function
    tree = jax.eval_shape(abstract_func, *var_args, **var_kwargs)
    return tree


def jax_wrap(func, doc_func=None):
    if doc_func is None:
        doc_func = func

    @wraps(doc_func)
    def op(*args, seed=None, **kwargs):

        # if there is a name we remove it for now to use the jax tracer
        op_name = kwargs.pop("name", None)
        args = list(args)

        for i in range(len(args)):
            if isinstance(args[i], MultiOutputOp):
                args[i] = tuple(args[i])
        args = tuple(args)
        for key, value in kwargs.items():
            if isinstance(value, MultiOutputOp):
                kwargs[key] = tuple(value)

        # first we check if we are in a random function to be careful
        # with the key. this is just to get shape and dtype so we do not bother
        # to use the correct seed yet
        is_random = func in symjax.tensor.random._RANDOM_FUNCTIONS
        temp_args = ((jax.random.PRNGKey(0),) if is_random else ()) + args
        tree = get_output_tree(func, *temp_args, **kwargs)

        # now we determine what type of Tensor subclass it will produce
        feed = {"_jax_function": func, "name": op_name}
        if type(tree) == list or type(tree) == tuple:
            feed.update(
                {
                    "_shapes": [t.shape for t in tree],
                    "_dtypes": [t.dtype for t in tree],
                }
            )
            return MultiOutputOp(
                *args,
                **feed,
                **kwargs,
            )
        else:
            feed.update({"_shape": tree.shape, "_dtype": tree.dtype})
            if is_random:
                return RandomOp(*args, _seed=seed, **feed, **kwargs)
            else:
                return Op(*args, **feed, **kwargs)

    if func.__name__ not in symjax._fn_to_op:
        symjax._fn_to_op[func.__name__] = op
    if not hasattr(func, "__doc__") or func.__doc__ is None:
        return op

    if doc_func is not None:
        # sections = func.__doc__.split("\n\n")

        # signatures = []
        # summary = None
        # for i in range(len(sections)):
        #     if _numpy_signature_re.match(sections[i]):
        #         signatures.append(sections[i])
        #     else:
        #         summary = sections[i].strip()
        #         break
        # body = "\n\n".join(signatures + sections[i + 1 :])
        # body = update_numpydoc(body, func, op)
        # desc = "ADDITION"
        # docstr = (
        #     "{summary}\n\nLAX-backend implementation of :func:`{fun}`.\n"
        #     "{lax_description}Original docstring below.\n\n{body}".format(
        #         summary=summary,
        #         lax_description=desc,
        #         fun=func.__name__,
        #         body=body,
        #     )
        # )

        op.__name__ = func.__name__
        # op.__doc__ = docstr

    return op


def wrap_class(c, method_exceptions=None):
    class meta:
        def __new__(cls, *args, **kwargs):

            # the first part consists into reexpressing any possible symjax
            # input into a jax one to first evaluate the class creator and
            # derive from its the equivalent symjax computational graph that
            # would produce the same class attributes
            new_args = []
            new_kwargs = {}
            for i in range(len(args)):
                if isinstance(args[i], Tensor):
                    new_args.append(jnp.zeros(args[i].shape, dtype=args[i].dtype))
                else:
                    new_args.append(args[i])
            for i in kwargs:
                if isinstance(kwargs[i], Tensor):
                    new_kwargs[i] = jnp.zeros(kwargs[i].shape, dtype=kwargs[i].dtype)
                else:
                    new_kwargs[i] = kwargs[i]

            # now we check which attributes were added during the class
            # creation, those are the ones that will have to be obtained from
            # a symjax computational graph based on the class inputs
            attr_before = c.__dict__.keys()
            instance = c(*new_args, **new_kwargs)
            attr_after = instance.__dict__.keys()

            news = [i for i in attr_after if i not in attr_before]
            news = [
                n
                for n in news
                if isinstance(instance.__dict__[n], jax.interpreters.xla.DeviceArray)
            ]

            # this function maps the class inputs to the creator generated
            # class attributes
            def function(*args, **kwargs):
                return [instance.__dict__[n] for n in news]

            init_op = jax_wrap(function)

            # we now allow our normal class creation to proceed
            obj = super().__new__(cls)
            obj._init_op = init_op
            obj._news = news

            # we also have to wrap all the methods
            method_exceptions = cls._method_exceptions or []
            for att in dir(instance):
                if att[:2] == "__" or att in method_exceptions:
                    continue
                if callable(getattr(instance, att)):
                    setattr(obj, att, jax_wrap(getattr(instance, att)))
            return obj

        def __init__(self, *args, **kwargs):
            attrs = self._init_op(*args, **kwargs)
            for n, a in zip(self._news, attrs):
                self.__dict__[n] = a

    meta._method_exceptions = method_exceptions

    return meta


class Tensor:

    __array_priority__ = 1000

    def __new__(cls, *args, inplace_copy=None, **kwargs):
        if inplace_copy is not None:
            return inplace_copy
        else:
            obj = super(Tensor, cls).__new__(cls)
            return obj

    def __init__(self, *args, inplace_copy=None, **kwargs):

        if inplace_copy is not None:
            return
        if "_attrs" in kwargs:
            if "_shape" in kwargs["_attrs"]:
                self._shape = kwargs["_attrs"]["_shape"]
            if "_dtype" in kwargs["_attrs"]:
                self._dtype = kwargs["_attrs"]["_dtype"]
        symjax.current_graph()._add(self, *args, **kwargs)

    @property
    def name(self):
        return symjax.current_graph().nodes[self]["name"]

    @property
    def scope(self):
        return symjax.current_graph().nodes[self]["scope"]

    def clone(self, givens):
        return symjax.current_graph().clone(self, givens)

    @property
    def T(self):
        return self.transpose()

    @property
    def shape(self):
        return self._shape  # symjax.current_graph().nodes[self]["_shape"]

    @property
    def dtype(self):
        return self._dtype  # symjax.current_graph().nodes[self]["_dtype"]

    @property
    def ndim(self):
        return len(self.shape)

    def get(self, givens=None):
        return symjax.current_graph().get(self, givens)

    def variables(self, trainable=True):
        return symjax.get_variables(scope=self.scope, trainable=trainable)


class Constant(Tensor):
    def __init__(self, value):

        name, scope = symjax.current_graph()._get_name_scope("constant", self)

        super().__init__(
            _attrs={
                "name": name,
                "scope": scope,
                "value": value,
                "root": False,
            },
        )

    @property
    def value(self):
        return symjax.current_graph().nodes[self]["value"]

    @property
    def shape(self):
        return None

    @property
    def dtype(self):
        return None

    def get(self):
        return self.value

    def __repr__(self):
        return "ConstantValue({})".format(type(self.value))

    def __str__(self):
        return self.__repr__()


class Op(Tensor):
    """an Op generates a Tensor object obtained from a function"""

    def __init__(
        self,
        *args,
        _jax_function=None,
        _shape=None,
        _dtype=None,
        name=None,
        **kwargs,
    ):

        if self in symjax.current_graph().nodes:
            return

        if name is None:
            name = _jax_function.__name__

        name, scope = symjax.current_graph()._get_name_scope(name, self)

        super().__init__(
            *args,
            _attrs={
                "name": name,
                "scope": scope,
                "_shape": _shape,
                "_dtype": _dtype,
                "jax_function": _jax_function,
                "root": False,
            },
            **kwargs,
        )

    @property
    def fn_name(self):
        return symjax.current_graph().nodes[self]["jax_function"].__name__

    def __repr__(self):

        name = "Op(name={}, fn={}, shape={}, dtype={}, scope={})"
        return name.format(self.name, self.fn_name, self.shape, self.dtype, self.scope)

    def __str__(self):

        return self.__repr__()


def _tuple(*args):
    return tuple(args)


_tuple.__name__ = "jax_tuple"

Tuple = jax_wrap(_tuple)


class MultiOutputOp(Op, tuple, Tensor):
    __array_priority__ = 0

    def __new__(
        cls,
        *args,
        _jax_function,
        _shapes,
        _dtypes,
        name=None,
        **kwargs,
    ):
        scope = symjax.current_graph().scope.absolute_name
        items = []

        for i, (shape, dtype) in enumerate(zip(_shapes, _dtypes)):
            items.append(
                OpItem(
                    _attrs={
                        "name": _jax_function.__name__ + "[{}]".format(i),
                        "scope": scope,
                        "jax_function": _jax_function,
                        "root": False,
                        "parent_index": i,
                        "_shape": shape,
                        "_dtype": dtype,
                    }
                )
            )
        return super(MultiOutputOp, cls).__new__(cls, tuple(items))

    def __init__(self, *args, _jax_function, _shapes, _dtypes, name=None, **kwargs):
        if name is None:
            name = _jax_function.__name__

        name, scope = symjax.current_graph()._get_name_scope(name, self)

        Tensor.__init__(
            self,
            *args,
            _attrs={
                "name": name,
                "scope": scope,
                "_dtype": MultiOutputOp,
                "jax_function": _jax_function,
                "root": False,
            },
            **kwargs,
        )

        for i, child in enumerate(self):
            symjax.current_graph().add_edge(
                self,
                child,
                name="parent_index" + str(i),
            )
            symjax.current_graph().nodes[child]["parent"] = self

    def __str__(self):
        return tuple.__str__(self)

    def __repr__(self):
        return tuple.__repr__(self)

    def __len__(self):
        return tuple.__len__(self)


class OpItem(Op, Tensor):
    def __init__(self, *args, **kwargs):
        Tensor.__init__(self, *args, **kwargs)

    @property
    def parent(self):
        return symjax.current_graph().nodes[self]["parent"]

    @property
    def parent_index(self):
        return symjax.current_graph().nodes[self]["parent_index"]


class RandomOp(Op, Tensor):
    """
    This class creates a :obj:`Tensor` object that given a function (see below)
    and its inputs can be used as a Node in the graph construction. This class
    is specialized to deal with random functions, if the function does not
    take a jax.PRNGKey as first argument, then it should not be used.

    Notes:
        This class is not meant to be used by the user. To create your own
        callable node, see :class:`RandomOp`.

    Args:
        msg (str): Human readable string describing the exception.
        code (:obj:`int`, optional): Error code.

    Attributes:
        msg (str): Human readable string describing the exception.
        code (int): Exception error code.

    """

    def __init__(
        self, *args, _jax_function, _shape, _dtype, _seed, name=None, **kwargs
    ):

        if name is None:
            name = _jax_function.__name__

        name, scope = symjax.current_graph()._get_name_scope(name, self)

        _seed = ord(os.urandom(1)) if _seed is None else _seed
        seed_op = Seed(_seed)

        Tensor.__init__(
            self,
            seed_op,
            *args,
            _attrs={
                "name": name,
                "scope": scope,
                "_shape": _shape,
                "_dtype": _dtype,
                "jax_function": _jax_function,
                "root": False,
            },
            **kwargs,
        )

    @property
    def seed(self):
        return symjax.current_graph().nodes[self]["seed"]

    def __repr__(self):
        name = "RandomOp(name={}, fn={}, shape={}, dtype={}, scope={})"
        return name.format(self.name, self.fn_name, self.shape, self.dtype, self.scope)


class Variable(Tensor):

    """variable that is a standalone persistent tensor. this tensor
    can be updated and differentiated.

    Parameters:
    -----------

        value_or_fn: array or initializer
            the value given as a numpy array or an initializer which
            takes as input the shape and can be type casted afterward
            via numpy.cast

        shape: tuple (optional)
            the shape of the variable, used only if the value_or_fn is an
            initializer

        dtype: dtype (optional)
            the dtype of the variable, used only if the value_or_fn is an
            initializer

        name: str (optional)
            the name of the variable, there is no test of name duplication

        trainable: bool
            whether the variable is trainable or not. It is set as an
            attribute and can be accessed.
    """

    def __init__(
        self,
        initializer,
        name="unnamed",
        trainable=True,
        shape=None,
        dtype=None,
    ):

        if trainable and dtype == "bool":
            raise RuntimeError("error impossible learning with dtype bool")

        assert not isvar(shape)

        name, scope = symjax.current_graph()._get_name_scope(name, self)
        value = self._reset(initializer, shape, dtype)
        shape = tuple(shape or value.shape)
        dtype = jax.numpy.dtype(dtype) if dtype is not None else value.dtype

        super().__init__(
            _attrs={
                "name": name,
                "scope": scope,
                "_shape": shape,
                "_dtype": dtype,
                "trainable": trainable,
                "initializer": initializer,
                "root": True,
                "value": value,
            }
        )

    @property
    def trainable(self):
        return symjax.current_graph().nodes[self]["trainable"]

    @property
    def initializer(self):
        return symjax.current_graph().nodes[self]["initializer"]

    def _reset(self, init, shape, dtype):
        """reset the value of the variable based on the initial one, whether
        it was an array or initializer. If it was a random initializer,
        nothing guarantees that the reset will give back the original value
        as opposed to the array case
        """
        if type(shape) == list:
            shape = tuple(shape)

        if callable(init):
            if shape is None:
                warnings.warn("given shape was None, using ()")
                shape = ()
            init = init(shape)

        if isinstance(init, Tensor):
            value = init.get()
        else:
            value = numpy.array(init)

        if dtype is not None:
            value = value.astype(dtype)

        if shape is not None:
            if shape != jax.numpy.shape(value):
                raise RuntimeError(
                    "given value and shape do not match (got {} expected {})".format(
                        value.shape, shape
                    )
                )

        value = jax.device_put(value)
        return value

    def reset(self):
        self.update(self._reset(self.initializer, self.shape, self.dtype))

    @property
    def value(self):
        """utility function that takes the input and return
        the actual value. It deals with cases where the input
        was a function or not etc
        """
        return symjax.current_graph().nodes[self]["value"]

    def update(self, update_value, fast=True):
        """assign a new value for the variable"""
        if fast:
            symjax.current_graph().nodes[self]["value"] = update_value

        new_value = symjax.current_graph().get(update_value)

        if self.shape != jax.numpy.shape(new_value):
            warnings.warn(
                "Variable and update of {}".format(self)
                + "are not the same shape (expected {}, got {}".format(
                    self.shape, jax.numpy.shape(new_value)
                )
                + "... attempting to cast"
            )
            new_value = jax.numpy.reshape(new_value, self.shape)

        if hasattr(new_value, "dtype"):
            ntype = new_value.dtype
        else:
            ntype = type(new_value)
        if self.dtype != ntype:
            warnings.warn(
                "Variable and update of {}".format(self)
                + "are not the same dtype (expected {}, got {}".format(
                    self.dtype, ntype
                )
                + "... attempting to cast"
            )

            new_value = jax.numpy.asarray(new_value).astype(self.dtype)

        symjax.current_graph().nodes[self]["value"] = new_value

    def __repr__(self):
        name = "Variable(name={}, shape={}, dtype={}, trainable={}, scope={})"
        return name.format(
            self.name, self.shape, self.dtype, self.trainable, self.scope
        )


class Seed(Variable, Tensor):
    """an Op generates a Tensor object obtained from a function"""

    def __init__(self, seed):

        name, scope = symjax.current_graph()._get_name_scope("seed", self)

        Tensor.__init__(
            self,
            _attrs={
                "name": name,
                "scope": scope,
                "value": jax.random.PRNGKey(seed),
                "root": True,
                "_shape": (2,),
                "_dtype": "uint32",
                "trainable": False,
            },
        )

    def update(self, other_seed=None):
        """

        update the seed either by splitting the current one
        effectively generating a new random seed
        or by using a given one

        """
        if other_seed is not None:
            if len(other_seed) != 2:
                raise RuntimeError("given updated seed {other_seed} is not valid")
            symjax.current_graph().nodes[self]["value"] = other_seed
        else:
            new_key = jax.random.split(self.value, 1)[0]
            symjax.current_graph().nodes[self]["value"] = new_key

    def __repr__(self):
        name = "Seed(name={}, scope={})"
        return name.format(self.name, self.scope)

    def reset(self):
        pass


class Placeholder(Tensor):
    """placeholder is an input to the computational graph that takes outside
    values. That is, it is an input gate to feed data into a computational
    graph as opposed to for example variables which are living in memory and
    are not fed externally.

    Parameters:
    -----------

        shape: tuple
            the shape of the placeholder

        dtype: dtype
            the dtype of the placeholder

        name: str (optional)
            the name of the variable, there is no test of name duplication
    """

    def __init__(self, shape, dtype, name="unnamed"):

        name, scope = symjax.current_graph()._get_name_scope(name, self)

        super().__init__(
            _attrs={
                "name": name,
                "scope": scope,
                "_dtype": jax.numpy.dtype(dtype),
                "_shape": tuple(shape),
                "root": True,
            },
        )

    def __repr__(self):
        name = "Placeholder(name={}, shape={}, dtype={}, scope={})"
        return name.format(self.name, self.shape, self.dtype, self.scope)


def placeholder_like(item, name="", force=True):
    if item is None:
        return None
    elif type(item) == list or type(item) == tuple:
        return type(item)([placeholder_like(i, force=force) for i in item])
    elif not force and not isvar(item):
        return item
    else:
        return Placeholder(item.shape, item.dtype, name=name)


def match(l1, l2, output):
    if output is None:
        output = dict()
    if type(l1) == list or type(l1) == tuple:
        for a, b in zip(l1, l2):
            match(a, b, output)
    else:
        output.update({l1: l2})


def symjax_to_jax_fn(func):
    def newfn(*args, fn=func):
        pholders = placeholder_like(args)
        symjax_outputs = fn(*pholders)
        feed_dict = {}
        match(pholders, args, feed_dict)
        if None in feed_dict:
            del feed_dict[None]
        outputs = symjax.current_graph().get(symjax_outputs, feed_dict)
        return outputs

    return newfn


def clone(tensor, givens):
    return tensor.clone(givens)


def get(tensor, tracker=None):
    return symjax.current_graph().get(tensor, tracker)