symjax
¶
Graph and Compilation¶
function (*args[, outputs, updates, device, …]) |
Generate a user function that compiles a computational graph. |
Graph (name, *args, **kwargs) |
|
Scope ([relative_name, absolute_name, …]) |
Defining scope for any variable/operation to be in. |
Derivatives¶
gradients (scalar, variables) |
Compute the gradients of a scalar w.r.t to a given list of variables. |
jacobians (tensor, variables[, mode]) |
Compute the jacobians of a tensor w.r.t to a given list of variables. |
Graph Acces and Manipulation¶
clone |
|
current_graph () |
Current graph. |
get_variables ([name, scope, trainable]) |
|
get_ops ([name, scope]) |
Same as symjax.variable but for ops |
get_placeholders ([name, scope]) |
Same as symjax.variable but for placeholders |
get_updates ([name, scope, variables]) |
Same as symjax.variable but for ops |
save_variables (path_or_file[, name, scope, …]) |
saves the graph variables. |
load_variables (path_or_file[, name, scope, …]) |
loads the graph variables. |
reset_variables ([name, scope, trainable]) |
utility to reset variables based on their names |
Detailed Descriptions¶
-
symjax.
function
(*args, outputs=[], updates=None, device=None, backend=None, default_value=None, frozen=False)[source]¶ Generate a user function that compiles a computational graph.
Based on given inputs, outputs and update policy of variables. This function internally jit compile the underlying jax computational graph for performances.
Parameters: - args (trailing tuple) – the inputs to the function to be compiled. The tuple should contain all the placeholders that are roots of any output given of the function and update values
- outputs (List (optional)) – the outputs of the function, if a single element, it can be given as a standalone and not a list
- updates (Dict (optional)) – the dictionnary of updates as per {var:new_value} for any variable of the graph
- backend ('cpu' or 'gpu') – the backend to use to run the function on
- default_value (not implemented) – not implemented
Returns: the user frontend function that takes the specified inputs, returns the specified outputs and perform internally the updates
Return type: callable
Examples
>>> import symjax >>> import symjax.tensor as T >>> x = T.ones((4, 4)) >>> xs = x.sum() + 1 >>> f = symjax.function(outputs=xs) >>> print(f()) 17.0
>>> w = T.Variable(0., name='w', dtype='float32') >>> increment = symjax.function(updates={w: w + 1}) >>> for i in range(10): ... increment() >>> print(w.value) 10.0
-
class
symjax.
Scope
(relative_name=None, absolute_name=None, reattach=False, reuse=False, graph=None)[source]¶ Defining scope for any variable/operation to be in.
Example
-
symjax.
gradients
(scalar, variables)[source]¶ Compute the gradients of a scalar w.r.t to a given list of variables.
Parameters: - scalar (
symjax.tensor.base.Tensor
) – the variable to differentiate - variables (List or Tuple) – the variables used to compute the derivative.
Returns: gradients – the sequency of gradients ordered as given in the input variables
Return type: Tuple
Example
>>> import symjax >>> w = symjax.tensor.ones(3) >>> x = symjax.tensor.Variable(2., name='x', dtype='float32') >>> l = (w ** 2).sum() * x >>> g = symjax.gradients(l, [w])[0] >>> f = symjax.function(outputs=g, updates={x:x + 1}) >>> for i in range(2): ... print(f()) [4. 4. 4.] [6. 6. 6.]
- scalar (
-
symjax.
jacobians
(tensor, variables, mode='forward')[source]¶ Compute the jacobians of a tensor w.r.t to a given list of variables.
The tensor needs not to be a vector, but will be treated as such. For example if tensor.shape is (10, 3, 3) and a variable shape if (10, 10) the resulting jacobian has shape (10, 3, 3, 10, 10). It is possible to specify the mode forward or backward. For tall jacobians, forward is faster and vice-versa.
Parameters: - vector (Tensor) – the variable to differentiate
- variables (List or Tuple) – the variables used to compute the derivative.
Returns: jacobians – the sequency of gradients ordered as given in the input variables :param tensor: :param mode:
Return type: Tuple
-
symjax.
get_updates
(name='*', scope='/', variables=None)[source]¶ Same as symjax.variable but for ops
-
symjax.
save_variables
(path_or_file, name='*', scope='*', trainable=None)[source]¶ saves the graph variables.
The saving is done via
numpy.savez
for fast and compressed storage.- path_or_file: str or file
- the path and name of the file to save the variables in or an open file object
- name: str (optional)
- the name string that the variables to save must match
- scope: str (optional)
- the scope name string that the variables to save must match
- trainable: bool or None
- the option of the variables to save (
True
,False
orNone
)
-
symjax.
load_variables
(path_or_file, name='*', scope='*', trainable=None)[source]¶ loads the graph variables.
The loading is done via
numpy.savez
for fast and compressed storage.- path_or_file: str or file
- the path and name of the file to load the variables from or an open file object
- name: str (optional)
- the name string that the variables to load must match
- scope: str (optional)
- the scope name string that the variables to load must match
- trainable: bool or None
- the option of the variables to save (
True
,False
orNone
)
-
symjax.
reset_variables
(name='*', scope='*', trainable=None)[source]¶ utility to reset variables based on their names
Parameters: - name (str (default=*)) – the name (or part of the name) of all the variables that should be reset, it can include the glob (*) searching for all matching names
- trainable (bool or None (optional, default=None)) – is not None, it will only reset from the matched variables the ones that trainable attribute matches the given one
Returns: Return type: None
Example
>>> import symjax >>> w = symjax.tensor.Variable(1., name='w', dtype='float32') >>> x = symjax.tensor.Variable(2., name='x', dtype='float32') >>> f = symjax.function(outputs=[w, x], updates={w:w + 1,x:x + 1}) >>> for i in range(10): ... print(f()) [array(1., dtype=float32), array(2., dtype=float32)] [array(2., dtype=float32), array(3., dtype=float32)] [array(3., dtype=float32), array(4., dtype=float32)] [array(4., dtype=float32), array(5., dtype=float32)] [array(5., dtype=float32), array(6., dtype=float32)] [array(6., dtype=float32), array(7., dtype=float32)] [array(7., dtype=float32), array(8., dtype=float32)] [array(8., dtype=float32), array(9., dtype=float32)] [array(9., dtype=float32), array(10., dtype=float32)] [array(10., dtype=float32), array(11., dtype=float32)] >>> # reset only the w variable >>> symjax.reset_variables('w') >>> # reset all variables >>> symjax.reset_variables('*')