Welcome to SymJAX’s documentation!¶
SymJAX = JAX+NetworkX
JAX
JAX is a XLA python interface that provides a Numpy-like user experience with just-in-time compilation and Autograd powered automatic differenciation. XLA is a compiler that optimizes a computational graph by fusing multiple kernels into one preventing intermediate computation, reducing memory operations and increasing performances.
NetworkX
NetworkX is a Python package for the creation, manipulation, and study of directed and undirected graphs is a Python package for the creation, manipulation, and study of directed and undirected graphs is a Python package for the creation, manipulation, and study of directed and undirected graphs.
SymJAX is a symbolic programming version of JAX providing a Theano-like user experience thanks to a NetworkX powered computational graph backend. In addition of simplifying graph input/output, variable updates and graph utilities, SymJAX also features machine learning and deep learning tools similar to Lasagne and Tensorflow1 as well as a lazy on-the-go execution capability like PyTorch and Tensorflow2.
This is an under development research project, not an official product, expect bugs and sharp edges; please help by trying it out, reporting bugs and missing pieces.
Installation Guide : Installation
Implementation Walkthrough : Computational Graph
Developer Guide : Development
Updates Roadmap : Roadmap
Modules¶
We briefly describe below the structure of SymJAX and what are (in term of functionalities) the closest analog from other known libraries:
- symjax.data : everything related to downloading/importing/batchifying/patchifying datasets. Large corpus of time-series and computer vision dataset, similar to
tensorflow_datasets
with additional utilities - symjax.tensor : everything related to operating with tensors (array like objects) similar to
numpy
andtheano.tensor
, specialized submodules are- symjax.tensor.linalg: like
scipy.linalg
andnumpy.linalg
- symjax.tensor.fft: like
numpy.fft
- symjax.tensor.signal: like
scipy.signal
+ additional time-frequency and wavelet tools - symjax.tensor.random: like
numpy.random
- symjax.tensor.linalg: like
- symjax.nn : everything related to machine/deep-learning mixing
lasagne
,tensorflow
,torch.nn
andkeras
and subdivided into- symjax.nn.layers: like
lasagne.layers
ortf.keras.layers
- symjax.nn.optimizers: like
lasagne.optimizers
ortf.keras.optimizers
- symjax.nn.losses: like
lasagne.losses
ortf.keras.losses
- symjax.nn.initializers: like
lasagne.initializers
ortf.keras.initializers
- symjax.nn.schedules: external variable state control (s.a. for learning rate schedules) as in
lasagne.initializers
ortf.keras.optimizers.schedules
or optax
- symjax.nn.layers: like
- symjax.probabilities : like
tensorflow-probabilities
- symjax.rl : like tfagents or OpenAI SpinningUp and Baselines (no environment is implemented as Gym already provides a large collection), submodules are
symjax.rl.utils
providing utilities to interact with environments, play, learn, buffers, …symjax.rl.agents
providing the basic agents such as DDPG, PPO, DQN, …
Tutorials¶
SymJAX¶
- Function: compiling a graph into an executable (function)
- Clone: one line multipurpose graph replacement
- Variable batch length (shape)
- while
- Graph Saving and Loading
- Graph visualization
- Wrap: Jax function/computation to SymJAX Op
- Wrap: Jax class to SymJAX class
- Function: compiling a graph into an executable (function)