Source code for symjax.base

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Base of symjax."""

import fnmatch

import jax
import numpy
from jax import jacfwd, jacrev

import symjax
from symjax import tensor as t
from symjax.tensor import random
import networkx as nx
import re
import collections
import os


def natural_key(string_):
    """See http://www.codinghorror.com/blog/archives/001018.html"""
    return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_)]


def current_scope():
    """Current scope."""
    return current_graph()._scopes[-1]


[docs]def current_graph(): """Current graph.""" assert len(symjax._graphs) return symjax._graphs[-1]
[docs]class Graph(nx.DiGraph): def __init__(self, name, *args, **kwargs): super().__init__(*args, **kwargs) self._name = name self.reset() self._jax_to_op = {} def reset(self): self.clear() self._scopes = [Scope(absolute_name="/" + self.name, graph=self)] self._updates = {} self._scopes_history = [] self._branches = {} self._name_counts = {} def __repr__(self): msg = "Graph(name:{}, n_nodes:{})".format(self.name, len(self.nodes)) return msg def __str__(self): return self.__repr__() @property def name(self): return self._name def add_updates(self, updates): # TODO # we add just in case there are some lists or something # this might not be needed for i in updates.values(): self._add(i) self._updates.update(updates) def _get_name_scope(self, name, tensor): return self.scope._get_name_scope(name, tensor) def is_connected(self, node_1, node_2, directed=True): """check if two nodes are connected in the graph. This function is useful to check wheter two nodes (or Op) as connected in the graph and are thus dependent on each other. For example this is useful to select on trainable variables that affect a specific tensor. Args: node_1: Tensor node_2: Tensor directed: bool whether to test for both directions or not, if the graph is not directed then this parameter has no effect. If ``False`` then the function will return ``True`` is the nodes are connected no matter the direction Returns: bool """ if directed: return nx.has_path(self, node_1, node_2) else: return nx.has_path(self, node_1, node_2) or nx.has_path( self, node_2, node_1 ) def _add(self, tensor, *args, _attrs=None, **kwargs): _attrs = _attrs or {} # first we check is the not is a hashable, if it is not # already in the graph if isinstance(tensor, collections.Hashable) and type(tensor) != tuple: if tensor in self.nodes: return tensor # then, if the node is a constant (not tensor, variable, ...) if not t.isvar(tensor): # if it is hashable then we can add it as is as a node # and we return it since it is the same object # if isinstance(tensor, collections.Hashable): # self.add_node(tensor, root=False) # return tensor # otherwise we have to make it hashable, to do so we use # the constant object # during creation of the object, it will be added to the # graph automatically, we can return the object # else: return t.Constant(tensor) if type(tensor) == list or type(tensor) == tuple: return t.Tuple(*tensor) # now check if it is a list of a tuple if isinstance(tensor, t.Tensor): self.add_node(tensor, **_attrs) if isinstance(tensor, t.Op): for i, arg in enumerate(args): node = self._add(arg) if self.has_edge(node, tensor): self[node][tensor]["name"] += "+arg" + str(i) else: self.add_edge(node, tensor, name="arg" + str(i)) for key, arg in kwargs.items(): node = self._add(arg) if self.has_edge(node, tensor): self[node][tensor]["name"] += "+" + key else: self.add_edge(node, tensor, name=key) return tensor def roots(self, nodes, roots=None): if roots is None: roots = [] if type(nodes) == tuple or type(nodes) == list: for node in nodes: self.roots(node, roots) else: if self.nodes[nodes]["root"]: roots.append(nodes) for i in self.ancestors(nodes): if self.nodes[i]["root"]: roots.append(i) return list(set(roots)) def ancestors(self, nodes, acc=None): if acc is None: acc = set() if type(nodes) == list: for node in nodes: acc = acc.union(self.ancestors(node, acc)) return acc else: predecessors = list(self.predecessors(nodes)) predecessors = [p for p in predecessors if p not in acc] if len(predecessors) == 0: return acc else: acc = acc.union(set(predecessors)) return self.ancestors(predecessors, acc) def clone(self, node, todos, done=None, share_random_ops=True): """ Function that allows replacing subgraphs of a computational graph. It returns a copy of the initial subgraph with the corresponding substitutions. Parameters ---------- node : Tensor the node to replace input_givens : dict Dictionary describing which subgraphs should be replaced by what. share_random_ops: bool whether to get the same random variables realisations or not between the original tensor and the cloned one """ if done is None: done = {} if node in done: return done[node] elif node in todos: return todos[node] elif isinstance(node, t.Variable) or isinstance(node, t.Placeholder): return node elif isinstance(node, t.RandomOp) and share_random_ops: return node elif isinstance(node, t.Constant): return node.value elif len(set(todos) - set(done)) == 0: return node elif isinstance(node, t.OpItem): parent = list(self.predecessors(node))[0] index = int(self[parent][node]["name"].split("parent_index")[1]) done[node] = self.clone(parent, todos, done)[index] return done[node] elif type(node) == tuple or type(node) == list or type(node) == t.Tuple: return [self.clone(n, todos, done) for n in node] args, kwargs = self.get_args_kwargs(node, evaluate=False) new_args = [ self.clone(n, todos, done) for n in args if not isinstance(n, t.Seed) ] new_kwargs = {name: self.clone(n, todos, done) for name, n in kwargs.items()} fun = self.get_node_attribute(node, "jax_function") done[node] = symjax._fn_to_op[fun.__name__](*new_args, **new_kwargs) return done[node] def get_node_attribute(self, node, attr): return nx.get_node_attributes(self, attr)[node] def unnest(self, container): for i in container: if isinstance(i, (list, tuple)): for j in self.unnest(i): yield j else: yield i def get(self, input_tensor, tracker=None, frozen=True): if not t.isvar(input_tensor): return input_tensor if tracker is None: tracker = {} mapped = {} else: mapped = tracker.copy() tomap = [] tosearch = list(self.unnest([input_tensor])) # the first graph traversal has for objective to go through the parents # and readh the furthest parents to mapped their value into the mapped # hashmap while len(tosearch): for item in tosearch: tosearch.remove(item) if item in tracker: mapped[item] = tracker[item] elif isinstance(item, t.Variable) or isinstance(item, t.Constant): mapped[item] = item.value elif isinstance(item, t.Placeholder): if item in tracker: mapped[item] = tracker[item] else: raise ValueError( " no value given for placeholder {}".format(item) ) elif isinstance(item, t.OpItem): parent = list(self.predecessors(item))[0] tomap.append(parent) if parent not in tosearch: tosearch.append(parent) elif not t.isvar(item) and not isinstance(item, t.Constant): tomap.append(t.Constant(item)) else: # this item is an Op and thus will need to be mapped later on tomap.append(item) # we retreive the parents of the Op node to go upper in the graph item_parents = list(self.predecessors(item)) # we loop through the parents and only add the unique ones to the # searchlist for the next loop iteration for it in list(self.unnest(list(item_parents))): if it not in tosearch: tosearch.append(it) # now tosearch is empty, and the parents of item that could be assigned # an explicit value are mapped into the mapped hashmap # we start from the end as the last added items will have parents already in mapped for item in tomap[::-1]: if isinstance(item, t.OpItem): parent = list(self.predecessors(item))[0] index = int(self[parent][item]["name"].split("parent_index")[1]) if parent not in mapped: mapped[parent] = self._get_value_given_mapped( parent, mapped, frozen=frozen ) mapped[item] = mapped[parent][index] elif isinstance(item, t.MultiOutputOp): mapped[item] = self._get_value_given_mapped(item, mapped, frozen=frozen) for node, value in zip(item, mapped[item]): mapped[node] = value else: mapped[item] = self._get_value_given_mapped(item, mapped, frozen=frozen) # now generate the mapped inputs output = self._to_mapped_value(input_tensor, mapped) return self._to_mapped_value(input_tensor, mapped) def _get_value_given_mapped(self, item, mapped, frozen=False): if isinstance(item, t.Constant): return item.value elif not t.isvar(item): return item args, kwargs = self.get_args_kwargs(item, evaluate=False, frozen=frozen) mapped_args = self._to_mapped_value(args, mapped) mapped_kwargs = { key: self._to_mapped_value(value, mapped) for key, value in kwargs.items() } return self.nodes[item]["jax_function"](*mapped_args, **mapped_kwargs) def _to_mapped_value(self, args, mapped): if not t.isvar(args): return args elif type(args) == list: return [self._to_mapped_value(arg, mapped) for arg in args] elif type(args) == tuple: return tuple([self._to_mapped_value(arg, mapped) for arg in args]) else: return mapped[args] def all_predecessors(self, item): if type(item) == list or type(item) == tuple: predecessors = [] for i in item: predecessors += self.all_predecessors(i) return predecessors else: return self.predecessors(item) def get_args_kwargs(self, node, tracker=None, evaluate=True, frozen=True): if evaluate: assert tracker is not None parents = list(self.all_predecessors(node)) all_args = {} for i in range(len(parents)): all_args[self[parents[i]][node]["name"]] = self.get( parents[i], tracker, frozen=frozen ) else: all_args = { self[parent][node]["name"]: parent for parent in self.all_predecessors(node) } # now we inspect if there are duplicate args for arg in list(all_args.keys()): if "+" in arg: items = arg.split("+") for item in items: all_args.update({item: all_args[arg]}) del all_args[arg] arg_names = [name for name in all_args.keys() if "arg" == name[:3]] arg_names = sorted(arg_names, key=natural_key) args = [all_args[name] for name in arg_names] for name in arg_names: del all_args[name] return args, all_args @property def scope(self): return self._scopes[-1] @property def scope_name(self): return self._scopes[-1].absolute_name def variables(self, trainable=True): variables = [n for n in self.nodes if isinstance(n, t.Variable)] if trainable is None: return variables return [v for v in variables if v.trainable == trainable] @property def placeholders(self): placeholders = [n for n in self.nodes if type(n) == t.Placeholder] return placeholders @property def ops(self): ops = [n for n in self.nodes if type(n) in [t.Op, t.MultiOutputOp]] return ops @property def updates(self): return self._updates @property def other_nodes(self): return list( set(self.nodes) - ( set(self.ops).union( set(self.placeholders).union(set(self.variables(None))) ) ) )
[docs]class Scope: """ Defining scope for any variable/operation to be in. Example ------- .. doctest:: >>> import symjax >>> import symjax.tensor as T >>> v = T.Variable(1, name='v') >>> # the current (default) scope is the root of the graph >>> print(v.scope) / >>> with symjax.Scope('my_scope'): ... w = T.Variable(2, name='w') ... out = v*w >>> print(out.scope) /my_scope/ >>> print(w.scope) /my_scope/ >>> #it is also possible to keep a scope persistently >>> scope = symjax.Scope('second_scope') >>> with scope: ... other = out * w >>> print(other.scope) /second_scope/ >>> # this allows to keep track directly of internal ops >>> print(scope.ops) [Op(name=multiply, fn=multiply, shape=(), dtype=int32, scope=/second_scope/)] """ def __init__( self, relative_name=None, absolute_name=None, reattach=False, reuse=False, graph=None, ): """Constructor.""" assert relative_name is not None or absolute_name is not None if relative_name is not None: assert "/" not in relative_name self.reattach = reattach self.graph = graph self.reuse = reuse self.relative_name = relative_name self.absolute_name = absolute_name if reuse or reattach: assert reattach assert absolute_name in self.graph._scopes_history def __enter__(self): """Set global variables.""" if self.graph is None: self.graph = current_graph() if self.absolute_name is None: self.absolute_name = os.path.join(self.graph.scope_name, self.relative_name) if self.reattach: return self if self.absolute_name in self.graph._scopes_history: cpt = 1 self.absolute_name += "_" while self.absolute_name + str(cpt) in self.graph._scopes_history: cpt += 1 self.absolute_name += str(cpt) self.relative_name += "_" + str(cpt) self.graph._scopes_history.append(self.absolute_name) self.graph._scopes.append(self) return self def __exit__(self, *args): """Delete globals.""" self.graph._scopes = self.graph._scopes[:-1] # def save_variables(self, path): # """Save graph.""" # if ".npz" != path[:-4]: # path += ".npz" # numpy.savez( # path, # **dict([(v.name, symjax.tensor.get(v)) for v in self.variables]), # ) # def variables(self, trainable=True): # return get_variables(scope=self.full_name, trainable=trainable) # @property # def ops(self): # return get_ops(scope=self.full_name) # @property # def placeholders(self): # return get_placeholders(scope=self.full_name) # def load_variables(self, path): # """Load graph.""" # data = numpy.load(path) # for name, value in data.items(): # self.variable[name].update(value) # def reset(self): # for var in self.variables: # var.reset() def _get_name_scope(self, name, tensor): # if self.full_name is None: # self.__enter__() if isinstance(tensor, symjax.tensor.Placeholder): nodes = self.graph.placeholders elif isinstance(tensor, symjax.tensor.Variable): nodes = self.graph.variables(None) elif isinstance(tensor, symjax.tensor.Op): nodes = self.graph.ops else: nodes = self.graph.other_nodes names = [os.path.join(m.scope, m.name) for m in nodes if hasattr(m, "name")] test_name = os.path.join(self.absolute_name, name) # if we never used this name before then we can just keep it as is if test_name not in names: return name, self.absolute_name # otherwise we append the suffix _N with N the number of times we # already used this name. To keep track of how many times # did we already use a name we employ a look up table to have # fast retreival, we also update that count by 1 else: # note that if we allow reuse, then we just return the variable # that had this name if self.reuse and isinstance(tensor, symjax.tensor.Variable): return nodes[nodes.index(tensor)] # if we never saw this variable before (first time it is repeated) # then we create its entry in the look up table if test_name not in self.graph._name_counts: self.graph._name_counts[test_name] = 0 # we increase the counter self.graph._name_counts[test_name] += 1 return ( name + "_" + str(self.graph._name_counts[test_name]), self.absolute_name, )
[docs]def reset_variables(name="*", scope="*", trainable=None): """ utility to reset variables based on their names Parameters ---------- name: str (default=*) the name (or part of the name) of all the variables that should be reset, it can include the glob (*) searching for all matching names trainable: bool or None (optional, default=None) is not None, it will only reset from the matched variables the ones that trainable attribute matches the given one Returns ------- None Example ------- .. doctest:: >>> import symjax >>> w = symjax.tensor.Variable(1., name='w', dtype='float32') >>> x = symjax.tensor.Variable(2., name='x', dtype='float32') >>> f = symjax.function(outputs=[w, x], updates={w:w + 1,x:x + 1}) >>> for i in range(10): ... print(f()) [array(1., dtype=float32), array(2., dtype=float32)] [array(2., dtype=float32), array(3., dtype=float32)] [array(3., dtype=float32), array(4., dtype=float32)] [array(4., dtype=float32), array(5., dtype=float32)] [array(5., dtype=float32), array(6., dtype=float32)] [array(6., dtype=float32), array(7., dtype=float32)] [array(7., dtype=float32), array(8., dtype=float32)] [array(8., dtype=float32), array(9., dtype=float32)] [array(9., dtype=float32), array(10., dtype=float32)] [array(10., dtype=float32), array(11., dtype=float32)] >>> # reset only the w variable >>> symjax.reset_variables('w') >>> # reset all variables >>> symjax.reset_variables('*') """ variables = get_variables(name=name, scope=scope, trainable=trainable) for var in variables: var.reset()
[docs]def save_variables( path_or_file, name="*", scope="*", trainable=None, ): """saves the graph variables. The saving is done via ``numpy.savez`` for fast and compressed storage. Parameters: ----------- path_or_file: str or file the path and name of the file to save the variables in or an open file object name: str (optional) the name string that the variables to save must match scope: str (optional) the scope name string that the variables to save must match trainable: bool or None the option of the variables to save (``True``, ``False`` or ``None``) """ if type(path_or_file) == str: if path_or_file[-4:] != ".npz": path_or_file += ".npz" variables = get_variables(name, scope, trainable) numpy.savez( path_or_file, **dict( [ ( v.scope + v.name, symjax.tensor.get(v), ) for v in variables ] ), )
[docs]def load_variables(path_or_file, name="*", scope="*", trainable=None): """loads the graph variables. The loading is done via ``numpy.savez`` for fast and compressed storage. Parameters: ----------- path_or_file: str or file the path and name of the file to load the variables from or an open file object name: str (optional) the name string that the variables to load must match scope: str (optional) the scope name string that the variables to load must match trainable: bool or None the option of the variables to save (``True``, ``False`` or ``None``) """ if type(path_or_file) == str: if path_or_file[-4:] != ".npz": path_or_file += ".npz" variables = get_variables(name, scope, trainable=trainable) data = numpy.load(path_or_file) for var in variables: name_in_file = var.scope + var.name if name_in_file not in data: raise Warning("{} not in loaded file".format(name_in_file)) var.update(data[name_in_file])
[docs]def get_variables(name="*", scope="/", trainable=True): matched = current_graph().variables(trainable) output = [] for m in matched: if len(fnmatch.filter([m.name], name)) and len( fnmatch.filter([m.scope], scope + "*") ): output.append(m) return output
[docs]def get_placeholders(name="*", scope="/"): """ Same as symjax.variable but for placeholders """ matched = current_graph().placeholders output = [] for m in matched: if len(fnmatch.filter([m.name], name)) and len( fnmatch.filter([m.scope], scope + "*") ): output.append(m) return output
[docs]def get_ops(name="*", scope="/"): """ Same as symjax.variable but for ops """ matched = current_graph().ops output = [] for m in matched: if len(fnmatch.filter([m.name], name)) and len( fnmatch.filter([m.scope], scope + "*") ): output.append(m) return output
[docs]def get_updates(name="*", scope="/", variables=None): """ Same as symjax.variable but for ops """ matched = current_graph().updates output = {} if variables is not None: for v in variables: if v in matched: output[v] = matched[v] return output for var, update in matched.items(): if len(fnmatch.filter([update.name], name)) and len( fnmatch.filter([update.scope], scope + "*") ): output[var] = update return output
def add_updates(updates): current_scope().add_updates(updates)
[docs]def gradients(scalar, variables): """Compute the gradients of a scalar w.r.t to a given list of variables. Arguments --------- scalar: :class:`symjax.tensor.base.Tensor` the variable to differentiate variables: List or Tuple the variables used to compute the derivative. Returns ------- gradients: Tuple the sequency of gradients ordered as given in the input variables Example ------- .. doctest:: >>> import symjax >>> w = symjax.tensor.ones(3) >>> x = symjax.tensor.Variable(2., name='x', dtype='float32') >>> l = (w ** 2).sum() * x >>> g = symjax.gradients(l, [w])[0] >>> f = symjax.function(outputs=g, updates={x:x + 1}) >>> for i in range(2): ... print(f()) [4. 4. 4.] [6. 6. 6.] """ if numpy.prod(scalar.shape) != 1: raise RuntimeError("the variable to differentiate is not a scalar") if not isinstance(scalar, t.Tensor): raise RuntimeError("the variable used in gradients should be a Tensor type") if isinstance(variables, t.Tensor): input_variables = [variables] input_list = False else: input_variables = variables.copy() input_list = True # get the argnum of the variables that we differentiate one argnums = list(range(len(input_variables))) # get all the roots of the scalar, this is needed as otherwise they are not # as the input of the gradient function and thus a change of # their value will not change the gradient computation, we also ensure # uniqueness input_variables += [ i for i in current_graph().roots(scalar) if i not in input_variables ] # create a dummy function that is needed for jax to compute a gradient func # this function is the one that builds the graph of computation from all # roots # to the scalar varible s.t. automatic diffenrentiation can be applied def internal_gradient_function(*args): return current_graph().get( scalar, {input_variables[i]: args[i] for i in range(len(input_variables))}, ) # now we obtain the grad function. In fact, Jax returns a function that, # when it is called, returns the gradient values, this function is then # used to generate the Tuple of symbolic variables grad_fn = jax.grad(internal_gradient_function, argnums) wrap_fn = t.jax_wrap(grad_fn) if input_list: return wrap_fn(*input_variables) else: return wrap_fn(*input_variables)[0]
[docs]def jacobians(tensor, variables, mode="forward"): """Compute the jacobians of a tensor w.r.t to a given list of variables. The tensor needs not to be a vector, but will be treated as such. For example if tensor.shape is (10, 3, 3) and a variable shape if (10, 10) the resulting jacobian has shape (10, 3, 3, 10, 10). It is possible to specify the mode forward or backward. For tall jacobians, forward is faster and vice-versa. Arguments --------- vector: Tensor the variable to differentiate variables: List or Tuple the variables used to compute the derivative. Returns ------- jacobians: Tuple the sequency of gradients ordered as given in the input variables :param tensor: :param mode: """ if isinstance(variables, t.Tensor): input_variables = [variables] input_list = False else: input_variables = variables.copy() input_list = True # get the argnum of the variables that we differentiate one argnums = list(range(len(input_variables))) # get all the roots of the scalar, this is needed as otherwise they are not # as the input of the gradient function and thus a change of # their value will not change the gradient computation, we also ensure # uniqueness input_variables += [ i for i in current_graph().roots(tensor) if i not in input_variables ] # create a dummy function that is needed for jax to compute a gradient func # this function is the one that builds the graph of computation from # all roots # to the scalar varible s.t. automatic diffenrentiation can be applied def internal_jacobian(*args): return current_graph().get( tensor, {input_variables[i]: args[i] for i in range(len(input_variables))}, ) # now we obtain the jacobian function. In fact, Jax returns a function that # when it is called, returns the jacobian values, this function is then # used to generate the Tuple of symbolic variables if mode == "forward": jacob_fn = jacfwd(internal_jacobian, argnums) elif mode == "backward": jacob_fn = jacrev(internal_jacobian, argnums) else: raise RuntimeError( "mode {} not recognized, use forward or backward".format(mode) ) wrap_fn = t.jax_wrap(jacob_fn, False) if input_list: return wrap_fn(*input_variables) else: return wrap_fn(*input_variables)[0]
[docs]class function: """Generate a user function that compiles a computational graph. Based on given inputs, outputs and update policy of variables. This function internally jit compile the underlying jax computational graph for performances. Arguments --------- args: trailing tuple the inputs to the function to be compiled. The tuple should contain all the placeholders that are roots of any output given of the function and update values outputs: List (optional) the outputs of the function, if a single element, it can be given as a standalone and not a list updates: Dict (optional) the dictionnary of updates as per {var:new_value} for any variable of the graph backend: 'cpu' or 'gpu' the backend to use to run the function on default_value: not implemented not implemented Returns ------- callable: the user frontend function that takes the specified inputs, returns the specified outputs and perform internally the updates Examples -------- >>> import symjax >>> import symjax.tensor as T >>> x = T.ones((4, 4)) >>> xs = x.sum() + 1 >>> f = symjax.function(outputs=xs) >>> print(f()) 17.0 >>> w = T.Variable(0., name='w', dtype='float32') >>> increment = symjax.function(updates={w: w + 1}) >>> for i in range(10): ... increment() >>> print(w.value) 10.0 """ def __init__( self, *args, outputs=[], updates=None, # noqa device=None, backend=None, default_value=None, frozen=False ): """Initialize.""" # check the given updates (if any) and ensure that they only # update Variable objects if updates is None: updates = {} for update in updates: if not isinstance(update, t.Variable): raise RuntimeError( "{} is not a Variable and cannot be updated".format(update) ) # ensure that all inputs are actual placeholders or variables for arg in args: if not isinstance(arg, t.Tensor): raise RuntimeError( "{} is not a Tensor type. Only tensor types can be" + "function inputs".format(arg) ) # gather all roots, they need to be explicit as inputs of the # underlying functions otherwise they are treated as constants # and any change in their value will not appear when running the # function self.updates_keys = list(updates.keys()) self.updates_values = list(updates.values()) self.args = args self.outputs = outputs outs = self.updates_values.copy() outs += [outputs] if isinstance(outputs, t.Tensor) else outputs self.all_roots = set(symjax.current_graph().roots(outs)) # check the function inputs, they must be at least contain all the # placeholders needed to compute the outputs values placeholders_in_root = [ x for x in self.all_roots if isinstance(x, t.Placeholder) ] # check for non_givens = set(placeholders_in_root) - set(self.args) if len(non_givens) > 0: raise RuntimeError( """\ Missing placeholders from the function inputs...\n\ ...Missings are: {}""".format( non_givens ) ) # the roots are made of variables, random tensors, placeholders. We # already ensured that all placeholders are given as inputs to the # function. Now we must ensure that the other ones will also be given # as inputs to not be treated as constants by jax. # we also remove update keys because we will expicitly feed them self.extra_inputs = set(self.all_roots) - set(self.args).union( self.updates_keys ) self.extra_inputs = list(self.extra_inputs) allargs = list(self.args) + self.updates_keys + self.extra_inputs def to_jit(*jitargs): feed_dict = dict(zip(allargs, jitargs)) outputs = [self.outputs, self.updates_values] return symjax.current_graph().get(outputs, feed_dict) self.jited = jax.jit(to_jit, device=device, backend=backend) # define the frontend function that takes as input the inputs variables # and internally compute and update the variables from updates if any def meta(*fnargs): # ensure that the number of arguments is correct assert len(fnargs) == len(self.args) for fnarg, classarg in zip(fnargs, self.args): if hasattr(fnarg, "shape"): if fnarg.shape != classarg.shape and 0 not in classarg.shape: raise RuntimeError( "wrong input given for {}".format(classarg) + ", given is {}".format(fnarg) + ", shape={}".format(fnarg.shape) ) # retreive the function outputs, updated values and apply them jited_add_inputs = symjax.current_graph().get( self.updates_keys + self.extra_inputs ) jitoutputs, jitupdates = self.jited(*fnargs, *jited_add_inputs) for key, update in zip(self.updates_keys, jitupdates): key.update(update) # update the seed such that the random numbers of the # next call are all different if not frozen: for i in self.extra_inputs: if isinstance(i, t.Seed): i.update() return jitoutputs self.meta = meta def __call__(self, *args, device_get=True): """Callable fn.""" # in the presence of RandomTensor(s) in the graph, we keep track of the # number of functions calls to keep accumulating the PRNGKey of the jax # key, otherwise each function call returns the same realisation outputs = self.meta(*args) if device_get: outputs = jax.device_get(outputs) return outputs