symjax

Graph and Compilation

function(*args[, outputs, updates, device, …]) Generate a user function that compiles a computational graph.
Graph(name, *args, **kwargs)
Scope([relative_name, absolute_name, …]) Defining scope for any variable/operation to be in.

Derivatives

gradients(scalar, variables) Compute the gradients of a scalar w.r.t to a given list of variables.
jacobians(tensor, variables[, mode]) Compute the jacobians of a tensor w.r.t to a given list of variables.

Graph Acces and Manipulation

clone
current_graph() Current graph.
get_variables([name, scope, trainable])
get_ops([name, scope]) Same as symjax.variable but for ops
get_placeholders([name, scope]) Same as symjax.variable but for placeholders
get_updates([name, scope, variables]) Same as symjax.variable but for ops
save_variables(path_or_file[, name, scope, …]) saves the graph variables.
load_variables(path_or_file[, name, scope, …]) loads the graph variables.
reset_variables([name, scope, trainable]) utility to reset variables based on their names

Detailed Descriptions

symjax.function(*args, outputs=[], updates=None, device=None, backend=None, default_value=None, frozen=False)[source]

Generate a user function that compiles a computational graph.

Based on given inputs, outputs and update policy of variables. This function internally jit compile the underlying jax computational graph for performances.

Parameters:
  • args (trailing tuple) – the inputs to the function to be compiled. The tuple should contain all the placeholders that are roots of any output given of the function and update values
  • outputs (List (optional)) – the outputs of the function, if a single element, it can be given as a standalone and not a list
  • updates (Dict (optional)) – the dictionnary of updates as per {var:new_value} for any variable of the graph
  • backend ('cpu' or 'gpu') – the backend to use to run the function on
  • default_value (not implemented) – not implemented
Returns:

the user frontend function that takes the specified inputs, returns the specified outputs and perform internally the updates

Return type:

callable

Examples

>>> import symjax
>>> import symjax.tensor as T
>>> x = T.ones((4, 4))
>>> xs = x.sum() + 1
>>> f = symjax.function(outputs=xs)
>>> print(f())
17.0
>>> w = T.Variable(0., name='w', dtype='float32')
>>> increment = symjax.function(updates={w: w + 1})
>>> for i in range(10):
...     increment()
>>> print(w.value)
10.0
class symjax.Graph(name, *args, **kwargs)[source]
class symjax.Scope(relative_name=None, absolute_name=None, reattach=False, reuse=False, graph=None)[source]

Defining scope for any variable/operation to be in.

Example


symjax.gradients(scalar, variables)[source]

Compute the gradients of a scalar w.r.t to a given list of variables.

Parameters:
  • scalar (symjax.tensor.base.Tensor) – the variable to differentiate
  • variables (List or Tuple) – the variables used to compute the derivative.
Returns:

gradients – the sequency of gradients ordered as given in the input variables

Return type:

Tuple

Example

>>> import symjax
>>> w = symjax.tensor.ones(3)
>>> x = symjax.tensor.Variable(2., name='x', dtype='float32')
>>> l = (w ** 2).sum() * x
>>> g = symjax.gradients(l, [w])[0]
>>> f = symjax.function(outputs=g, updates={x:x + 1})
>>> for i in range(2):
...    print(f())
[4. 4. 4.]
[6. 6. 6.]
symjax.jacobians(tensor, variables, mode='forward')[source]

Compute the jacobians of a tensor w.r.t to a given list of variables.

The tensor needs not to be a vector, but will be treated as such. For example if tensor.shape is (10, 3, 3) and a variable shape if (10, 10) the resulting jacobian has shape (10, 3, 3, 10, 10). It is possible to specify the mode forward or backward. For tall jacobians, forward is faster and vice-versa.

Parameters:
  • vector (Tensor) – the variable to differentiate
  • variables (List or Tuple) – the variables used to compute the derivative.
Returns:

jacobians – the sequency of gradients ordered as given in the input variables :param tensor: :param mode:

Return type:

Tuple

symjax.current_graph()[source]

Current graph.

symjax.get_variables(name='*', scope='/', trainable=True)[source]
symjax.get_ops(name='*', scope='/')[source]

Same as symjax.variable but for ops

symjax.get_placeholders(name='*', scope='/')[source]

Same as symjax.variable but for placeholders

symjax.get_updates(name='*', scope='/', variables=None)[source]

Same as symjax.variable but for ops

symjax.save_variables(path_or_file, name='*', scope='*', trainable=None)[source]

saves the graph variables.

The saving is done via numpy.savez for fast and compressed storage.

path_or_file: str or file
the path and name of the file to save the variables in or an open file object
name: str (optional)
the name string that the variables to save must match
scope: str (optional)
the scope name string that the variables to save must match
trainable: bool or None
the option of the variables to save (True, False or None)
symjax.load_variables(path_or_file, name='*', scope='*', trainable=None)[source]

loads the graph variables.

The loading is done via numpy.savez for fast and compressed storage.

path_or_file: str or file
the path and name of the file to load the variables from or an open file object
name: str (optional)
the name string that the variables to load must match
scope: str (optional)
the scope name string that the variables to load must match
trainable: bool or None
the option of the variables to save (True, False or None)
symjax.reset_variables(name='*', scope='*', trainable=None)[source]

utility to reset variables based on their names

Parameters:
  • name (str (default=*)) – the name (or part of the name) of all the variables that should be reset, it can include the glob (*) searching for all matching names
  • trainable (bool or None (optional, default=None)) – is not None, it will only reset from the matched variables the ones that trainable attribute matches the given one
Returns:

Return type:

None

Example

>>> import symjax
>>> w = symjax.tensor.Variable(1., name='w', dtype='float32')
>>> x = symjax.tensor.Variable(2., name='x', dtype='float32')
>>> f = symjax.function(outputs=[w, x], updates={w:w + 1,x:x + 1})
>>> for i in range(10):
...    print(f())
[array(1., dtype=float32), array(2., dtype=float32)]
[array(2., dtype=float32), array(3., dtype=float32)]
[array(3., dtype=float32), array(4., dtype=float32)]
[array(4., dtype=float32), array(5., dtype=float32)]
[array(5., dtype=float32), array(6., dtype=float32)]
[array(6., dtype=float32), array(7., dtype=float32)]
[array(7., dtype=float32), array(8., dtype=float32)]
[array(8., dtype=float32), array(9., dtype=float32)]
[array(9., dtype=float32), array(10., dtype=float32)]
[array(10., dtype=float32), array(11., dtype=float32)]
>>> # reset only the w variable
>>> symjax.reset_variables('w')
>>> # reset all variables
>>> symjax.reset_variables('*')