#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Base of symjax."""
import fnmatch
import jax
import numpy
from jax import jacfwd, jacrev
import symjax
from symjax import tensor as t
from symjax.tensor import random
import networkx as nx
import re
import collections
import os
def natural_key(string_):
"""See http://www.codinghorror.com/blog/archives/001018.html"""
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_)]
def current_scope():
"""Current scope."""
return current_graph()._scopes[-1]
[docs]def current_graph():
"""Current graph."""
assert len(symjax._graphs)
return symjax._graphs[-1]
[docs]class Graph(nx.DiGraph):
def __init__(self, name, *args, **kwargs):
super().__init__(*args, **kwargs)
self._name = name
self.reset()
self._jax_to_op = {}
def reset(self):
self.clear()
self._scopes = [Scope(absolute_name="/" + self.name, graph=self)]
self._updates = {}
self._scopes_history = []
self._branches = {}
self._name_counts = {}
def __repr__(self):
msg = "Graph(name:{}, n_nodes:{})".format(self.name, len(self.nodes))
return msg
def __str__(self):
return self.__repr__()
@property
def name(self):
return self._name
def add_updates(self, updates):
# TODO
# we add just in case there are some lists or something
# this might not be needed
for i in updates.values():
self._add(i)
self._updates.update(updates)
def _get_name_scope(self, name, tensor):
return self.scope._get_name_scope(name, tensor)
def is_connected(self, node_1, node_2, directed=True):
"""check if two nodes are connected in the graph.
This function is useful to check wheter two nodes (or Op)
as connected in the graph and are thus dependent on each other.
For example this is useful to select on trainable variables that
affect a specific tensor.
Args:
node_1: Tensor
node_2: Tensor
directed: bool
whether to test for both directions or not, if the graph
is not directed then this parameter has no effect.
If ``False`` then the function will return ``True``
is the nodes are connected no matter the direction
Returns:
bool
"""
if directed:
return nx.has_path(self, node_1, node_2)
else:
return nx.has_path(self, node_1, node_2) or nx.has_path(
self, node_2, node_1
)
def _add(self, tensor, *args, _attrs=None, **kwargs):
_attrs = _attrs or {}
# first we check is the not is a hashable, if it is not
# already in the graph
if isinstance(tensor, collections.Hashable) and type(tensor) != tuple:
if tensor in self.nodes:
return tensor
# then, if the node is a constant (not tensor, variable, ...)
if not t.isvar(tensor):
# if it is hashable then we can add it as is as a node
# and we return it since it is the same object
# if isinstance(tensor, collections.Hashable):
# self.add_node(tensor, root=False)
# return tensor
# otherwise we have to make it hashable, to do so we use
# the constant object
# during creation of the object, it will be added to the
# graph automatically, we can return the object
# else:
return t.Constant(tensor)
if type(tensor) == list or type(tensor) == tuple:
return t.Tuple(*tensor)
# now check if it is a list of a tuple
if isinstance(tensor, t.Tensor):
self.add_node(tensor, **_attrs)
if isinstance(tensor, t.Op):
for i, arg in enumerate(args):
node = self._add(arg)
if self.has_edge(node, tensor):
self[node][tensor]["name"] += "+arg" + str(i)
else:
self.add_edge(node, tensor, name="arg" + str(i))
for key, arg in kwargs.items():
node = self._add(arg)
if self.has_edge(node, tensor):
self[node][tensor]["name"] += "+" + key
else:
self.add_edge(node, tensor, name=key)
return tensor
def roots(self, nodes, roots=None):
if roots is None:
roots = []
if type(nodes) == tuple or type(nodes) == list:
for node in nodes:
self.roots(node, roots)
else:
if self.nodes[nodes]["root"]:
roots.append(nodes)
for i in self.ancestors(nodes):
if self.nodes[i]["root"]:
roots.append(i)
return list(set(roots))
def ancestors(self, nodes, acc=None):
if acc is None:
acc = set()
if type(nodes) == list:
for node in nodes:
acc = acc.union(self.ancestors(node, acc))
return acc
else:
predecessors = list(self.predecessors(nodes))
predecessors = [p for p in predecessors if p not in acc]
if len(predecessors) == 0:
return acc
else:
acc = acc.union(set(predecessors))
return self.ancestors(predecessors, acc)
def clone(self, node, todos, done=None, share_random_ops=True):
"""
Function that allows replacing subgraphs of a computational graph.
It returns a copy of the initial subgraph with the corresponding
substitutions.
Parameters
----------
node : Tensor
the node to replace
input_givens : dict
Dictionary describing which subgraphs should be replaced by what.
share_random_ops: bool
whether to get the same random variables
realisations or not between
the original tensor and the cloned
one
"""
if done is None:
done = {}
if node in done:
return done[node]
elif node in todos:
return todos[node]
elif isinstance(node, t.Variable) or isinstance(node, t.Placeholder):
return node
elif isinstance(node, t.RandomOp) and share_random_ops:
return node
elif isinstance(node, t.Constant):
return node.value
elif len(set(todos) - set(done)) == 0:
return node
elif isinstance(node, t.OpItem):
parent = list(self.predecessors(node))[0]
index = int(self[parent][node]["name"].split("parent_index")[1])
done[node] = self.clone(parent, todos, done)[index]
return done[node]
elif type(node) == tuple or type(node) == list or type(node) == t.Tuple:
return [self.clone(n, todos, done) for n in node]
args, kwargs = self.get_args_kwargs(node, evaluate=False)
new_args = [
self.clone(n, todos, done) for n in args if not isinstance(n, t.Seed)
]
new_kwargs = {name: self.clone(n, todos, done) for name, n in kwargs.items()}
fun = self.get_node_attribute(node, "jax_function")
done[node] = symjax._fn_to_op[fun.__name__](*new_args, **new_kwargs)
return done[node]
def get_node_attribute(self, node, attr):
return nx.get_node_attributes(self, attr)[node]
def unnest(self, container):
for i in container:
if isinstance(i, (list, tuple)):
for j in self.unnest(i):
yield j
else:
yield i
def get(self, input_tensor, tracker=None, frozen=True):
if not t.isvar(input_tensor):
return input_tensor
if tracker is None:
tracker = {}
mapped = {}
else:
mapped = tracker.copy()
tomap = []
tosearch = list(self.unnest([input_tensor]))
# the first graph traversal has for objective to go through the parents
# and readh the furthest parents to mapped their value into the mapped
# hashmap
while len(tosearch):
for item in tosearch:
tosearch.remove(item)
if item in tracker:
mapped[item] = tracker[item]
elif isinstance(item, t.Variable) or isinstance(item, t.Constant):
mapped[item] = item.value
elif isinstance(item, t.Placeholder):
if item in tracker:
mapped[item] = tracker[item]
else:
raise ValueError(
" no value given for placeholder {}".format(item)
)
elif isinstance(item, t.OpItem):
parent = list(self.predecessors(item))[0]
tomap.append(parent)
if parent not in tosearch:
tosearch.append(parent)
elif not t.isvar(item) and not isinstance(item, t.Constant):
tomap.append(t.Constant(item))
else:
# this item is an Op and thus will need to be mapped later on
tomap.append(item)
# we retreive the parents of the Op node to go upper in the graph
item_parents = list(self.predecessors(item))
# we loop through the parents and only add the unique ones to the
# searchlist for the next loop iteration
for it in list(self.unnest(list(item_parents))):
if it not in tosearch:
tosearch.append(it)
# now tosearch is empty, and the parents of item that could be assigned
# an explicit value are mapped into the mapped hashmap
# we start from the end as the last added items will have parents already in mapped
for item in tomap[::-1]:
if isinstance(item, t.OpItem):
parent = list(self.predecessors(item))[0]
index = int(self[parent][item]["name"].split("parent_index")[1])
if parent not in mapped:
mapped[parent] = self._get_value_given_mapped(
parent, mapped, frozen=frozen
)
mapped[item] = mapped[parent][index]
elif isinstance(item, t.MultiOutputOp):
mapped[item] = self._get_value_given_mapped(item, mapped, frozen=frozen)
for node, value in zip(item, mapped[item]):
mapped[node] = value
else:
mapped[item] = self._get_value_given_mapped(item, mapped, frozen=frozen)
# now generate the mapped inputs
output = self._to_mapped_value(input_tensor, mapped)
return self._to_mapped_value(input_tensor, mapped)
def _get_value_given_mapped(self, item, mapped, frozen=False):
if isinstance(item, t.Constant):
return item.value
elif not t.isvar(item):
return item
args, kwargs = self.get_args_kwargs(item, evaluate=False, frozen=frozen)
mapped_args = self._to_mapped_value(args, mapped)
mapped_kwargs = {
key: self._to_mapped_value(value, mapped) for key, value in kwargs.items()
}
return self.nodes[item]["jax_function"](*mapped_args, **mapped_kwargs)
def _to_mapped_value(self, args, mapped):
if not t.isvar(args):
return args
elif type(args) == list:
return [self._to_mapped_value(arg, mapped) for arg in args]
elif type(args) == tuple:
return tuple([self._to_mapped_value(arg, mapped) for arg in args])
else:
return mapped[args]
def all_predecessors(self, item):
if type(item) == list or type(item) == tuple:
predecessors = []
for i in item:
predecessors += self.all_predecessors(i)
return predecessors
else:
return self.predecessors(item)
def get_args_kwargs(self, node, tracker=None, evaluate=True, frozen=True):
if evaluate:
assert tracker is not None
parents = list(self.all_predecessors(node))
all_args = {}
for i in range(len(parents)):
all_args[self[parents[i]][node]["name"]] = self.get(
parents[i], tracker, frozen=frozen
)
else:
all_args = {
self[parent][node]["name"]: parent
for parent in self.all_predecessors(node)
}
# now we inspect if there are duplicate args
for arg in list(all_args.keys()):
if "+" in arg:
items = arg.split("+")
for item in items:
all_args.update({item: all_args[arg]})
del all_args[arg]
arg_names = [name for name in all_args.keys() if "arg" == name[:3]]
arg_names = sorted(arg_names, key=natural_key)
args = [all_args[name] for name in arg_names]
for name in arg_names:
del all_args[name]
return args, all_args
@property
def scope(self):
return self._scopes[-1]
@property
def scope_name(self):
return self._scopes[-1].absolute_name
def variables(self, trainable=True):
variables = [n for n in self.nodes if isinstance(n, t.Variable)]
if trainable is None:
return variables
return [v for v in variables if v.trainable == trainable]
@property
def placeholders(self):
placeholders = [n for n in self.nodes if type(n) == t.Placeholder]
return placeholders
@property
def ops(self):
ops = [n for n in self.nodes if type(n) in [t.Op, t.MultiOutputOp]]
return ops
@property
def updates(self):
return self._updates
@property
def other_nodes(self):
return list(
set(self.nodes)
- (
set(self.ops).union(
set(self.placeholders).union(set(self.variables(None)))
)
)
)
[docs]class Scope:
"""
Defining scope for any variable/operation to be in.
Example
-------
.. doctest::
>>> import symjax
>>> import symjax.tensor as T
>>> v = T.Variable(1, name='v')
>>> # the current (default) scope is the root of the graph
>>> print(v.scope)
/
>>> with symjax.Scope('my_scope'):
... w = T.Variable(2, name='w')
... out = v*w
>>> print(out.scope)
/my_scope/
>>> print(w.scope)
/my_scope/
>>> #it is also possible to keep a scope persistently
>>> scope = symjax.Scope('second_scope')
>>> with scope:
... other = out * w
>>> print(other.scope)
/second_scope/
>>> # this allows to keep track directly of internal ops
>>> print(scope.ops)
[Op(name=multiply, fn=multiply, shape=(), dtype=int32, scope=/second_scope/)]
"""
def __init__(
self,
relative_name=None,
absolute_name=None,
reattach=False,
reuse=False,
graph=None,
):
"""Constructor."""
assert relative_name is not None or absolute_name is not None
if relative_name is not None:
assert "/" not in relative_name
self.reattach = reattach
self.graph = graph
self.reuse = reuse
self.relative_name = relative_name
self.absolute_name = absolute_name
if reuse or reattach:
assert reattach
assert absolute_name in self.graph._scopes_history
def __enter__(self):
"""Set global variables."""
if self.graph is None:
self.graph = current_graph()
if self.absolute_name is None:
self.absolute_name = os.path.join(self.graph.scope_name, self.relative_name)
if self.reattach:
return self
if self.absolute_name in self.graph._scopes_history:
cpt = 1
self.absolute_name += "_"
while self.absolute_name + str(cpt) in self.graph._scopes_history:
cpt += 1
self.absolute_name += str(cpt)
self.relative_name += "_" + str(cpt)
self.graph._scopes_history.append(self.absolute_name)
self.graph._scopes.append(self)
return self
def __exit__(self, *args):
"""Delete globals."""
self.graph._scopes = self.graph._scopes[:-1]
# def save_variables(self, path):
# """Save graph."""
# if ".npz" != path[:-4]:
# path += ".npz"
# numpy.savez(
# path,
# **dict([(v.name, symjax.tensor.get(v)) for v in self.variables]),
# )
# def variables(self, trainable=True):
# return get_variables(scope=self.full_name, trainable=trainable)
# @property
# def ops(self):
# return get_ops(scope=self.full_name)
# @property
# def placeholders(self):
# return get_placeholders(scope=self.full_name)
# def load_variables(self, path):
# """Load graph."""
# data = numpy.load(path)
# for name, value in data.items():
# self.variable[name].update(value)
# def reset(self):
# for var in self.variables:
# var.reset()
def _get_name_scope(self, name, tensor):
# if self.full_name is None:
# self.__enter__()
if isinstance(tensor, symjax.tensor.Placeholder):
nodes = self.graph.placeholders
elif isinstance(tensor, symjax.tensor.Variable):
nodes = self.graph.variables(None)
elif isinstance(tensor, symjax.tensor.Op):
nodes = self.graph.ops
else:
nodes = self.graph.other_nodes
names = [os.path.join(m.scope, m.name) for m in nodes if hasattr(m, "name")]
test_name = os.path.join(self.absolute_name, name)
# if we never used this name before then we can just keep it as is
if test_name not in names:
return name, self.absolute_name
# otherwise we append the suffix _N with N the number of times we
# already used this name. To keep track of how many times
# did we already use a name we employ a look up table to have
# fast retreival, we also update that count by 1
else:
# note that if we allow reuse, then we just return the variable
# that had this name
if self.reuse and isinstance(tensor, symjax.tensor.Variable):
return nodes[nodes.index(tensor)]
# if we never saw this variable before (first time it is repeated)
# then we create its entry in the look up table
if test_name not in self.graph._name_counts:
self.graph._name_counts[test_name] = 0
# we increase the counter
self.graph._name_counts[test_name] += 1
return (
name + "_" + str(self.graph._name_counts[test_name]),
self.absolute_name,
)
[docs]def reset_variables(name="*", scope="*", trainable=None):
"""
utility to reset variables based on their names
Parameters
----------
name: str (default=*)
the name (or part of the name) of all the variables that should be
reset, it can include the glob (*) searching for all matching
names
trainable: bool or None (optional, default=None)
is not None, it will only reset from the matched variables the ones that
trainable attribute matches the given one
Returns
-------
None
Example
-------
.. doctest::
>>> import symjax
>>> w = symjax.tensor.Variable(1., name='w', dtype='float32')
>>> x = symjax.tensor.Variable(2., name='x', dtype='float32')
>>> f = symjax.function(outputs=[w, x], updates={w:w + 1,x:x + 1})
>>> for i in range(10):
... print(f())
[array(1., dtype=float32), array(2., dtype=float32)]
[array(2., dtype=float32), array(3., dtype=float32)]
[array(3., dtype=float32), array(4., dtype=float32)]
[array(4., dtype=float32), array(5., dtype=float32)]
[array(5., dtype=float32), array(6., dtype=float32)]
[array(6., dtype=float32), array(7., dtype=float32)]
[array(7., dtype=float32), array(8., dtype=float32)]
[array(8., dtype=float32), array(9., dtype=float32)]
[array(9., dtype=float32), array(10., dtype=float32)]
[array(10., dtype=float32), array(11., dtype=float32)]
>>> # reset only the w variable
>>> symjax.reset_variables('w')
>>> # reset all variables
>>> symjax.reset_variables('*')
"""
variables = get_variables(name=name, scope=scope, trainable=trainable)
for var in variables:
var.reset()
[docs]def save_variables(
path_or_file,
name="*",
scope="*",
trainable=None,
):
"""saves the graph variables.
The saving is done via ``numpy.savez`` for fast and compressed storage.
Parameters:
-----------
path_or_file: str or file
the path and name of the file to save the variables in or an
open file object
name: str (optional)
the name string that the variables to save must match
scope: str (optional)
the scope name string that the variables to save must match
trainable: bool or None
the option of the variables to save (``True``, ``False`` or ``None``)
"""
if type(path_or_file) == str:
if path_or_file[-4:] != ".npz":
path_or_file += ".npz"
variables = get_variables(name, scope, trainable)
numpy.savez(
path_or_file,
**dict(
[
(
v.scope + v.name,
symjax.tensor.get(v),
)
for v in variables
]
),
)
[docs]def load_variables(path_or_file, name="*", scope="*", trainable=None):
"""loads the graph variables.
The loading is done via ``numpy.savez`` for fast and compressed storage.
Parameters:
-----------
path_or_file: str or file
the path and name of the file to load the variables from or an
open file object
name: str (optional)
the name string that the variables to load must match
scope: str (optional)
the scope name string that the variables to load must match
trainable: bool or None
the option of the variables to save (``True``, ``False`` or ``None``)
"""
if type(path_or_file) == str:
if path_or_file[-4:] != ".npz":
path_or_file += ".npz"
variables = get_variables(name, scope, trainable=trainable)
data = numpy.load(path_or_file)
for var in variables:
name_in_file = var.scope + var.name
if name_in_file not in data:
raise Warning("{} not in loaded file".format(name_in_file))
var.update(data[name_in_file])
[docs]def get_variables(name="*", scope="/", trainable=True):
matched = current_graph().variables(trainable)
output = []
for m in matched:
if len(fnmatch.filter([m.name], name)) and len(
fnmatch.filter([m.scope], scope + "*")
):
output.append(m)
return output
[docs]def get_placeholders(name="*", scope="/"):
"""
Same as symjax.variable but for placeholders
"""
matched = current_graph().placeholders
output = []
for m in matched:
if len(fnmatch.filter([m.name], name)) and len(
fnmatch.filter([m.scope], scope + "*")
):
output.append(m)
return output
[docs]def get_ops(name="*", scope="/"):
"""
Same as symjax.variable but for ops
"""
matched = current_graph().ops
output = []
for m in matched:
if len(fnmatch.filter([m.name], name)) and len(
fnmatch.filter([m.scope], scope + "*")
):
output.append(m)
return output
[docs]def get_updates(name="*", scope="/", variables=None):
"""
Same as symjax.variable but for ops
"""
matched = current_graph().updates
output = {}
if variables is not None:
for v in variables:
if v in matched:
output[v] = matched[v]
return output
for var, update in matched.items():
if len(fnmatch.filter([update.name], name)) and len(
fnmatch.filter([update.scope], scope + "*")
):
output[var] = update
return output
def add_updates(updates):
current_scope().add_updates(updates)
[docs]def gradients(scalar, variables):
"""Compute the gradients of a scalar w.r.t to a given list of variables.
Arguments
---------
scalar: :class:`symjax.tensor.base.Tensor`
the variable to differentiate
variables: List or Tuple
the variables used to compute the derivative.
Returns
-------
gradients: Tuple
the sequency of gradients ordered as given in the input variables
Example
-------
.. doctest::
>>> import symjax
>>> w = symjax.tensor.ones(3)
>>> x = symjax.tensor.Variable(2., name='x', dtype='float32')
>>> l = (w ** 2).sum() * x
>>> g = symjax.gradients(l, [w])[0]
>>> f = symjax.function(outputs=g, updates={x:x + 1})
>>> for i in range(2):
... print(f())
[4. 4. 4.]
[6. 6. 6.]
"""
if numpy.prod(scalar.shape) != 1:
raise RuntimeError("the variable to differentiate is not a scalar")
if not isinstance(scalar, t.Tensor):
raise RuntimeError("the variable used in gradients should be a Tensor type")
if isinstance(variables, t.Tensor):
input_variables = [variables]
input_list = False
else:
input_variables = variables.copy()
input_list = True
# get the argnum of the variables that we differentiate one
argnums = list(range(len(input_variables)))
# get all the roots of the scalar, this is needed as otherwise they are not
# as the input of the gradient function and thus a change of
# their value will not change the gradient computation, we also ensure
# uniqueness
input_variables += [
i for i in current_graph().roots(scalar) if i not in input_variables
]
# create a dummy function that is needed for jax to compute a gradient func
# this function is the one that builds the graph of computation from all
# roots
# to the scalar varible s.t. automatic diffenrentiation can be applied
def internal_gradient_function(*args):
return current_graph().get(
scalar,
{input_variables[i]: args[i] for i in range(len(input_variables))},
)
# now we obtain the grad function. In fact, Jax returns a function that,
# when it is called, returns the gradient values, this function is then
# used to generate the Tuple of symbolic variables
grad_fn = jax.grad(internal_gradient_function, argnums)
wrap_fn = t.jax_wrap(grad_fn)
if input_list:
return wrap_fn(*input_variables)
else:
return wrap_fn(*input_variables)[0]
[docs]def jacobians(tensor, variables, mode="forward"):
"""Compute the jacobians of a tensor w.r.t to a given list of variables.
The tensor needs not to be a vector, but will be treated as such. For
example if tensor.shape is (10, 3, 3) and a variable shape if (10, 10)
the resulting jacobian has shape (10, 3, 3, 10, 10). It is possible
to specify the mode forward or backward. For tall jacobians, forward
is faster and vice-versa.
Arguments
---------
vector: Tensor
the variable to differentiate
variables: List or Tuple
the variables used to compute the derivative.
Returns
-------
jacobians: Tuple
the sequency of gradients ordered as given in the input variables
:param tensor:
:param mode:
"""
if isinstance(variables, t.Tensor):
input_variables = [variables]
input_list = False
else:
input_variables = variables.copy()
input_list = True
# get the argnum of the variables that we differentiate one
argnums = list(range(len(input_variables)))
# get all the roots of the scalar, this is needed as otherwise they are not
# as the input of the gradient function and thus a change of
# their value will not change the gradient computation, we also ensure
# uniqueness
input_variables += [
i for i in current_graph().roots(tensor) if i not in input_variables
]
# create a dummy function that is needed for jax to compute a gradient func
# this function is the one that builds the graph of computation from
# all roots
# to the scalar varible s.t. automatic diffenrentiation can be applied
def internal_jacobian(*args):
return current_graph().get(
tensor,
{input_variables[i]: args[i] for i in range(len(input_variables))},
)
# now we obtain the jacobian function. In fact, Jax returns a function that
# when it is called, returns the jacobian values, this function is then
# used to generate the Tuple of symbolic variables
if mode == "forward":
jacob_fn = jacfwd(internal_jacobian, argnums)
elif mode == "backward":
jacob_fn = jacrev(internal_jacobian, argnums)
else:
raise RuntimeError(
"mode {} not recognized, use forward or backward".format(mode)
)
wrap_fn = t.jax_wrap(jacob_fn, False)
if input_list:
return wrap_fn(*input_variables)
else:
return wrap_fn(*input_variables)[0]
[docs]class function:
"""Generate a user function that compiles a computational graph.
Based on given inputs, outputs and update policy of variables. This
function internally jit compile the underlying jax computational
graph for performances.
Arguments
---------
args: trailing tuple
the inputs to the function to be compiled. The tuple should
contain all the placeholders that are roots of any output
given of the function and update values
outputs: List (optional)
the outputs of the function, if a single element, it can be
given as a standalone and not a list
updates: Dict (optional)
the dictionnary of updates as per {var:new_value} for any
variable of the graph
backend: 'cpu' or 'gpu'
the backend to use to run the function on
default_value: not implemented
not implemented
Returns
-------
callable:
the user frontend function that takes the specified inputs,
returns the specified outputs and perform internally the
updates
Examples
--------
>>> import symjax
>>> import symjax.tensor as T
>>> x = T.ones((4, 4))
>>> xs = x.sum() + 1
>>> f = symjax.function(outputs=xs)
>>> print(f())
17.0
>>> w = T.Variable(0., name='w', dtype='float32')
>>> increment = symjax.function(updates={w: w + 1})
>>> for i in range(10):
... increment()
>>> print(w.value)
10.0
"""
def __init__(
self,
*args,
outputs=[],
updates=None, # noqa
device=None,
backend=None,
default_value=None,
frozen=False
):
"""Initialize."""
# check the given updates (if any) and ensure that they only
# update Variable objects
if updates is None:
updates = {}
for update in updates:
if not isinstance(update, t.Variable):
raise RuntimeError(
"{} is not a Variable and cannot be updated".format(update)
)
# ensure that all inputs are actual placeholders or variables
for arg in args:
if not isinstance(arg, t.Tensor):
raise RuntimeError(
"{} is not a Tensor type. Only tensor types can be"
+ "function inputs".format(arg)
)
# gather all roots, they need to be explicit as inputs of the
# underlying functions otherwise they are treated as constants
# and any change in their value will not appear when running the
# function
self.updates_keys = list(updates.keys())
self.updates_values = list(updates.values())
self.args = args
self.outputs = outputs
outs = self.updates_values.copy()
outs += [outputs] if isinstance(outputs, t.Tensor) else outputs
self.all_roots = set(symjax.current_graph().roots(outs))
# check the function inputs, they must be at least contain all the
# placeholders needed to compute the outputs values
placeholders_in_root = [
x for x in self.all_roots if isinstance(x, t.Placeholder)
]
# check for
non_givens = set(placeholders_in_root) - set(self.args)
if len(non_givens) > 0:
raise RuntimeError(
"""\
Missing placeholders from the function inputs...\n\
...Missings are: {}""".format(
non_givens
)
)
# the roots are made of variables, random tensors, placeholders. We
# already ensured that all placeholders are given as inputs to the
# function. Now we must ensure that the other ones will also be given
# as inputs to not be treated as constants by jax.
# we also remove update keys because we will expicitly feed them
self.extra_inputs = set(self.all_roots) - set(self.args).union(
self.updates_keys
)
self.extra_inputs = list(self.extra_inputs)
allargs = list(self.args) + self.updates_keys + self.extra_inputs
def to_jit(*jitargs):
feed_dict = dict(zip(allargs, jitargs))
outputs = [self.outputs, self.updates_values]
return symjax.current_graph().get(outputs, feed_dict)
self.jited = jax.jit(to_jit, device=device, backend=backend)
# define the frontend function that takes as input the inputs variables
# and internally compute and update the variables from updates if any
def meta(*fnargs):
# ensure that the number of arguments is correct
assert len(fnargs) == len(self.args)
for fnarg, classarg in zip(fnargs, self.args):
if hasattr(fnarg, "shape"):
if fnarg.shape != classarg.shape and 0 not in classarg.shape:
raise RuntimeError(
"wrong input given for {}".format(classarg)
+ ", given is {}".format(fnarg)
+ ", shape={}".format(fnarg.shape)
)
# retreive the function outputs, updated values and apply them
jited_add_inputs = symjax.current_graph().get(
self.updates_keys + self.extra_inputs
)
jitoutputs, jitupdates = self.jited(*fnargs, *jited_add_inputs)
for key, update in zip(self.updates_keys, jitupdates):
key.update(update)
# update the seed such that the random numbers of the
# next call are all different
if not frozen:
for i in self.extra_inputs:
if isinstance(i, t.Seed):
i.update()
return jitoutputs
self.meta = meta
def __call__(self, *args, device_get=True):
"""Callable fn."""
# in the presence of RandomTensor(s) in the graph, we keep track of the
# number of functions calls to keep accumulating the PRNGKey of the jax
# key, otherwise each function call returns the same realisation
outputs = self.meta(*args)
if device_get:
outputs = jax.device_get(outputs)
return outputs