Welcome to SymJAX’s documentation!¶
SymJAX = JAX+NetworkX
JAX
JAX is a XLA python interface that provides a Numpy-like user experience with just-in-time compilation and Autograd powered automatic differenciation. XLA is a compiler that optimizes a computational graph by fusing multiple kernels into one preventing intermediate computation, reducing memory operations and increasing performances.
NetworkX
NetworkX is a Python package for the creation, manipulation, and study of directed and undirected graphs is a Python package for the creation, manipulation, and study of directed and undirected graphs is a Python package for the creation, manipulation, and study of directed and undirected graphs.
SymJAX is a symbolic programming version of JAX providing a Theano-like user experience thanks to a NetworkX powered computational graph backend. In addition of simplifying graph input/output, variable updates and graph utilities, SymJAX also features machine learning and deep learning tools similar to Lasagne and Tensorflow1 as well as a lazy on-the-go execution capability like PyTorch and Tensorflow2.
This is an under development research project, not an official product, expect bugs and sharp edges; please help by trying it out, reporting bugs and missing pieces.
Installation Guide : Installation
Implementation Walkthrough : Computational Graph
Developer Guide : Development
Updates Roadmap : Roadmap
Modules¶
We briefly describe below the structure of SymJAX and what are (in term of functionalities) the closest analog from other known libraries:
- symjax.data : everything related to downloading/importing/batchifying/patchifying datasets. Large corpus of time-series and computer vision dataset, similar to
tensorflow_datasets
with additional utilities - symjax.tensor : everything related to operating with tensors (array like objects) similar to
numpy
andtheano.tensor
, specialized submodules are- symjax.tensor.linalg: like
scipy.linalg
andnumpy.linalg
- symjax.tensor.fft: like
numpy.fft
- symjax.tensor.signal: like
scipy.signal
+ additional time-frequency and wavelet tools - symjax.tensor.random: like
numpy.random
- symjax.tensor.linalg: like
- symjax.nn : everything related to machine/deep-learning mixing
lasagne
,tensorflow
,torch.nn
andkeras
and subdivided into- symjax.nn.layers: like
lasagne.layers
ortf.keras.layers
- symjax.nn.optimizers: like
lasagne.optimizers
ortf.keras.optimizers
- symjax.nn.losses: like
lasagne.losses
ortf.keras.losses
- symjax.nn.initializers: like
lasagne.initializers
ortf.keras.initializers
- symjax.nn.schedules: external variable state control (s.a. for learning rate schedules) as in
lasagne.initializers
ortf.keras.optimizers.schedules
or optax
- symjax.nn.layers: like
- symjax.probabilities : like
tensorflow-probabilities
- symjax.rl : like tfagents or OpenAI SpinningUp and Baselines (no environment is implemented as Gym already provides a large collection), submodules are
symjax.rl.utils
providing utilities to interact with environments, play, learn, buffers, …symjax.rl.agents
providing the basic agents such as DDPG, PPO, DQN, …
Tutorials¶
SymJAX¶
- Function: compiling a graph into an executable (function)
- Clone: one line multipurpose graph replacement
- Variable batch length (shape)
- while
- Graph Saving and Loading
- Graph visualization
- Wrap: Jax function/computation to SymJAX Op
- Wrap: Jax class to SymJAX class
- Function: compiling a graph into an executable (function)
Amortized Variational Inference¶
Gallery¶
Installation¶
SymJAX has a couple of prerequisites that need to be installed first.
CPU only installation¶
Installation of SymJAX and all its dependencies (including Jax). For CPU only support is done simply as follows
$ pip install --upgrade jaxlib $ pip install --upgrade jax $ pip install --upgrade symjax
GPU installation¶
For the GPU support, the Jax installation needs to be done first and based on the local cuda settings following Jax Installation. In short, the steps involve
Installation of GPU drivers/libraries/compilers (
cuda
,cudnn
,nvcc
).Install
jax
following Jax Installation.Install SymJAX with
$ pip install --upgrade symjax
Manual (local/bleeding-edge) installation of SymJAX¶
In place of the base installation of SymJAX from the latest official release from PyPi, one can install the latest version of SymJAX from the github repository as follows
Clone this repository with
$ git clone https://github.com/RandallBalestriero/SymJAX
Install.
$ cd SymJAX $ pip install .
Note that whenever changes are made to the SymJAX github repository, one can pull those changes bu running
$ git pull
from within the cloned repository. However the changes won’t impact the installed version unless the install was done with
$ pip install -e .
Tutorials¶
SymJAX¶
We briefly describe some key components of SymJAX.
Function: compiling a graph into an executable (function)¶
As opposed to most current softwares, SymJAX proposes a symbolic viewpoint from which one can create a computational graph, laying out all the computation pipeline from inputs to outputs including updates of persistent variables. Once this is defined, it is possible to compile this graph to optimize the exection speed. In fact, knowing the graph (nodes, connections, shapes, types, constant values) is enough to produce an highly optimized executable of this graph. In SymJAX this is done via symjax.function as demonstrated below:
import symjax
import symjax.tensor as T
value = T.Variable(T.ones(()))
randn = T.random.randn(())
rand = T.random.rand(())
out1 = randn * value
out2 = out1.clone({randn: rand})
f = symjax.function(rand, outputs=out2, updates={value: 2 + value})
for i in range(3):
print(f(i))
# 0.
# 3.
# 10.
# we create a simple computational graph
var = T.Variable(T.random.randn((16, 8), seed=10))
loss = ((var - T.ones_like(var)) ** 2).sum()
g = symjax.gradients(loss, [var])
opt = symjax.optimizers.SGD(loss, 0.01, params=var)
f = symjax.function(outputs=loss, updates=opt.updates)
for i in range(10):
print(f())
# 240.96829
# 231.42595
# 222.26149
# 213.45993
# 205.00691
# 196.88864
# 189.09186
# 181.60382
# 174.41231
# 167.50558
While/Map/Scan¶
An important part of many implementations resides in the use of for/while loops
and in scans, which allow to maintain and update an additional quantity through
the iterations. In SymJAX, those operators are different from the Jax ones and
closers to the Theano ones as they provide an explicit sequences
and
non_sequences
argument. Here are a few examples below:
import symjax as sj
import symjax.tensor as T
w = T.Variable(1.0, dtype="float32")
u = T.Placeholder((), "float32")
out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u])
f = sj.function(u, outputs=out, updates={w: w + 1})
print(f(2))
# [0, 1, 2]
print(f(2))
# [0, 0, 0]
print(f(0))
# [0, -3, -6]
w.reset()
out = T.map(lambda a, w, u: w * a * u, [T.range(3)], non_sequences=[w, u])
g = sj.gradients(out.sum(), [w])[0]
f = sj.function(u, outputs=g)
print(f(0))
# 0
print(f(1))
# 3
out = T.map(lambda a, b: a * b, [T.range(3), T.range(3)])
f = sj.function(outputs=out)
print(f())
# [0, 1, 4]
w.reset()
v = T.Placeholder((), "float32")
out = T.while_loop(
lambda i, u: i[0] + u < 5,
lambda i: (i[0] + 1.0, i[0] ** 2),
(w, 1.0),
non_sequences_cond=[v],
)
f = sj.function(v, outputs=out)
print(f(0))
# 5, 16
print(f(2))
# [3, 4]
the use of the non_sequences
argument allows to keep track of the internal
function dependencies without requiring to execute the function. Hence all
tensors used inside a function should be part of the sequences
or
non_sequences
op inputs.
Variable batch length (shape)¶
In many applications it is required to have length varying inputs to a compiled
SymJAX function. This can be done by expliciting setting the shape of the corresponding
Placeholders
to 0 (this will likely change in the future) as demonstrated
below:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import symjax
x = symjax.tensor.Placeholder((0, 2), "float32")
w = symjax.tensor.Variable(1.0, dtype="float32")
p = x.sum(1)
f = symjax.function(x, outputs=p, updates={w: x.sum()})
print(f(np.ones((1, 2))))
print(w.value)
print(f(np.ones((2, 2))))
print(w.value)
# [2.]
# 2.0
# [2. 2.]
# 4.0
x = symjax.tensor.Placeholder((0, 2), "float32")
y = symjax.tensor.Placeholder((0,), "float32")
w = symjax.tensor.Variable((1, 1), dtype="float32")
loss = ((x.dot(w) - y) ** 2).mean()
g = symjax.gradients(loss, [w])[0]
other_g = symjax.gradients(x.dot(w).sum(), [w])[0]
f = symjax.function(x, y, outputs=loss, updates={w: w - 0.1 * g})
other_f = symjax.function(x, outputs=other_g)
for i in range(10):
print(f(np.ones((i + 1, 2)), -1 * np.ones(i + 1)))
print(other_f(np.ones((i + 1, 2))))
# 9.0
# [1. 1.]
# 3.2399998
# [2. 2.]
# 1.1663998
# [3. 3.]
# 0.419904
# [4. 4.]
# 0.15116541
# [5. 5.]
# 0.05441956
# [6. 6.]
# 0.019591037
# [7. 7.]
# 0.007052775
# [8. 8.]
# 0.0025389965
# [9. 9.]
# 0.0009140394
# [10. 10.]
in the backend, SymJAX automatically jit the overall (vmapped) functions for optimal performances.
Graph visualization¶
Similarly to Theano, it is possible to display the computational graph of the code written as follows:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = "Randall Balestriero"
import symjax.tensor as T
import matplotlib.pyplot as plt
from symjax.viz import compute_graph
x = T.random.randn((10,), name="x")
y = T.random.randn((10,), name="y")
z = T.random.randn((10,), name="z")
w = T.Variable(T.ones(1), name="w")
out = (x + y).sum() * w + z.sum()
graph = compute_graph(out)
graph.draw("file.png", prog="dot")
import matplotlib.image as mpimg
img = mpimg.imread("file.png")
plt.figure(figsize=(15, 5))
imgplot = plt.imshow(img)
plt.xticks()
plt.yticks()
plt.tight_layout()

Clone: one line multipurpose graph replacement¶
In most current packages, the ability to perform an already define computation graph but with altered nodes is cumbersone. Some specific involve the use of layers as in Keras where one can feed any value hence allow to compute a feedforward pass without much changes but if one had to replace a specific variable or more complex part of a graph no tools are available. In Theano, the clone function allowed to do such thing and it implemented in SymJAX as well. As per the below example, it is clear how the clone utility allows to get an already defined computational graph and replace any subgraph in it with another based on a node->node mapping:
import numpy as np
import symjax
import symjax.tensor as T
# we create a simple mapping with 2 matrix multiplications interleaved
# with nonlinearities
x = T.Placeholder((8,), "float32")
w_1 = T.Variable(T.random.randn((16, 8)))
w_2 = T.Variable(T.random.randn((2, 16)))
# the output can be computed easily as
output = w_2.dot(T.relu(w_1.dot(x)))
# now suppose we also wanted the same mapping but with a noise input
epsilon = T.random.randn((8,))
output_noisy = output.clone({x: x + epsilon})
f = symjax.function(x, outputs=[output, output_noisy])
for i in range(10):
print(f(np.ones(8)))
# [array([-14.496595, 8.7136 ], dtype=float32), array([-11.590391 , 4.7543654], dtype=float32)]
# [array([-14.496595, 8.7136 ], dtype=float32), array([-30.038504, 26.758451], dtype=float32)]
# [array([-14.496595, 8.7136 ], dtype=float32), array([-19.214798, 19.600328], dtype=float32)]
# [array([-14.496595, 8.7136 ], dtype=float32), array([-12.927457, 10.457445], dtype=float32)]
# [array([-14.496595, 8.7136 ], dtype=float32), array([-19.486668, 17.367273], dtype=float32)]
# [array([-14.496595, 8.7136 ], dtype=float32), array([-31.634314, 24.837488], dtype=float32)]
# [array([-14.496595, 8.7136 ], dtype=float32), array([-19.756075, 12.330083], dtype=float32)]
# [array([-14.496595, 8.7136 ], dtype=float32), array([-38.9738 , 31.588022], dtype=float32)]
# [array([-14.496595, 8.7136 ], dtype=float32), array([-19.561726, 12.192366], dtype=float32)]
# [array([-14.496595, 8.7136 ], dtype=float32), array([-33.110832, 30.104563], dtype=float32)]
Scopes, Operations/Variables/Placeholders naming and accessing¶
Accessing, naming variables, operations and placeholders. This is done in a similar way as in the vanilla Tensorflow form with scopes and EVERY of the variable/placeholder/operation is named and located with a unique identifier (name) per scope. If during creation both have same names, the original name is augmented with an underscore and interger number, here is a brief example:
import symjax
import symjax.tensor as T
# scope/graph naming and accessing
value1 = T.Variable(T.ones((1,)))
value2 = T.Variable(T.zeros((1,)))
g = symjax.Graph("special")
with g:
value3 = T.Variable(T.zeros((1,)))
value4 = T.Variable(T.zeros((1,)))
result = value3 + value4
h = symjax.Graph("inversion")
with h:
value5 = T.Variable(T.zeros((1,)))
value6 = T.Variable(T.zeros((1,)))
value7 = T.Variable(T.zeros((1,)), name="w")
print(g.variables)
# {'unnamed_variable': Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/),
# 'unnamed_variable_1': Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/)}
print(h.variables)
# {'unnamed_variable': Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/),
# 'unnamed_variable_1': Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/),
# 'w': Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)}
print(h.variable("w"))
# Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)
# now suppose that we did not hold the value for the graph g/h, we can still
# recover a variable based on the name AND the scope
print(symjax.get_variables("/special/inversion/w"))
# Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)
# now if the exact scope name is not know, it is possible to use smart indexing
# for example suppose we do not remember, then we can get all variables named
# 'w' among scopes
print(symjax.get_variables("*/w"))
# Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)
# if only part of the scope is known, all the variables of a given scope can
# be retreived
print(symjax.get_variables("/special/*"))
# [Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/),
# Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/),
# Variable(name=unnamed_variable, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/),
# Variable(name=unnamed_variable_1, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/),
# Variable(name=w, shape=(1,), dtype=float32, trainable=True, scope=/special/inversion/)]
print(symjax.get_ops("*add"))
# Op(name=add, shape=(1,), dtype=float32, scope=/special/)
Graph Saving and Loading¶
An important feature of SymJAX is the easiness to reset, save, load variables. This is crucial in order to save a model and being to reloaded (in a possibly different script) to keep using it. In our case, a computational graph is completely defined by its structure and the values of the persistent nodes (the variables). Hence, it is enough to save the variables. This is done in a very explicit manner using the numpy.savez utility where the saved file can be accessed from any other script, variables can be loaded, accessed, even modified, and then reloaded inside the computational graph. Here is a brief example:
import symjax
import symjax.tensor as T
g = symjax.Graph("model1")
with g:
learning_rate = T.Variable(T.ones((1,)))
with symjax.Graph("layer1"):
W1 = T.Variable(T.zeros((1,)), name="W")
b1 = T.Variable(T.zeros((1,)), name="b")
with symjax.Graph("layer2"):
W2 = T.Variable(T.zeros((1,)), name="W")
b2 = T.Variable(T.zeros((1,)), name="b")
# define an irrelevant loss function involving the parameters
loss = (W1 + b1 + W2 + b2) * learning_rate
# and a train/update function
train = symjax.function(
outputs=loss, updates={W1: W1 + 1, b1: b1 + 2, W2: W2 + 2, b2: b2 + 3}
)
# pretend we train for a while
for i in range(4):
print(train())
# [0.]
# [8.]
# [16.]
# [24.]
# now say we wanted to reset the variables and retrain, we can do
# either with g, as it contains all the variables
g.reset()
# or we can do
symjax.reset_variables("*")
# or if we wanted to only reset say variables from layer2
symjax.reset_variables("*layer2*")
# now that all has been reset, let's retrain for a while
# pretend we train for a while
for i in range(2):
print(train())
# [0.]
# [8.]
# now resetting is nice, but we might want to save the model parameters, to
# keep training later or do some other analyses. We can do so as follows:
g.save_variables("model1_saved")
# this would save all variables as they are contained in g. Now say we want to
# only save the second layer variables, if we had saved the graph variables as
# say h we could just do ``h.save('layer1_saved')''
# but while we do not have it, we recall the scope of it, we can thus do
symjax.save_variables("*layer1*", "layer1_saved")
# and for the entire set of variables just do
symjax.save_variables("*", "model1_saved")
# now suppose that after training or after resetting
symjax.reset_variables("*")
# one wants to recover the saved weights, one can do
symjax.load_variables("*", "model1_saved")
# in that case all variables will be reloaded as they were in model1_saved,
# if we used symjax.load('*', 'layer1_saved'), an error would occur as not all
# variables are present in this file, one should instead do
# (in this case, this is redundant as we loaded everything up above)
symjax.load_variables("*layer1*", "layer1_saved")
# we can now pretend to keep training our model form its saved state
for i in range(2):
print(train())
# [16.]
# [24.]
Wrap: Jax function/computation to SymJAX Op¶
The computation in Jax is done eagerly similarly to TF2 and PyTorch. In SymJAX the computational graph definition is done a priori with symbolic variables. That is, no actual computations are done during the graph definition, once done the graph is compiled with proper inputs/outputs/updates to provide the user with a compiled function executing the graph. This graph thus involves various operations, one can define its own in the two following way. First by combining the already existing SymJAX function, the other by creating it in pure Jax and then wrapping it into a SymJAX symbolic operation as demonstrated below.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import symjax as sj
import jax.numpy as jnp
__author__ = "Randall Balestriero"
# suppose we want to compute the mean-squared error between two vectors
x = sj.tensor.random.normal((10,))
y = sj.tensor.zeros((10,))
# one way is to do so by combining SymJAX functions as
mse = ((x - y) ** 2).sum()
# notice that the basic operators are overloaded and implicitly call SymJAX ops
# another solution is to create a new SymJAX Op from a jax computation as
# follows
def mse_jax(x, y):
return jnp.sum((x - y) ** 2)
# wrap the jax computation into a SymJAX Op that can then be used as any
# SymJAX function
mse_op = sj.tensor.jax_wrap(mse_jax)
also_mse = mse_op(x, y)
print(also_mse)
# Tensor(Op=mse_jax, shape=(), dtype=float32)
# ensure that both are equivalent
f = sj.function(outputs=[mse, also_mse])
print(f())
# [array(6.0395503, dtype=float32), array(6.0395503, dtype=float32)]
A SymJAX computation graph can not be partially defined with Jax computation, the above thus provides an easy way to wrap Jax computations into a SymJAX Op which can then be put into the graph as any other SymJAX provided Ops.
Wrap: Jax class to SymJAX class¶
One might have defined a Jax class, with a constructor possibly taking some constant values and some jax arrays, performing some computations, setting some attributes, and then interacting with those attributes when calling the class methods. It would be particularly easy to pair such already implemented classes with SymJAX computation graph. This can be done as follows:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import symjax as sj
import symjax.tensor as T
import jax.numpy as jnp
__author__ = "Randall Balestriero"
class product:
def __init__(self, W, V=1):
self.W = jnp.square(V * W * (W > 0).astype("float32"))
self.ndim = self.compute_ndim()
def feed(self, x):
return jnp.dot(self.W, x)
def compute_ndim(self):
return self.W.shape[0] * self.W.shape[1]
wrapped = T.wrap_class(product, method_exceptions=["compute_ndim"])
a = wrapped(T.zeros((10, 10)), V=T.ones((10, 10)))
x = T.random.randn((10, 100))
print(a.W)
# (Tensor: name=function[0], shape=(10, 10), dtype=float32)
print(a.feed(x))
# Op(name=feed, shape=(10, 100), dtype=float32, scope=/)
f = sj.function(outputs=a.feed(x))
f()
As can be seen, there is some restrictions. First, the behavior inside the constructor of the original class should be fixed as it will be executed once by the wrapper in order to map the constructor computations into SymJAX. Second, any jax array update done internally will break the conversion as such operations are only allowed for Variables in SymJAX, hence some care is needed. More flexibility will be provided in future versions.
Amortized Variational Inference¶
We briefly describe some key components of SymJAX.
The principles of AVI¶
Reinforcement Learning¶
We briefly describe some key components of SymJAX.
Notations¶
immediate reward \(r_t\) is observed from the environment at state \(𝑠_{t}\) by performing action \(𝑎_{t}\)
total discounted reward \(𝐺_t(γ)\) often abbreviated as \(𝐺_t\) and defined as
\[𝐺_t = Σ_{t'=t+1}^{T}γ^{t'-t-1}r_t\]action-value function \(Q_{π}(𝑠,𝑎)\) is the expected return starting from state 𝑠, following policy 𝜋 and taking action 𝑎
\[Q_{π}(𝑠,𝑎)=E_{π}[𝐺_{t}|𝑠_{t} = 𝑠,𝑎_{t}=𝑎]\]state-value function \(V_{π}(𝑠)\) is the expected return starting from state 𝑠 following policy 𝜋 as in
\[\begin{split}V_{π}(𝑠)&=E_{π}[𝐺_{t}|𝑠_{t} = 𝑠]\\ &=Σ_{𝑎 ∈ 𝐴}π(𝑎|𝑠)Q_{π}(𝑠,𝑎)\end{split}\]in a deterministic policy setting, one has directly \(V_{π}(𝑠)=Q_{π}(𝑠,π(𝑠))\). in a greedy policy one might have \(V^{*}_{π}(𝑠)=\max_{𝑎∈𝐴}Q_{π}(𝑠,𝑎)\) where \(V^{*}_{π}\) is the best value of a state if you could follow an (unknown) optimum policy.
TD-error
- \(𝛿_t=r_t+γQ(𝑠_{t+1},𝑎_{t+1})-Q(𝑠_{t},𝑎_{t})\)
advantage value : how much better it is to take a specific action compared to the average at the given state
\[\begin{split}A(s_t,𝑎_t)&=Q(𝑠_t,𝑎_t)-V(𝑠_t)\\ A(𝑠_t,𝑎_t)&=E[r_{t+1}+ γ V(𝑠_{t+1})]-V(𝑠_t)\\ A(𝑠_t,𝑎_t)&=r_{t+1}+ γ V(𝑠_{t+1})-V(𝑠_t)\end{split}\]The formulation of policy gradients with advantage functions is extremely common, and there are many different ways of estimating the advantage function used by different algorithms.
probability of a trajectory \(τ=(s_0,a_0,...,s_{T+1})\) is given by
\[p(τ|θ)=p_{0}(s_0)Π_{t=0}^{T}p(𝑠_{t+1}|𝑠_{t},𝑎_{t})π_{0}(𝑎_{t}|𝑠_{t})\]
Policy gradient and REINFORCE¶
Policy gradient and REINFORCE : Policy gradient methods are ubiquitous in model free reinforcement learning algorithms — they appear frequently in reinforcement learning algorithms, especially so in recent publications. The policy gradient method is also the “actor” part of Actor-Critic methods. Its implementation (REINFORCE) is also known as Monte Carlo Policy Gradients. Policy gradient methods update the probability distribution of actions \(π(a|s)\) so that actions with higher expected reward have a higher probability value for an observed state.
- needs to reach end of episode to compute discounted rewards and train the model
- only needs an actor (a.k.a policy) network
- noisy gradients and high variance => instability and slow convergence
- fails for trajectories having a cumulative reward of 0
Tricks¶
- normalizing discounter rewards (or advantages) : In practice it can can also be important to normalize these. For example, suppose we compute [discounted cumulative reward] for all of the 20,000 actions in the batch of 100 Pong game rollouts above. One good idea is to “standardize” these returns (e.g. subtract mean, divide by standard deviation) before we plug them into backprop. This way we’re always encouraging and discouraging roughly half of the performed actions. Mathematically you can also interpret these tricks as a way of controlling the variance of the policy gradient estimator.
Gallery¶
Basic examples¶
Introductory examples that teach how to use SymJAX.
Note
Click here to download the full example code
Basic image resampling and alignment¶
demonstration on how to perform basic image preprocessing
import matplotlib.pyplot as plt
import numpy as np
import symjax
image1 = np.random.rand(3, 2, 4)
image2 = np.random.rand(3, 4, 2)
image3 = np.random.rand(3, 4, 4)
all_images = [image1, image2, image3]
images = symjax.data.utils.resample_images(all_images, (6, 6))
fig = plt.figure(figsize=(8, 3))
for i in range(3):
plt.subplot(2, 3, i + 1)
plt.imshow(all_images[i].transpose(1, 2, 0), aspect="auto", vmax=10, cmap="jet")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 3, i + 4)
plt.imshow(images[i].transpose(1, 2, 0), aspect="auto", vmax=10, cmap="jet")
plt.xticks([])
plt.yticks([])
plt.tight_layout()
Total running time of the script: ( 0 minutes 0.246 seconds)
Note
Click here to download the full example code
Basic gradient descent (and reset)¶
demonstration on how to compute a gradient and apply a basic gradient update rule to minimize some loss function
Out:
/home/vrael/anaconda3/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
import symjax
import symjax.tensor as T
import matplotlib.pyplot as plt
# GRADIENT DESCENT
z = T.Variable(3.0, dtype="float32")
loss = (z - 1) ** 2
g_z = symjax.gradients(loss, [z])[0]
symjax.current_graph().add_updates({z: z - 0.1 * g_z})
train = symjax.function(outputs=[loss, z], updates=symjax.get_updates())
losses = list()
values = list()
for i in range(200):
if (i + 1) % 50 == 0:
symjax.reset_variables("*")
a, b = train()
losses.append(a)
values.append(b)
plt.figure()
plt.subplot(121)
plt.plot(losses, "-x")
plt.ylabel("loss")
plt.xlabel("number of gradient updates")
plt.subplot(122)
plt.plot(values, "-x")
plt.axhline(1, c="red")
plt.ylabel("value")
plt.xlabel("number of gradient updates")
plt.tight_layout()
Total running time of the script: ( 0 minutes 0.278 seconds)
Note
Click here to download the full example code
Basic 1d upsampling¶
In this example we demonstrate how to employ the utility functions from
symjax.tensor.interpolation
which can be used for upsampling
import matplotlib.pyplot as plt
import symjax
import symjax.tensor as T
import numpy as np
w = T.Placeholder((3,), "float32", name="w")
w_interp1 = T.interpolation.upsample_1d(w, repeat=4, mode="nearest")
w_interp2 = T.interpolation.upsample_1d(
w, repeat=4, mode="linear", boundary_condition="mirror"
)
w_interp3 = T.interpolation.upsample_1d(
w, repeat=4, mode="linear", boundary_condition="periodic"
)
w_interp4 = T.interpolation.upsample_1d(w, repeat=4)
f = symjax.function(w, outputs=[w_interp1, w_interp2, w_interp3, w_interp4])
samples = f(np.array([1, 2, 3]))
fig = plt.figure(figsize=(6, 6))
plt.subplot(411)
plt.plot(samples[0], "xg", linewidth=3, markersize=15)
plt.plot([0, 5, 10], [1, 2, 3], "ok", alpha=0.5)
plt.title("nearest-periodic")
plt.xticks([])
plt.subplot(412)
plt.plot(samples[1], "xg", linewidth=3, markersize=15)
plt.plot([0, 5, 10], [1, 2, 3], "ok", alpha=0.5)
plt.title("linear-mirror")
plt.xticks([])
plt.subplot(413)
plt.plot(samples[2], "xg", linewidth=3, markersize=15)
plt.plot([0, 5, 10], [1, 2, 3], "ok", alpha=0.5)
plt.title("linear-periodic")
plt.xticks([])
plt.subplot(414)
plt.plot(samples[3], "xg", linewidth=3, markersize=15)
plt.plot([0, 5, 10], [1, 2, 3], "ok", alpha=0.5)
plt.title("constant-0")
plt.tight_layout()
Total running time of the script: ( 0 minutes 0.825 seconds)
Note
Click here to download the full example code
Pixel interpolation learning¶
we demonstrate in this toy example how to use the coordinate interpolation techniques with learnable parameter to map one image to another one simply by interpolation the original image values from learned coordinates
Out:
... mnist.pkl.gz already exists
Loading mnist
Dataset mnist loaded in 0.67s.
/home/vrael/anaconda3/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
[[ 0. 0. 0. ... 27. 27. 27.]
[ 0. 1. 2. ... 25. 26. 27.]]
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/tight_layout.py:345: UserWarning: tight_layout not applied: number of columns in subplot specifications mustbe multiples of one another.
warnings.warn('tight_layout not applied: '
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:445: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
% get_backend())
import symjax
import symjax.tensor as T
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ["DATASET_PATH"] = "/home/vrael/DATASETS/"
symjax.current_graph().reset()
mnist = symjax.data.mnist()
# 2d image
images = mnist["train_set/images"][mnist["train_set/labels"] == 2][:2, 0]
images /= images.max()
np.random.seed(0)
coordinates = T.meshgrid(T.range(28), T.range(28))
coordinates = T.Variable(
T.stack([coordinates[1].flatten(), coordinates[0].flatten()]).astype("float32")
)
interp = T.interpolation.map_coordinates(images[0], coordinates, order=1).reshape(
(28, 28)
)
loss = ((interp - images[1]) ** 2).mean()
lr = symjax.nn.schedules.PiecewiseConstant(0.05, {5000: 0.01, 8000: 0.005})
symjax.nn.optimizers.Adam(loss, lr)
train = symjax.function(outputs=loss, updates=symjax.get_updates())
rec = symjax.function(outputs=interp)
losses = list()
original = coordinates.value
for i in range(100):
losses.append(train())
reconstruction = rec()
after = coordinates.value
plt.figure(figsize=(12, 6))
plt.subplot(311)
plt.semilogy(losses, "-x")
plt.ylabel("loss (l2)")
plt.title("Training loss")
plt.subplot(334)
plt.imshow(images[0], aspect="auto", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("input")
plt.subplot(335)
plt.imshow(images[1], aspect="auto", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("target")
plt.subplot(336)
plt.imshow(reconstruction, aspect="auto", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("reconstruction")
print(original)
plt.subplot(325)
plt.scatter(original[1][::-1], original[0], s=3)
plt.xticks([])
plt.yticks([])
plt.title("Initialized coordinates")
plt.subplot(326)
plt.scatter(after[1][::-1], after[0], s=3)
plt.xticks([])
plt.yticks([])
plt.title("Learned coordinates")
plt.tight_layout()
plt.show()
Total running time of the script: ( 0 minutes 1.922 seconds)
Note
Click here to download the full example code
Basic (linear) deconvolution filter learning¶
demonstration on how to learn a deconvolutional filter based on some flavors of gradietn descent assuming we know the true output
Out:
... mnist.pkl.gz already exists
Loading mnist
Dataset mnist loaded in 0.95s.
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:445: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
% get_backend())
import symjax
import symjax.tensor as T
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import convolve2d
import os
os.environ["DATASET_PATH"] = "/home/vrael/DATASETS/"
symjax.current_graph().reset()
true_image = symjax.data.mnist()
# 2d image
true_image = true_image["train_set/images"][0, 0]
true_image /= true_image.max()
np.random.seed(0)
noisy_image = convolve2d(true_image, np.random.randn(5, 5) / 5, "same")
# GRADIENT DESCENT
filter_1 = T.Variable(np.random.randn(8, 8) / 8, dtype="float32")
filter_2 = T.Variable(filter_1.value, dtype="float32")
reconstruction_1 = T.signal.convolve2d(noisy_image, filter_1, "same")
reconstruction_2 = T.signal.convolve2d(noisy_image, filter_2, "same")
loss1 = T.abs(reconstruction_1 - true_image).mean()
loss2 = (T.abs(reconstruction_2 - true_image) ** 2).mean()
lr = symjax.nn.schedules.PiecewiseConstant(0.05, {5000: 0.01, 8000: 0.005})
symjax.nn.optimizers.Adam(loss1 + loss2, lr)
train = symjax.function(outputs=[loss1, loss2], updates=symjax.get_updates())
rec = symjax.function(outputs=[reconstruction_1, reconstruction_2])
losses_1 = list()
losses_2 = list()
for i in range(10000):
losses = train()
losses_1.append(losses[0])
losses_2.append(losses[1])
reconstruction_1, reconstruction_2 = rec()
plt.figure(figsize=(12, 6))
plt.subplot(221)
plt.semilogy(losses_1, "-x")
plt.ylabel("log-loss (l1)")
plt.xlabel("number of gradient updates")
plt.subplot(222)
plt.semilogy(losses_2, "-x")
plt.ylabel("log-loss (l2)")
plt.xlabel("number of gradient updates")
plt.subplot(245)
plt.imshow(reconstruction_1, aspect="auto", origin="lower", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("reconstruction (l1)")
plt.subplot(246)
plt.imshow(reconstruction_2, aspect="auto", origin="lower", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("reconstruction (l2)")
plt.subplot(247)
plt.imshow(true_image, aspect="auto", origin="lower", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("True image")
plt.subplot(248)
plt.imshow(noisy_image, aspect="auto", origin="lower", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("Convolved image")
plt.tight_layout()
plt.show()
Total running time of the script: ( 0 minutes 32.168 seconds)
Note
Click here to download the full example code
Basic scan/loops examples¶
In this example we demonstrate how to employ the symjax.tensor.scan()
and other similar functions.
We first demonstrate how to compute a moving average with
symjax.tensor.scan()
We then demonstrate how to do a simple for loop and then a while loop.
import matplotlib.pyplot as plt
import symjax
import symjax.tensor as T
import numpy as np
# suppose we are given a time-serie and we want to compute an
# exponential moving average, we also use the EMA coefficient alpha
# based on the user input
signal = T.Placeholder((512,), "float32", name="signal")
alpha = T.Placeholder((), "float32", "alpha")
# to use a scan function one needs a function to be applied at each step
# in our case an exponential moving average function
# this function should output the new value of the carry as well as an
# additional output, in our case, the carry (EMA) is also what we want to
# output at each tiem step
def fn(at, xt, alpha):
# the function first input is the carry, then are the (ordered)
# values from sequences and non_sequences similar to Theano
EMA = at * alpha + (1 - alpha) * xt
return EMA, EMA
# the scan function will return the carry at each time steps (first arg.)
# as well as the last one, we also need to provide an init.
last_ema, all_ema = T.scan(
fn, init=signal[0], sequences=[signal[1:]], non_sequences=[alpha]
)
f = symjax.function(signal, alpha, outputs=all_ema)
# generate a signal
x = np.cos(np.linspace(-3, 3, 512)) + np.random.randn(512) * 0.2
fig, ax = plt.subplots(3, 1, figsize=(3, 9))
for k, alpha in enumerate([0.1, 0.5, 0.9]):
ax[k].plot(x, c="b")
ax[k].plot(f(x, alpha), c="r")
ax[k].set_title("EMA: {}".format(alpha))
ax[k].set_xticks([])
ax[k].set_yticks([])
plt.tight_layout()
# Now let's do a simple map for which we can compute a simple
# moving average. The for loop will consist of moving a window and
# average the values on that window
# in that case the function also needs to be defined
def fn(window):
# the function first input is the current index of the for loop
# the other inputs are the (ordered) sequences and non_sequnces
# values
return T.mean(window)
windowed = T.extract_signal_patches(signal, 10)
output = T.map(fn, sequences=[windowed])
f = symjax.function(signal, outputs=output)
fig, ax = plt.subplots(1, 1, figsize=(5, 2))
ax.plot(x, c="b")
ax.plot(f(x), c="r")
ax.set_title("SMA: 10")
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
Total running time of the script: ( 0 minutes 0.397 seconds)
Note
Click here to download the full example code
Basic image transform (TPS/affine)¶
In this example we demonstrate how to employ the utility functions from
symjax.tensor.interpolation.affine_transform
and
symjax.tensor.interpolation.thin_plate_spline
to transform/interpolate images
Out:
/home/vrael/SymJAX/symjax/tensor/interpolation.py:548: RuntimeWarning: divide by zero encountered in log
log_r_2 = np.log(r_2)
... mnist.pkl.gz already exists
Loading mnist
Dataset mnist loaded in 1.07s.
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:445: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
% get_backend())
import matplotlib.pyplot as plt
import symjax
import symjax.tensor as T
import numpy as np
x = T.Placeholder((10, 1, 28, 28), "float32")
points = T.Placeholder((10, 2 * 16), "float32")
thetas = T.Placeholder((10, 6), "float32")
affine = T.interpolation.affine_transform(x, thetas)
tps = T.interpolation.thin_plate_spline(x, points)
f = symjax.function(x, thetas, outputs=affine)
g = symjax.function(x, points, outputs=tps)
data = symjax.data.mnist()["train_set/images"][:10]
plt.figure(figsize=(20, 6))
plt.subplot(2, 8, 1)
plt.imshow(data[0][0])
plt.title("original")
plt.ylabel("TPS")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 2)
points = np.zeros((10, 2 * 16))
plt.imshow(g(data, points)[0][0])
plt.title("identity")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 3)
points = np.zeros((10, 2 * 16))
points[:, :16] += 0.3
plt.imshow(g(data, points)[0][0])
plt.title("x translation")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 4)
points = np.zeros((10, 2 * 16))
points[:, 16:] += 0.3
plt.imshow(g(data, points)[0][0])
plt.title("y translation")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 5)
points = np.random.randn(10, 2 * 16) * 0.2
plt.imshow(g(data, points)[0][0])
plt.title("random")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 6)
points = np.meshgrid(np.linspace(-1, 1, 4), np.linspace(-1, 1, 4))
points = np.concatenate([points[0].reshape(-1), points[1].reshape(-1)]) * 0.4
points = points[None] * np.ones((10, 1))
plt.imshow(g(data, points)[0][0])
plt.title("zoom")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 7)
points = np.meshgrid(np.linspace(-1, 1, 4), np.linspace(-1, 1, 4))
points = np.concatenate([points[0].reshape(-1), points[1].reshape(-1)]) * -0.2
points = points[None] * np.ones((10, 1))
plt.imshow(g(data, points)[0][0])
plt.title("zoom")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 8)
points = np.zeros((10, 2 * 16))
points[:, 1::2] -= 0.1
points[:, ::2] += 0.1
plt.imshow(g(data, points)[0][0])
plt.title("blob")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 9)
plt.imshow(data[0][0])
plt.title("original")
plt.ylabel("Affine")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 10)
points = np.zeros((10, 6))
points[:, 0] = 1
points[:, 4] = 1
plt.imshow(f(data, points)[0][0])
plt.title("identity")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 11)
points = np.zeros((10, 6))
points[:, 0] = 1
points[:, 4] = 1
points[:, 2] = 0.2
plt.imshow(f(data, points)[0][0])
plt.title("x translation")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 12)
points = np.zeros((10, 6))
points[:, 0] = 1
points[:, 4] = 1
points[:, 5] = 0.2
plt.imshow(f(data, points)[0][0])
plt.title("y translation")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 13)
points = np.zeros((10, 6))
points[:, 0] = 1
points[:, 4] = 1
points[:, 1] = 0.4
plt.imshow(f(data, points)[0][0])
plt.title("skewness x")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 14)
points = np.zeros((10, 6))
points[:, 0] = 1.4
points[:, 4] = 1.4
plt.imshow(f(data, points)[0][0])
plt.title("zoom")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 15)
points = np.zeros((10, 6))
points[:, 0] = 1.4
points[:, 4] = 1.0
plt.imshow(f(data, points)[0][0])
plt.title("zoom x")
plt.xticks([])
plt.yticks([])
plt.subplot(2, 8, 16)
points = np.zeros((10, 6))
points[:, 0] = 1
points[:, 4] = 1
points[:, 3] = 0.4
plt.imshow(f(data, points)[0][0])
plt.title("skewness y")
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.show()
Total running time of the script: ( 0 minutes 3.074 seconds)
Note
Click here to download the full example code
Computation times¶
In this example we demonstrate how to perform a simple optimization with Adam in TF and SymJAX and compare the computation time
Out:
False 10
TF1
WARNING:tensorflow:From /home/vrael/anaconda3/lib/python3.7/site-packages/tensorflow/python/compat/v2_compat.py:96: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version.
Instructions for updating:
non-resource variables are not supported in the long term
SJ
/home/vrael/anaconda3/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
False 100
TF1
SJ
False 200
TF1
SJ
False 400
TF1
SJ
False 1000
TF1
SJ
True 10
TF1
SJ
True 100
TF1
SJ
True 200
TF1
SJ
True 400
TF1
SJ
True 1000
TF1
SJ
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:445: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
% get_backend())
import matplotlib.pyplot as plt
import symjax
import symjax.tensor as T
from symjax.nn import optimizers
import numpy as np
import time
lr = 0.01
BS = 10000
D = 1000
X = np.random.randn(BS, D).astype("float32")
Y = X.dot(np.random.randn(D, 1).astype("float32")) + 2
def TF1(x, y, N, preallocate=False):
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_v2_behavior()
if preallocate:
tf_input = tf.constant(x)
tf_output = tf.constant(y)
else:
tf_input = tf.placeholder(dtype=tf.float32, shape=[BS, D])
tf_output = tf.placeholder(dtype=tf.float32, shape=[BS, 1])
np.random.seed(0)
tf_W = tf.Variable(np.random.randn(D, 1).astype("float32"))
tf_b = tf.Variable(
np.random.randn(
1,
).astype("float32")
)
tf_loss = tf.reduce_mean((tf.matmul(tf_input, tf_W) + tf_b - tf_output) ** 2)
train_op = tf.train.AdamOptimizer(lr).minimize(tf_loss)
# initialize session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
if not preallocate:
t = time.time()
for i in range(N):
sess.run(train_op, feed_dict={tf_input: x, tf_output: y})
else:
t = time.time()
for i in range(N):
sess.run(train_op)
return time.time() - t
def TF2(x, y, N, preallocate=False):
import tensorflow as tf
optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
np.random.seed(0)
tf_W = tf.Variable(np.random.randn(D, 1).astype("float32"))
tf_b = tf.Variable(
np.random.randn(
1,
).astype("float32")
)
@tf.function
def train(tf_input, tf_output):
with tf.GradientTape() as tape:
tf_loss = tf.reduce_mean(
(tf.matmul(tf_input, tf_W) + tf_b - tf_output) ** 2
)
grads = tape.gradient(tf_loss, [tf_W, tf_b])
optimizer.apply_gradients(zip(grads, [tf_W, tf_b]))
return tf_loss
if preallocate:
x = tf.constant(x)
y = tf.constant(y)
t = time.time()
for i in range(N):
l = train(x, y)
return time.time() - t
def SJ(x, y, N, preallocate=False):
symjax.current_graph().reset()
sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])
np.random.seed(0)
sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
sj_b = T.Variable(
np.random.randn(
1,
).astype("float32")
)
sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean()
optimizers.Adam(sj_loss, lr)
train = symjax.function(sj_input, sj_output, updates=symjax.get_updates())
if preallocate:
import jax
x = jax.device_put(x)
y = jax.device_put(y)
t = time.time()
for i in range(N):
train(x, y)
return time.time() - t
values = []
Ns = [10, 100, 200, 400, 1000]
for pre in [False, True]:
for N in Ns:
print(pre, N)
print("TF1")
values.append(TF1(X, Y, N, pre))
# print("TF2")
# values.append(TF2(X, Y, N, pre))
print("SJ")
values.append(SJ(X, Y, N, pre))
values = np.array(values).reshape((2, len(Ns), 2))
for i, ls in enumerate(["-", "--"]):
for j, c in enumerate(["r", "g"]):
plt.plot(Ns, values[i, :, j], linestyle=ls, c=c, linewidth=3, alpha=0.8)
plt.legend(["TF1 no prealloc.", "SJ no prealloc.", "TF1 prealloc.", "SJ prealloc."])
plt.show()
Total running time of the script: ( 1 minutes 46.270 seconds)
Note
Click here to download the full example code
Adam TF and SymJAX¶
In this example we demonstrate how to perform a simple optimization with Adam in TF and SymJAX
Out:
Placeholder(name=x, shape=(), dtype=float32, scope=/) Op(name=true_divide, fn=true_divide, shape=(), dtype=float32, scope=/ExponentialMovingAverage/)
[Variable(name=EMA, shape=(), dtype=float32, trainable=False, scope=/ExponentialMovingAverage/), Placeholder(name=x, shape=(), dtype=float32, scope=/), Variable(name=num_steps, shape=(), dtype=int32, trainable=False, scope=/ExponentialMovingAverage/)]
Placeholder(name=x, shape=(), dtype=float32, scope=/) Op(name=add, fn=add, shape=(), dtype=float32, scope=/ExponentialMovingAverage/)
[Placeholder(name=x, shape=(), dtype=float32, scope=/), Variable(name=EMA, shape=(), dtype=float32, trainable=False, scope=/ExponentialMovingAverage/)]
0%| | 0/400 [00:00<?, ?it/s]
18%|#8 | 74/400 [00:00<00:00, 738.41it/s]
42%|####1 | 167/400 [00:00<00:00, 786.52it/s]
67%|######7 | 268/400 [00:00<00:00, 841.21it/s]
90%|######### | 362/400 [00:00<00:00, 868.31it/s]
100%|##########| 400/400 [00:00<00:00, 897.45it/s]
0%| | 0/400 [00:00<?, ?it/s]
0%| | 1/400 [00:00<02:14, 2.96it/s]
10%|# | 41/400 [00:00<01:25, 4.22it/s]
21%|##1 | 85/400 [00:00<00:52, 6.00it/s]
32%|###2 | 130/400 [00:00<00:31, 8.52it/s]
44%|####3 | 174/400 [00:00<00:18, 12.07it/s]
55%|#####4 | 218/400 [00:00<00:10, 17.04it/s]
66%|######6 | 266/400 [00:00<00:05, 23.97it/s]
78%|#######8 | 314/400 [00:01<00:02, 33.53it/s]
91%|######### | 363/400 [00:01<00:00, 46.51it/s]
100%|##########| 400/400 [00:01<00:00, 326.94it/s]
0%| | 0/400 [00:00<?, ?it/s]
18%|#8 | 74/400 [00:00<00:00, 734.15it/s]
44%|####4 | 176/400 [00:00<00:00, 800.57it/s]
70%|######9 | 278/400 [00:00<00:00, 854.12it/s]
95%|#########4| 379/400 [00:00<00:00, 895.55it/s]
100%|##########| 400/400 [00:00<00:00, 944.07it/s]
0%| | 0/400 [00:00<?, ?it/s]
0%| | 1/400 [00:00<02:10, 3.06it/s]
13%|#3 | 52/400 [00:00<01:19, 4.37it/s]
24%|##4 | 96/400 [00:00<00:48, 6.21it/s]
36%|###5 | 142/400 [00:00<00:29, 8.82it/s]
48%|####7 | 190/400 [00:00<00:16, 12.50it/s]
60%|###### | 240/400 [00:00<00:09, 17.67it/s]
72%|#######2 | 289/400 [00:00<00:04, 24.85it/s]
84%|########3 | 335/400 [00:01<00:01, 34.70it/s]
96%|#########5| 382/400 [00:01<00:00, 48.04it/s]
100%|##########| 400/400 [00:01<00:00, 342.04it/s]
0%| | 0/400 [00:00<?, ?it/s]
18%|#8 | 74/400 [00:00<00:00, 733.99it/s]
44%|####3 | 174/400 [00:00<00:00, 796.45it/s]
69%|######8 | 275/400 [00:00<00:00, 849.44it/s]
94%|#########3| 376/400 [00:00<00:00, 890.69it/s]
100%|##########| 400/400 [00:00<00:00, 933.98it/s]
0%| | 0/400 [00:00<?, ?it/s]
0%| | 1/400 [00:00<02:18, 2.87it/s]
12%|#1 | 46/400 [00:00<01:26, 4.09it/s]
22%|##2 | 90/400 [00:00<00:53, 5.82it/s]
34%|###4 | 137/400 [00:00<00:31, 8.27it/s]
44%|####3 | 175/400 [00:00<00:19, 11.71it/s]
56%|#####5 | 222/400 [00:00<00:10, 16.54it/s]
67%|######7 | 268/400 [00:00<00:05, 23.27it/s]
78%|#######8 | 314/400 [00:01<00:02, 32.53it/s]
89%|########8 | 355/400 [00:01<00:01, 44.94it/s]
100%|##########| 400/400 [00:01<00:00, 319.40it/s]
0%| | 0/1000 [00:00<?, ?it/s]
8%|8 | 81/1000 [00:00<00:01, 805.73it/s]
18%|#8 | 183/1000 [00:00<00:00, 859.05it/s]
28%|##8 | 285/1000 [00:00<00:00, 900.36it/s]
39%|###8 | 389/1000 [00:00<00:00, 937.17it/s]
48%|####8 | 483/1000 [00:00<00:00, 936.13it/s]
57%|#####7 | 574/1000 [00:00<00:00, 922.51it/s]
67%|######6 | 668/1000 [00:00<00:00, 926.86it/s]
76%|#######6 | 764/1000 [00:00<00:00, 934.54it/s]
86%|########6 | 862/1000 [00:00<00:00, 947.60it/s]
95%|#########5| 954/1000 [00:01<00:00, 936.36it/s]
100%|##########| 1000/1000 [00:01<00:00, 947.93it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 1/1000 [00:00<05:21, 3.11it/s]
4%|4 | 44/1000 [00:00<03:35, 4.43it/s]
9%|9 | 92/1000 [00:00<02:24, 6.30it/s]
13%|#3 | 133/1000 [00:00<01:37, 8.94it/s]
17%|#7 | 174/1000 [00:00<01:05, 12.65it/s]
22%|##1 | 219/1000 [00:00<00:43, 17.85it/s]
27%|##6 | 267/1000 [00:00<00:29, 25.10it/s]
31%|###1 | 313/1000 [00:01<00:19, 35.03it/s]
36%|###5 | 356/1000 [00:01<00:13, 48.33it/s]
40%|###9 | 398/1000 [00:01<00:09, 65.63it/s]
44%|####4 | 444/1000 [00:01<00:06, 88.34it/s]
49%|####8 | 487/1000 [00:01<00:04, 115.87it/s]
54%|#####3 | 537/1000 [00:01<00:03, 150.37it/s]
58%|#####8 | 585/1000 [00:01<00:02, 189.08it/s]
63%|######3 | 631/1000 [00:01<00:01, 227.50it/s]
68%|######7 | 676/1000 [00:01<00:01, 265.01it/s]
72%|#######2 | 721/1000 [00:01<00:00, 289.24it/s]
77%|#######6 | 767/1000 [00:02<00:00, 325.02it/s]
81%|########1 | 814/1000 [00:02<00:00, 356.74it/s]
86%|########5 | 859/1000 [00:02<00:00, 371.88it/s]
90%|######### | 905/1000 [00:02<00:00, 394.16it/s]
95%|#########4| 949/1000 [00:02<00:00, 402.38it/s]
100%|#########9| 995/1000 [00:02<00:00, 415.68it/s]
100%|##########| 1000/1000 [00:02<00:00, 383.87it/s]
0%| | 0/1000 [00:00<?, ?it/s]
8%|8 | 81/1000 [00:00<00:01, 808.12it/s]
18%|#7 | 179/1000 [00:00<00:00, 850.89it/s]
27%|##7 | 272/1000 [00:00<00:00, 872.83it/s]
36%|###6 | 365/1000 [00:00<00:00, 887.83it/s]
46%|####5 | 459/1000 [00:00<00:00, 902.70it/s]
55%|#####5 | 553/1000 [00:00<00:00, 911.91it/s]
65%|######4 | 647/1000 [00:00<00:00, 919.47it/s]
73%|#######3 | 733/1000 [00:00<00:00, 866.19it/s]
82%|########2 | 823/1000 [00:00<00:00, 874.13it/s]
92%|#########1| 916/1000 [00:01<00:00, 887.85it/s]
100%|##########| 1000/1000 [00:01<00:00, 907.75it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 1/1000 [00:00<04:51, 3.43it/s]
4%|4 | 45/1000 [00:00<03:15, 4.88it/s]
10%|9 | 97/1000 [00:00<02:10, 6.95it/s]
15%|#5 | 150/1000 [00:00<01:26, 9.87it/s]
20%|## | 200/1000 [00:00<00:57, 13.97it/s]
25%|##5 | 254/1000 [00:00<00:37, 19.74it/s]
31%|### | 309/1000 [00:00<00:24, 27.77it/s]
36%|###5 | 359/1000 [00:00<00:16, 38.75it/s]
41%|####1 | 414/1000 [00:01<00:10, 53.73it/s]
46%|####6 | 464/1000 [00:01<00:07, 73.31it/s]
52%|#####2 | 522/1000 [00:01<00:04, 99.31it/s]
57%|#####7 | 574/1000 [00:01<00:03, 130.63it/s]
63%|######2 | 626/1000 [00:01<00:02, 167.68it/s]
68%|######7 | 677/1000 [00:01<00:01, 201.99it/s]
72%|#######2 | 725/1000 [00:01<00:01, 240.89it/s]
77%|#######7 | 772/1000 [00:01<00:00, 280.01it/s]
82%|########2 | 820/1000 [00:01<00:00, 319.29it/s]
87%|########6 | 867/1000 [00:02<00:00, 349.68it/s]
92%|#########1| 918/1000 [00:02<00:00, 384.49it/s]
97%|#########6| 966/1000 [00:02<00:00, 405.70it/s]
100%|##########| 1000/1000 [00:02<00:00, 426.97it/s]
0%| | 0/1000 [00:00<?, ?it/s]
8%|7 | 79/1000 [00:00<00:01, 787.66it/s]
18%|#8 | 183/1000 [00:00<00:00, 848.93it/s]
28%|##8 | 283/1000 [00:00<00:00, 887.76it/s]
38%|###7 | 377/1000 [00:00<00:00, 902.75it/s]
48%|####7 | 478/1000 [00:00<00:00, 930.63it/s]
58%|#####7 | 577/1000 [00:00<00:00, 946.38it/s]
68%|######7 | 677/1000 [00:00<00:00, 959.54it/s]
77%|#######7 | 770/1000 [00:00<00:00, 949.58it/s]
86%|########6 | 864/1000 [00:00<00:00, 945.42it/s]
96%|#########6| 960/1000 [00:01<00:00, 948.24it/s]
100%|##########| 1000/1000 [00:01<00:00, 956.63it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 1/1000 [00:00<05:00, 3.32it/s]
5%|4 | 48/1000 [00:00<03:21, 4.73it/s]
9%|9 | 93/1000 [00:00<02:14, 6.73it/s]
14%|#4 | 140/1000 [00:00<01:29, 9.56it/s]
19%|#8 | 188/1000 [00:00<00:59, 13.54it/s]
24%|##3 | 236/1000 [00:00<00:39, 19.11it/s]
29%|##8 | 287/1000 [00:00<00:26, 26.85it/s]
33%|###3 | 330/1000 [00:01<00:17, 37.31it/s]
37%|###7 | 374/1000 [00:01<00:12, 51.41it/s]
42%|####1 | 418/1000 [00:01<00:08, 69.93it/s]
47%|####6 | 470/1000 [00:01<00:05, 94.44it/s]
52%|#####1 | 517/1000 [00:01<00:03, 124.08it/s]
56%|#####6 | 563/1000 [00:01<00:02, 158.83it/s]
61%|######1 | 610/1000 [00:01<00:01, 198.12it/s]
66%|######5 | 658/1000 [00:01<00:01, 239.73it/s]
71%|####### | 708/1000 [00:01<00:01, 284.06it/s]
76%|#######5 | 757/1000 [00:01<00:00, 324.02it/s]
80%|######## | 805/1000 [00:02<00:00, 354.01it/s]
85%|########5 | 852/1000 [00:02<00:00, 373.65it/s]
90%|########9 | 898/1000 [00:02<00:00, 393.86it/s]
94%|#########4| 944/1000 [00:02<00:00, 410.98it/s]
99%|#########9| 990/1000 [00:02<00:00, 423.88it/s]
100%|##########| 1000/1000 [00:02<00:00, 406.74it/s]
0%| | 0/400 [00:00<?, ?it/s]
25%|##5 | 100/400 [00:00<00:00, 998.80it/s]
53%|#####2 | 211/400 [00:00<00:00, 1026.84it/s]
80%|######## | 322/400 [00:00<00:00, 1050.32it/s]
100%|##########| 400/400 [00:00<00:00, 1089.28it/s]
0%| | 0/400 [00:00<?, ?it/s]
0%| | 1/400 [00:00<01:40, 3.96it/s]
26%|##6 | 106/400 [00:00<00:52, 5.64it/s]
53%|#####2 | 211/400 [00:00<00:23, 8.04it/s]
80%|######## | 320/400 [00:00<00:06, 11.45it/s]
100%|##########| 400/400 [00:00<00:00, 645.53it/s]
0%| | 0/400 [00:00<?, ?it/s]
24%|##4 | 97/400 [00:00<00:00, 965.63it/s]
53%|#####3 | 213/400 [00:00<00:00, 1014.91it/s]
80%|######## | 322/400 [00:00<00:00, 1036.07it/s]
100%|##########| 400/400 [00:00<00:00, 1073.61it/s]
0%| | 0/400 [00:00<?, ?it/s]
0%| | 1/400 [00:00<01:38, 4.04it/s]
29%|##9 | 116/400 [00:00<00:49, 5.76it/s]
59%|#####8 | 235/400 [00:00<00:20, 8.22it/s]
84%|########4 | 338/400 [00:00<00:05, 11.70it/s]
100%|##########| 400/400 [00:00<00:00, 659.14it/s]
0%| | 0/400 [00:00<?, ?it/s]
25%|##5 | 100/400 [00:00<00:00, 998.09it/s]
55%|#####4 | 219/400 [00:00<00:00, 1048.03it/s]
83%|########2 | 332/400 [00:00<00:00, 1065.31it/s]
100%|##########| 400/400 [00:00<00:00, 1098.47it/s]
0%| | 0/400 [00:00<?, ?it/s]
0%| | 1/400 [00:00<01:38, 4.06it/s]
28%|##8 | 114/400 [00:00<00:49, 5.79it/s]
57%|#####7 | 230/400 [00:00<00:20, 8.25it/s]
86%|########5 | 342/400 [00:00<00:04, 11.75it/s]
100%|##########| 400/400 [00:00<00:00, 668.20it/s]
0%| | 0/1000 [00:00<?, ?it/s]
7%|7 | 70/1000 [00:00<00:01, 697.46it/s]
16%|#6 | 161/1000 [00:00<00:01, 748.09it/s]
25%|##5 | 253/1000 [00:00<00:00, 790.74it/s]
35%|###4 | 347/1000 [00:00<00:00, 829.86it/s]
44%|####4 | 443/1000 [00:00<00:00, 864.52it/s]
54%|#####3 | 535/1000 [00:00<00:00, 877.39it/s]
62%|######2 | 625/1000 [00:00<00:00, 883.61it/s]
72%|#######1 | 716/1000 [00:00<00:00, 890.30it/s]
80%|######## | 802/1000 [00:00<00:00, 878.07it/s]
90%|########9 | 898/1000 [00:01<00:00, 900.22it/s]
99%|#########9| 991/1000 [00:01<00:00, 907.98it/s]
100%|##########| 1000/1000 [00:01<00:00, 896.11it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 1/1000 [00:00<04:12, 3.95it/s]
11%|#1 | 110/1000 [00:00<02:37, 5.64it/s]
22%|##2 | 224/1000 [00:00<01:36, 8.04it/s]
34%|###4 | 341/1000 [00:00<00:57, 11.45it/s]
46%|####5 | 455/1000 [00:00<00:33, 16.29it/s]
56%|#####6 | 565/1000 [00:00<00:18, 23.12it/s]
68%|######7 | 676/1000 [00:00<00:09, 32.73it/s]
79%|#######9 | 792/1000 [00:00<00:04, 46.20it/s]
90%|########9 | 896/1000 [00:01<00:01, 64.76it/s]
100%|##########| 1000/1000 [00:01<00:00, 870.90it/s]
0%| | 0/1000 [00:00<?, ?it/s]
9%|9 | 92/1000 [00:00<00:00, 912.93it/s]
20%|## | 203/1000 [00:00<00:00, 963.16it/s]
32%|###2 | 322/1000 [00:00<00:00, 1021.25it/s]
44%|####4 | 441/1000 [00:00<00:00, 1064.24it/s]
56%|#####6 | 560/1000 [00:00<00:00, 1098.69it/s]
68%|######8 | 681/1000 [00:00<00:00, 1126.90it/s]
79%|#######9 | 790/1000 [00:00<00:00, 1115.40it/s]
90%|######### | 901/1000 [00:00<00:00, 1111.18it/s]
100%|##########| 1000/1000 [00:00<00:00, 1112.73it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 1/1000 [00:00<04:12, 3.95it/s]
11%|#1 | 112/1000 [00:00<02:37, 5.63it/s]
22%|##2 | 224/1000 [00:00<01:36, 8.03it/s]
33%|###3 | 334/1000 [00:00<00:58, 11.44it/s]
44%|####4 | 443/1000 [00:00<00:34, 16.27it/s]
55%|#####5 | 554/1000 [00:00<00:19, 23.09it/s]
67%|######7 | 671/1000 [00:00<00:10, 32.71it/s]
78%|#######7 | 778/1000 [00:00<00:04, 46.12it/s]
88%|########8 | 880/1000 [00:01<00:01, 64.64it/s]
99%|#########9| 992/1000 [00:01<00:00, 90.11it/s]
100%|##########| 1000/1000 [00:01<00:00, 858.04it/s]
0%| | 0/1000 [00:00<?, ?it/s]
8%|8 | 82/1000 [00:00<00:01, 816.91it/s]
18%|#7 | 177/1000 [00:00<00:00, 852.51it/s]
27%|##7 | 270/1000 [00:00<00:00, 871.91it/s]
36%|###6 | 360/1000 [00:00<00:00, 878.46it/s]
45%|####5 | 453/1000 [00:00<00:00, 890.71it/s]
55%|#####4 | 548/1000 [00:00<00:00, 907.21it/s]
64%|######4 | 644/1000 [00:00<00:00, 920.12it/s]
74%|#######4 | 740/1000 [00:00<00:00, 931.36it/s]
83%|########3 | 833/1000 [00:00<00:00, 930.27it/s]
92%|#########2| 924/1000 [00:01<00:00, 922.18it/s]
100%|##########| 1000/1000 [00:01<00:00, 914.25it/s]
0%| | 0/1000 [00:00<?, ?it/s]
0%| | 1/1000 [00:00<04:21, 3.82it/s]
11%|#1 | 112/1000 [00:00<02:42, 5.45it/s]
22%|##2 | 225/1000 [00:00<01:39, 7.77it/s]
33%|###3 | 331/1000 [00:00<01:00, 11.06it/s]
44%|####4 | 443/1000 [00:00<00:35, 15.74it/s]
56%|#####6 | 564/1000 [00:00<00:19, 22.36it/s]
68%|######8 | 681/1000 [00:00<00:10, 31.68it/s]
80%|######## | 800/1000 [00:00<00:04, 44.75it/s]
91%|######### | 908/1000 [00:01<00:01, 62.81it/s]
100%|##########| 1000/1000 [00:01<00:00, 870.96it/s]
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:445: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
% get_backend())
import matplotlib.pyplot as plt
import symjax
import symjax.tensor as T
from symjax.nn import optimizers
import numpy as np
from tqdm import tqdm
BS = 1000
D = 500
X = np.random.randn(BS, D).astype("float32")
Y = X.dot(np.random.randn(D, 1).astype("float32")) + 2
def TF1(x, y, N, lr, model, preallocate=False):
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_v2_behavior()
tf.reset_default_graph()
tf_input = tf.placeholder(dtype=tf.float32, shape=[BS, D])
tf_output = tf.placeholder(dtype=tf.float32, shape=[BS, 1])
np.random.seed(0)
tf_W = tf.Variable(np.random.randn(D, 1).astype("float32"))
tf_b = tf.Variable(
np.random.randn(
1,
).astype("float32")
)
tf_loss = tf.reduce_mean((tf.matmul(tf_input, tf_W) + tf_b - tf_output) ** 2)
if model == "SGD":
train_op = tf.train.GradientDescentOptimizer(lr).minimize(tf_loss)
elif model == "Adam":
train_op = tf.train.AdamOptimizer(lr).minimize(tf_loss)
# initialize session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
losses = []
for i in tqdm(range(N)):
losses.append(
sess.run([tf_loss, train_op], feed_dict={tf_input: x, tf_output: y})[0]
)
return losses
def TF_EMA(X):
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_v2_behavior()
tf.reset_default_graph()
x = tf.placeholder("float32")
# Create an ExponentialMovingAverage object
ema = tf.train.ExponentialMovingAverage(decay=0.9)
op = ema.apply([x])
out = ema.average(x)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer(), feed_dict={x: X[0]})
outputs = []
for i in range(len(X)):
sess.run(op, feed_dict={x: X[i]})
outputs.append(sess.run(out))
return outputs
def SJ_EMA(X, debias=True):
symjax.current_graph().reset()
x = T.Placeholder((), "float32", name="x")
value = symjax.nn.schedules.ExponentialMovingAverage(x, 0.9, debias=debias)[0]
print(x, value)
print(symjax.current_graph().roots(value))
train = symjax.function(x, outputs=value, updates=symjax.get_updates())
outputs = []
for i in range(len(X)):
outputs.append(train(X[i]))
return outputs
def SJ(x, y, N, lr, model, preallocate=False):
symjax.current_graph().reset()
sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])
np.random.seed(0)
sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
sj_b = T.Variable(
np.random.randn(
1,
).astype("float32")
)
sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean()
if model == "SGD":
optimizers.SGD(sj_loss, lr)
elif model == "Adam":
optimizers.Adam(sj_loss, lr)
train = symjax.function(
sj_input, sj_output, outputs=sj_loss, updates=symjax.get_updates()
)
losses = []
for i in tqdm(range(N)):
losses.append(train(x, y))
return losses
sample = np.random.randn(100)
plt.figure()
plt.plot(sample, label="Original signal", alpha=0.5)
plt.plot(TF_EMA(sample), c="orange", label="TF ema", linewidth=2, alpha=0.5)
plt.plot(SJ_EMA(sample), c="green", label="SJ ema (biased)", linewidth=2, alpha=0.5)
plt.plot(
SJ_EMA(sample, False),
c="green",
linestyle="--",
label="SJ ema (unbiased)",
linewidth=2,
alpha=0.5,
)
plt.legend()
plt.figure()
Ns = [400, 1000]
lrs = [0.001, 0.01, 0.1]
colors = ["r", "b", "g"]
for k, N in enumerate(Ns):
plt.subplot(1, len(Ns), 1 + k)
for c, lr in enumerate(lrs):
loss = TF1(X, Y, N, lr, "Adam")
plt.plot(loss, c=colors[c], linestyle="-", alpha=0.5)
loss = SJ(X, Y, N, lr, "Adam")
plt.plot(loss, c=colors[c], linestyle="--", alpha=0.5, linewidth=2)
plt.title("lr:" + str(lr))
plt.suptitle("Adam Optimization quadratic loss (-:TF, --:SJ)")
plt.figure()
Ns = [400, 1000]
lrs = [0.001, 0.01, 0.1]
colors = ["r", "b", "g"]
for k, N in enumerate(Ns):
plt.subplot(1, len(Ns), 1 + k)
for c, lr in enumerate(lrs):
loss = TF1(X, Y, N, lr, "SGD")
plt.plot(loss, c=colors[c], linestyle="-", alpha=0.5)
loss = SJ(X, Y, N, lr, "SGD")
plt.plot(loss, c=colors[c], linestyle="--", alpha=0.5, linewidth=2)
plt.title("lr:" + str(lr))
plt.xlabel("steps")
plt.suptitle("GD Optimization quadratic loss (-:TF, --:SJ)")
plt.show()
Total running time of the script: ( 0 minutes 27.428 seconds)
Deep Neural Networks¶
Note
Click here to download the full example code
RNN/GRU example¶
example of vanilla RNN for time series regression
Out:
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:445: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
% get_backend())
import symjax.tensor as T
from symjax import nn
import symjax
import numpy as np
import matplotlib.pyplot as plt
symjax.current_graph().reset()
# create the network
BATCH_SIZE = 32
TIME = 32
WIDTH = 32
C = 1
np.random.seed(0)
timeseries = T.Placeholder((BATCH_SIZE, TIME, C), "float32", name="time-series")
target = T.Placeholder((BATCH_SIZE, TIME), "float32", name="target")
rnn = nn.layers.RNN(timeseries, np.zeros((BATCH_SIZE, WIDTH)), WIDTH)
rnn = nn.layers.RNN(rnn, np.zeros((BATCH_SIZE, WIDTH)), WIDTH)
rnn = nn.layers.Dense(rnn, 1, flatten=False)
gru = nn.layers.GRU(timeseries, np.zeros((BATCH_SIZE, WIDTH)), WIDTH)
gru = nn.layers.GRU(gru, np.zeros((BATCH_SIZE, WIDTH)), WIDTH)
gru = nn.layers.Dense(gru, 1, flatten=False)
loss = ((target - rnn[:, :, 0]) ** 2).mean()
lossg = ((target - gru[:, :, 0]) ** 2).mean()
lr = nn.schedules.PiecewiseConstant(0.01, {1000: 0.005, 1800: 0.001})
nn.optimizers.Adam(loss + lossg, lr)
train = symjax.function(
timeseries,
target,
outputs=[loss, lossg],
updates=symjax.get_updates(),
)
predict = symjax.function(timeseries, outputs=[rnn[:, :, 0], gru[:, :, 0]])
x = [
np.random.randn(TIME) * 0.1 + np.cos(shift + np.linspace(-5, 10, TIME))
for shift in np.random.randn(BATCH_SIZE * 200) * 0.3
]
w = np.random.randn(TIME) * 0.01
y = [(w + np.roll(xi, 2) * 0.4) ** 3 for xi in x]
y = np.stack(y)
x = np.stack(x)[:, :, None]
x /= np.linalg.norm(x, 2, 1, keepdims=True)
x -= x.min()
y /= np.linalg.norm(y, 2, 1, keepdims=True)
loss = []
for i in range(10):
for xb, yb in symjax.data.utils.batchify(x, y, batch_size=BATCH_SIZE):
loss.append(train(xb, yb))
loss = np.stack(loss)
plt.figure(figsize=(8, 8))
plt.subplot(121)
plt.plot(loss[:, 0], c="g", label="Elman")
plt.plot(loss[:, 1], c="r", label="GRU")
plt.title("Training loss")
plt.xlabel("Iterations")
plt.ylabel("MSE")
plt.legend()
pred = predict(x[:BATCH_SIZE])
for i in range(4):
plt.subplot(4, 2, 2 + 2 * i)
plt.plot(x[i, :, 0], "-x", c="k", label="input")
plt.plot(y[i], "-x", c="b", label="target")
plt.plot(pred[0][i], "-x", c="g", label="Elman")
plt.plot(pred[1][i], "-x", c="r", label="GRU")
plt.title("Predictions")
plt.legend()
plt.show()
Total running time of the script: ( 1 minutes 7.615 seconds)
Note
Click here to download the full example code
MNIST classification¶
example of image (MNIST) classification on small part of the data and with a small architecture
Out:
... mnist.pkl.gz already exists
Loading mnist
Dataset mnist loaded in 0.89s.
(64, 1, 28, 28)
(64, 64, 26, 26)
(64, 64, 26, 26)
(64, 64, 26, 26)
(64, 64, 13, 13)
(64, 64, 11, 11)
(64, 64, 11, 11)
(64, 64, 11, 11)
(64, 64, 5, 5)
(64, 64, 1, 1)
(64, 10)
...epoch: 0
Test Loss and Accu: [2.303416 0.09815705]
Train Loss and Accu [1.9887893 0.4032258]
...epoch: 1
Test Loss and Accu: [2.8440666 0.11348157]
Train Loss and Accu [1.6667411 0.59727824]
...epoch: 2
Test Loss and Accu: [1.9810963 0.2217548]
Train Loss and Accu [1.4739343 0.7021169]
...epoch: 3
Test Loss and Accu: [1.4481678 0.62369794]
Train Loss and Accu [1.3121623 0.7636089]
...epoch: 4
Test Loss and Accu: [1.3252217 0.6608574]
Train Loss and Accu [1.1772408 0.81602824]
...epoch: 5
Test Loss and Accu: [1.1886992 0.7281651]
Train Loss and Accu [1.0639273 0.8251008]
...epoch: 6
Test Loss and Accu: [1.0929953 0.79547274]
Train Loss and Accu [0.9512595 0.86441535]
...epoch: 7
Test Loss and Accu: [1.0527598 0.7073317]
Train Loss and Accu [0.87056917 0.88508064]
...epoch: 8
Test Loss and Accu: [0.9697831 0.7624199]
Train Loss and Accu [0.7969418 0.8840726]
...epoch: 9
Test Loss and Accu: [0.9283603 0.7219551]
Train Loss and Accu [0.7198963 0.8986895]
...epoch: 10
Test Loss and Accu: [0.8994963 0.73297274]
Train Loss and Accu [0.66082746 0.9153226 ]
...epoch: 11
Test Loss and Accu: [0.7540454 0.82471955]
Train Loss and Accu [0.6040695 0.9203629]
...epoch: 12
Test Loss and Accu: [0.6891631 0.86588544]
Train Loss and Accu [0.55794346 0.9279234 ]
...epoch: 13
Test Loss and Accu: [0.88701797 0.7079327 ]
Train Loss and Accu [0.5097796 0.9309476]
...epoch: 14
Test Loss and Accu: [0.6690547 0.8370393]
Train Loss and Accu [0.45673898 0.9485887 ]
...epoch: 15
Test Loss and Accu: [0.6740171 0.82261616]
Train Loss and Accu [0.42886272 0.94758064]
...epoch: 16
Test Loss and Accu: [0.66504765 0.79917866]
Train Loss and Accu [0.41574922 0.9465726 ]
...epoch: 17
Test Loss and Accu: [0.5708862 0.8608774]
Train Loss and Accu [0.36107954 0.9596774 ]
...epoch: 18
Test Loss and Accu: [0.49747434 0.89663464]
Train Loss and Accu [0.34122872 0.9621976 ]
...epoch: 19
Test Loss and Accu: [0.5516419 0.85917467]
Train Loss and Accu [0.31652793 0.96622986]
Text(0.5, 0.98, 'MNIST (1K data) classification task')
import symjax.tensor as T
from symjax import nn
import symjax
import numpy as np
import matplotlib.pyplot as plt
from symjax.data import mnist
from symjax.data.utils import batchify
import os
os.environ["DATASET_PATH"] = "/home/vrael/DATASETS/"
symjax.current_graph().reset()
# load the dataset
mnist = mnist()
# some renormalization, and we only keep the first 2000 images
mnist["train_set/images"] = mnist["train_set/images"][:2000]
mnist["train_set/labels"] = mnist["train_set/labels"][:2000]
mnist["train_set/images"] /= mnist["train_set/images"].max((1, 2, 3), keepdims=True)
mnist["test_set/images"] /= mnist["test_set/images"].max((1, 2, 3), keepdims=True)
# create the network
BATCH_SIZE = 64
images = T.Placeholder((BATCH_SIZE, 1, 28, 28), "float32", name="images")
labels = T.Placeholder((BATCH_SIZE,), "int32", name="labels")
deterministic = T.Placeholder((1,), "bool")
layer = [nn.layers.Identity(images)]
for l in range(2):
layer.append(nn.layers.Conv2D(layer[-1], 64, (3, 3), b=None))
# due to the small size of the dataset we can
# increase the update of the bn moving averages
layer.append(
nn.layers.BatchNormalization(
layer[-1], [1], deterministic, beta_1=0.9, beta_2=0.9
)
)
layer.append(nn.leaky_relu(layer[-1]))
layer.append(nn.layers.Pool2D(layer[-1], (2, 2)))
layer.append(nn.layers.Pool2D(layer[-1], layer[-1].shape.get()[2:], pool_type="AVG"))
layer.append(nn.layers.Dense(layer[-1], 10))
# each layer is itself a tensor which represents its output and thus
# any tensor operation can be used on the layer instance, for example
for l in layer:
print(l.shape.get())
loss = nn.losses.sparse_softmax_crossentropy_logits(labels, layer[-1]).mean()
accuracy = nn.losses.accuracy(labels, layer[-1])
nn.optimizers.Adam(loss, 0.001)
test = symjax.function(images, labels, deterministic, outputs=[loss, accuracy])
train = symjax.function(
images,
labels,
deterministic,
outputs=[loss, accuracy],
updates=symjax.get_updates(),
)
test_accuracy = []
train_accuracy = []
for epoch in range(20):
print("...epoch:", epoch)
L = list()
for x, y in batchify(
mnist["test_set/images"],
mnist["test_set/labels"],
batch_size=BATCH_SIZE,
option="continuous",
):
L.append(test(x, y, 1))
print("Test Loss and Accu:", np.mean(L, 0))
test_accuracy.append(np.mean(L, 0))
L = list()
for x, y in batchify(
mnist["train_set/images"],
mnist["train_set/labels"],
batch_size=BATCH_SIZE,
option="random_see_all",
):
L.append(train(x, y, 0))
train_accuracy.append(np.mean(L, 0))
print("Train Loss and Accu", np.mean(L, 0))
train_accuracy = np.array(train_accuracy)
test_accuracy = np.array(test_accuracy)
plt.subplot(121)
plt.plot(test_accuracy[:, 1], c="k")
plt.plot(train_accuracy[:, 1], c="b")
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.subplot(122)
plt.plot(test_accuracy[:, 0], c="k")
plt.plot(train_accuracy[:, 0], c="b")
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.suptitle("MNIST (1K data) classification task")
Total running time of the script: ( 2 minutes 12.597 seconds)
Note
Click here to download the full example code
Image classification, Keras and SymJAX¶
example of image classification with deep networks using Keras and SymJAX
Out:
... cifar-10-python.tar.gz already exists
Loading cifar10: 0%| | 0/5 [00:00<?, ?it/s]
Loading cifar10: 20%|## | 1/5 [00:02<00:11, 2.99s/it]
Loading cifar10: 40%|#### | 2/5 [00:03<00:07, 2.39s/it]
Loading cifar10: 60%|###### | 3/5 [00:04<00:03, 1.89s/it]
Loading cifar10: 80%|######## | 4/5 [00:04<00:01, 1.40s/it]
Loading cifar10: 100%|##########| 5/5 [00:05<00:00, 1.28s/it]
Loading cifar10: 100%|##########| 5/5 [00:05<00:00, 1.19s/it]
Dataset cifar10 loaded in6.59s.
import symjax.tensor as T
from symjax import nn
import symjax
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.setrecursionlimit(3500)
def classif_tf(train_x, train_y, test_x, test_y, mlp=True):
import tensorflow as tf
from tensorflow.keras import layers
batch_size = 128
inputs = layers.Input(shape=(3, 32, 32))
if not mlp:
out = layers.Permute((2, 3, 1))(inputs)
out = layers.Conv2D(32, 3, activation="relu")(out)
for i in range(3):
for j in range(3):
conv = layers.Conv2D(
32 * (i + 1), 3, activation="linear", padding="SAME"
)(out)
bn = layers.BatchNormalization(axis=-1)(conv)
relu = layers.Activation("relu")(bn)
conv = layers.Conv2D(
32 * (i + 1), 3, activation="linear", padding="SAME"
)(relu)
bn = layers.BatchNormalization(axis=-1)(conv)
out = layers.Add()([out, bn])
out = layers.AveragePooling2D()(out)
out = layers.Conv2D(32 * (i + 2), 1, activation="linear")(out)
print(out.shape)
out = layers.GlobalAveragePooling2D()(out)
else:
out = layers.Flatten()(inputs)
for i in range(6):
out = layers.Dense(4000, activation="linear")(out)
bn = layers.BatchNormalization(axis=-1)(out)
out = layers.Activation("relu")(bn)
outputs = layers.Dense(10, activation="linear")(out)
model = tf.keras.Model(inputs, outputs)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
for epoch in range(5):
accu = 0
for x, y in symjax.data.utils.batchify(
train_x, train_y, batch_size=batch_size, option="random"
):
with tf.GradientTape() as tape:
preds = model(x, training=True)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(y, preds)
)
accu += tf.reduce_mean(tf.cast(y == tf.argmax(preds, 1), "float32"))
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
print("training", accu / (len(train_x) // batch_size))
accu = 0
for x, y in symjax.data.utils.batchify(
test_x, test_y, batch_size=batch_size, option="continuous"
):
preds = model(x, training=False)
accu += tf.reduce_mean(tf.cast(y == tf.argmax(preds, 1), "float32"))
print(accu / (len(test_x) // batch_size))
def classif_sj(train_x, train_y, test_x, test_y, mlp=True):
symjax.current_graph().reset()
from symjax import nn
batch_size = 128
input = T.Placeholder((batch_size, 3, 32, 32), "float32")
labels = T.Placeholder((batch_size,), "int32")
deterministic = T.Placeholder((), "bool")
if not mlp:
out = nn.relu(nn.layers.Conv2D(input, 32, (3, 3)))
for i in range(3):
for j in range(3):
conv = nn.layers.Conv2D(out, 32 * (i + 1), (3, 3), pad="SAME")
bn = nn.layers.BatchNormalization(
conv, [1], deterministic=deterministic
)
bn = nn.relu(bn)
conv = nn.layers.Conv2D(bn, 32 * (i + 1), (3, 3), pad="SAME")
bn = nn.layers.BatchNormalization(
conv, [1], deterministic=deterministic
)
out = out + bn
out = nn.layers.Pool2D(out, (2, 2), pool_type="AVG")
out = nn.layers.Conv2D(out, 32 * (i + 2), (1, 1))
# out = out.mean((2, 3))
out = nn.layers.Pool2D(out, out.shape.get()[-2:], pool_type="AVG")
else:
out = input
for i in range(6):
out = nn.layers.Dense(out, 4000)
out = nn.relu(
nn.layers.BatchNormalization(out, [1], deterministic=deterministic)
)
outputs = nn.layers.Dense(out, 10)
loss = nn.losses.sparse_softmax_crossentropy_logits(labels, outputs).mean()
nn.optimizers.Adam(loss, 0.001)
accu = T.equal(outputs.argmax(1), labels).astype("float32").mean()
train = symjax.function(
input,
labels,
deterministic,
outputs=[loss, accu, outputs],
updates=symjax.get_updates(),
)
test = symjax.function(input, labels, deterministic, outputs=accu)
for epoch in range(5):
accu = 0
for x, y in symjax.data.utils.batchify(
train_x, train_y, batch_size=batch_size, option="random"
):
accu += train(x, y, 0)[1]
print("training", accu / (len(train_x) // batch_size))
accu = 0
for x, y in symjax.data.utils.batchify(
test_x, test_y, batch_size=batch_size, option="continuous"
):
accu += test(x, y, 1)
print(accu / (len(test_x) // batch_size))
mnist = symjax.data.cifar10()
train_x, train_y = mnist["train_set/images"], mnist["train_set/labels"]
test_x, test_y = mnist["test_set/images"], mnist["test_set/labels"]
train_x /= train_x.max()
test_x /= test_x.max()
# classif_sj(train_x, train_y, test_x, test_y, False)
# classif_tf(train_x, train_y, test_x, test_y, False)
Total running time of the script: ( 0 minutes 6.730 seconds)
Datasets¶
Note
Click here to download the full example code
Speech picidae Dataset¶
This example shows how to download/load/import speech picidae
Out:
... PicidaeDataset.zip already exists
0%| | 0/3369 [00:00<?, ?it/s]
8%|8 | 275/3369 [00:00<00:01, 2728.72it/s]
12%|#1 | 390/3369 [00:00<00:01, 1923.33it/s]
18%|#7 | 604/3369 [00:00<00:01, 1983.50it/s]/home/vrael/SymJAX/symjax/data/picidae.py:96: WavFileWarning: Chunk (non-data) not understood, skipping it.
wavs.append(wav_read(byt)[1].astype("float32"))
22%|##1 | 733/3369 [00:00<00:01, 1537.87it/s]
29%|##8 | 966/3369 [00:00<00:01, 1707.83it/s]
33%|###3 | 1120/3369 [00:00<00:01, 1176.85it/s]
37%|###7 | 1248/3369 [00:00<00:01, 1135.30it/s]
41%|#### | 1369/3369 [00:01<00:02, 922.97it/s]
46%|####6 | 1557/3369 [00:01<00:01, 1088.74it/s]
50%|##### | 1686/3369 [00:01<00:01, 1081.89it/s]
55%|#####4 | 1839/3369 [00:01<00:01, 1182.07it/s]
58%|#####8 | 1970/3369 [00:01<00:01, 1196.45it/s]
62%|######2 | 2099/3369 [00:01<00:01, 1212.50it/s]
66%|######6 | 2227/3369 [00:01<00:01, 802.06it/s]
69%|######9 | 2330/3369 [00:02<00:01, 827.16it/s]
85%|########5 | 2871/3369 [00:02<00:00, 1091.15it/s]
91%|#########1| 3069/3369 [00:02<00:00, 577.81it/s]
95%|#########5| 3216/3369 [00:03<00:00, 353.17it/s]
99%|#########8| 3325/3369 [00:04<00:00, 323.84it/s]
100%|##########| 3369/3369 [00:04<00:00, 816.31it/s]
Dataset picidae loaded in 4.16s.
import symjax
import matplotlib.pyplot as plt
picidae = symjax.data.picidae()
plt.figure(figsize=(10, 4))
for i in range(10):
plt.subplot(2, 5, 1 + i)
plt.plot(picidae["wavs"][i])
plt.title(str(picidae["labels"][i]))
plt.tight_layout()
Total running time of the script: ( 0 minutes 4.670 seconds)
Note
Click here to download the full example code
MNIST Dataset¶
This example shows how to download/load/import MNIST
Out:
... mnist.pkl.gz already exists
Loading mnist
Dataset mnist loaded in 0.81s.
import symjax
import matplotlib.pyplot as plt
mnist = symjax.data.mnist()
plt.figure(figsize=(10, 4))
for i in range(10):
plt.subplot(2, 5, 1 + i)
plt.imshow(mnist["train_set/images"][i, 0], aspect="auto", cmap="Greys")
plt.xticks([])
plt.yticks([])
plt.title(str(mnist["train_set/labels"][i]))
plt.tight_layout()
Total running time of the script: ( 0 minutes 1.088 seconds)
Note
Click here to download the full example code
CIFAR10 Dataset¶
This example shows how to download/load/import CIFAR10
Out:
... cifar-10-python.tar.gz already exists
Loading cifar10: 0%| | 0/5 [00:00<?, ?it/s]
Loading cifar10: 20%|## | 1/5 [00:02<00:11, 2.98s/it]
Loading cifar10: 40%|#### | 2/5 [00:03<00:07, 2.38s/it]
Loading cifar10: 60%|###### | 3/5 [00:04<00:03, 1.89s/it]
Loading cifar10: 80%|######## | 4/5 [00:04<00:01, 1.40s/it]
Loading cifar10: 100%|##########| 5/5 [00:05<00:00, 1.27s/it]
Loading cifar10: 100%|##########| 5/5 [00:05<00:00, 1.19s/it]
Dataset cifar10 loaded in6.55s.
import symjax
import matplotlib.pyplot as plt
cifar10 = symjax.data.cifar10()
plt.figure(figsize=(10, 4))
for i in range(10):
plt.subplot(2, 5, 1 + i)
image = cifar10["train_set/images"][i]
label = cifar10["train_set/labels"][i]
plt.imshow(image.transpose((1, 2, 0)) / image.max(), aspect="auto", cmap="Greys")
plt.xticks([])
plt.yticks([])
plt.title("{}:{}".format(label, cifar10["label_to_name"][label]))
plt.tight_layout()
Total running time of the script: ( 0 minutes 6.801 seconds)
Note
Click here to download the full example code
RNTK kernel¶
tiem series regression and classification
Out:
Op(name=multiply, fn=multiply, shape=(3, 3), dtype=float32, scope=/)
[array([[374.33334, 120.098 , 131.59918],
[120.098 , 302.12424, 86.20666],
[131.59918, 86.20666, 203.82948]], dtype=float32), array([[71.971275, 48.085426, 45.39569 ],
[48.085426, 60.663128, 35.050926],
[45.39569 , 35.050926, 38.225647]], dtype=float32)]
import numpy as np
import symjax
import symjax.tensor as T
import networkx as nx
def RNTK_first_time_step(x, param):
# this is for computing the first GP and RNTK for t = 1. Both for relu and erf
sw = param["sigmaw"]
su = param["sigmau"]
sb = param["sigmab"]
sh = param["sigmah"]
X = x * x[:, None]
print(X)
n = X.shape[0]
GP_new = sh ** 2 * sw ** 2 * T.eye(n, n) + (su ** 2 / m) * X + sb ** 2
RNTK_new = GP_new
return RNTK_new, GP_new
def RNTK_relu(x, RNTK_old, GP_old, param, output):
sw = param["sigmaw"]
su = param["sigmau"]
sb = param["sigmab"]
sv = param["sigmav"]
a = T.diag(GP_old) # GP_old is in R^{n*n} having the output gp kernel
# of all pairs of data in the data set
B = a * a[:, None]
C = T.sqrt(B) # in R^{n*n}
D = GP_old / C # this is lamblda in ReLU analyrucal formula
# clipping E between -1 and 1 for numerical stability.
E = T.clip(D, -1, 1)
F = (1 / (2 * np.pi)) * (E * (np.pi - T.arccos(E)) + T.sqrt(1 - E ** 2)) * C
G = (np.pi - T.arccos(E)) / (2 * np.pi)
if output:
GP_new = sv ** 2 * F
RNTK_new = sv ** 2.0 * RNTK_old * G + GP_new
else:
X = x * x[:, None]
GP_new = sw ** 2 * F + (su ** 2 / m) * X + sb ** 2
RNTK_new = sw ** 2.0 * RNTK_old * G + GP_new
return RNTK_new, GP_new
L = 10
N = 3
DATA = T.Placeholder((N, L), "float32", name="data")
# parameters
param = {}
param["sigmaw"] = 1.33
param["sigmau"] = 1.45
param["sigmab"] = 1.2
param["sigmah"] = 0.4
param["sigmav"] = 2.34
m = 1
# first time step
RNTK, GP = RNTK_first_time_step(DATA[:, 0], param)
for t in range(1, L):
RNTK, GP = RNTK_relu(DATA[:, t], RNTK, GP, param, False)
RNTK, GP = RNTK_relu(0, RNTK, GP, param, True)
f = symjax.function(DATA, outputs=[RNTK, GP])
# three data of length T
a = np.random.randn(L)
b = np.random.randn(L)
c = np.random.randn(L)
example = np.stack([a, b, c]) # it is of shape (3, T)
print(f(example))
Total running time of the script: ( 0 minutes 4.321 seconds)
Signal Processing¶
Note
Click here to download the full example code
Morlet Wavelet in time and Fourier domain¶
This example shows how to generate a wavelet filter-bank.
import symjax
import symjax.tensor as T
import matplotlib.pyplot as plt
import numpy as np
J = 5
Q = 4
scales = T.power(2, T.linspace(0.1, J - 1, J * Q))
scales = scales[:, None]
wavelet = symjax.tensor.signal.complex_morlet(5 * scales, np.pi / scales)
waveletw = symjax.tensor.signal.fourier_complex_morlet(
5 * scales, np.pi / scales, wavelet.shape[-1]
)
f = symjax.function(outputs=[wavelet, waveletw])
wavelet, waveletw = f()
plt.subplot(121)
for i in range(J * Q):
plt.plot(2 * i + wavelet[i].real, c="b")
plt.plot(2 * i + wavelet[i].imag, c="r")
plt.subplot(122)
for i in range(J * Q):
plt.plot(i + waveletw[i].real, c="b")
plt.plot(i + waveletw[i].imag, c="r")
Total running time of the script: ( 0 minutes 10.489 seconds)
Development¶
The SymJAX project was started by Randall Balestriero in early 2020. As an open-source project, we highly welcome contributions (current contributors) !
Philosophy¶
SymJAX started from the need to combine the best functionalities of Theano, Tensorflow (v1) and Lasagne. While we propose various deep learning oriented methods, SymJAX shall remain as general as possible in its core, methods should be grouped as much as possible into specialized submodules, and a complete documentation should be provided, preferably along with a working example located in the Gallery.
How to contribute¶
If you are willing to help, we recommend to follow the following steps before requesting a pull request. Recall that
- Coding conventions: we used the PEP8 style guide for Python Code and the black formatting
- Docstrings: we use the numpydoc docstring guide for documenting the functions directly from the docstrings and automatically generating the documentation with sphinx. Please provide codes with up-to-date docstrings.
- Continuous Integration: to ensure that all the SymJAX functionalities are tested after each modifition run
pytest
from the main SymJAX directory. All tests should pass before considering a change to be successful. If new functionalities are added, it is highly preferable to also add a simple test in thetests/
directory to ensure that results are as expected. A Github action will automatically test the code at eachpush
(see Test the code).
Build/Test the doc¶
To rebuild the documentation, install several packages:
pip install -r docs/requirements.txt
to generate the documentation, you can do in the docs
directory and run:
make html
You can then see the generated documentation in
docs/_build/html/index.html
.
If examples/code-blocks are added to the documension, it has to be tested.
To do so, add the specific module/function in the tests/doc.py
and run:
>>> python tests/doc.py
if all tests pass, then the changes are ready to be put in a PR. Once the documentation has been changed and all tests pass, the change is ready for review and should be put in a PR.
Every time changes are pushed to Github master
branch the SymJAX
documentations (at symjax.readthedocs.io) is rebuilt based on
the .readthedocs.yml
and the docs/conf.py
configuration files.
For each automated documentation build you can see the
documentation build logs.
Test the code¶
To run all the SymJAX tests, we recommend using pytest
or pytest-xdist
. First, install pytest-xdist
and pytest-benchmark
by running
pip install pytest-xdist pytest-benchmark
.
Then, from the repository root directory run:
pytest
If all tests pass successfully, the code is ready for a PR.
symjax
¶
Graph and Compilation¶
function (*args[, outputs, updates, device, …]) |
Generate a user function that compiles a computational graph. |
Graph (name, *args, **kwargs) |
|
Scope ([relative_name, absolute_name, …]) |
Defining scope for any variable/operation to be in. |
Derivatives¶
gradients (scalar, variables) |
Compute the gradients of a scalar w.r.t to a given list of variables. |
jacobians (tensor, variables[, mode]) |
Compute the jacobians of a tensor w.r.t to a given list of variables. |
Graph Acces and Manipulation¶
clone |
|
current_graph () |
Current graph. |
get_variables ([name, scope, trainable]) |
|
get_ops ([name, scope]) |
Same as symjax.variable but for ops |
get_placeholders ([name, scope]) |
Same as symjax.variable but for placeholders |
get_updates ([name, scope, variables]) |
Same as symjax.variable but for ops |
save_variables (path_or_file[, name, scope, …]) |
saves the graph variables. |
load_variables (path_or_file[, name, scope, …]) |
loads the graph variables. |
reset_variables ([name, scope, trainable]) |
utility to reset variables based on their names |
Detailed Descriptions¶
-
symjax.
function
(*args, outputs=[], updates=None, device=None, backend=None, default_value=None, frozen=False)[source]¶ Generate a user function that compiles a computational graph.
Based on given inputs, outputs and update policy of variables. This function internally jit compile the underlying jax computational graph for performances.
Parameters: - args (trailing tuple) – the inputs to the function to be compiled. The tuple should contain all the placeholders that are roots of any output given of the function and update values
- outputs (List (optional)) – the outputs of the function, if a single element, it can be given as a standalone and not a list
- updates (Dict (optional)) – the dictionnary of updates as per {var:new_value} for any variable of the graph
- backend ('cpu' or 'gpu') – the backend to use to run the function on
- default_value (not implemented) – not implemented
Returns: the user frontend function that takes the specified inputs, returns the specified outputs and perform internally the updates
Return type: callable
Examples
>>> import symjax >>> import symjax.tensor as T >>> x = T.ones((4, 4)) >>> xs = x.sum() + 1 >>> f = symjax.function(outputs=xs) >>> print(f()) 17.0
>>> w = T.Variable(0., name='w', dtype='float32') >>> increment = symjax.function(updates={w: w + 1}) >>> for i in range(10): ... increment() >>> print(w.value) 10.0
-
class
symjax.
Scope
(relative_name=None, absolute_name=None, reattach=False, reuse=False, graph=None)[source]¶ Defining scope for any variable/operation to be in.
Example
-
symjax.
gradients
(scalar, variables)[source]¶ Compute the gradients of a scalar w.r.t to a given list of variables.
Parameters: - scalar (
symjax.tensor.base.Tensor
) – the variable to differentiate - variables (List or Tuple) – the variables used to compute the derivative.
Returns: gradients – the sequency of gradients ordered as given in the input variables
Return type: Tuple
Example
>>> import symjax >>> w = symjax.tensor.ones(3) >>> x = symjax.tensor.Variable(2., name='x', dtype='float32') >>> l = (w ** 2).sum() * x >>> g = symjax.gradients(l, [w])[0] >>> f = symjax.function(outputs=g, updates={x:x + 1}) >>> for i in range(2): ... print(f()) [4. 4. 4.] [6. 6. 6.]
- scalar (
-
symjax.
jacobians
(tensor, variables, mode='forward')[source]¶ Compute the jacobians of a tensor w.r.t to a given list of variables.
The tensor needs not to be a vector, but will be treated as such. For example if tensor.shape is (10, 3, 3) and a variable shape if (10, 10) the resulting jacobian has shape (10, 3, 3, 10, 10). It is possible to specify the mode forward or backward. For tall jacobians, forward is faster and vice-versa.
Parameters: - vector (Tensor) – the variable to differentiate
- variables (List or Tuple) – the variables used to compute the derivative.
Returns: jacobians – the sequency of gradients ordered as given in the input variables :param tensor: :param mode:
Return type: Tuple
-
symjax.
get_updates
(name='*', scope='/', variables=None)[source]¶ Same as symjax.variable but for ops
-
symjax.
save_variables
(path_or_file, name='*', scope='*', trainable=None)[source]¶ saves the graph variables.
The saving is done via
numpy.savez
for fast and compressed storage.- path_or_file: str or file
- the path and name of the file to save the variables in or an open file object
- name: str (optional)
- the name string that the variables to save must match
- scope: str (optional)
- the scope name string that the variables to save must match
- trainable: bool or None
- the option of the variables to save (
True
,False
orNone
)
-
symjax.
load_variables
(path_or_file, name='*', scope='*', trainable=None)[source]¶ loads the graph variables.
The loading is done via
numpy.savez
for fast and compressed storage.- path_or_file: str or file
- the path and name of the file to load the variables from or an open file object
- name: str (optional)
- the name string that the variables to load must match
- scope: str (optional)
- the scope name string that the variables to load must match
- trainable: bool or None
- the option of the variables to save (
True
,False
orNone
)
-
symjax.
reset_variables
(name='*', scope='*', trainable=None)[source]¶ utility to reset variables based on their names
Parameters: - name (str (default=*)) – the name (or part of the name) of all the variables that should be reset, it can include the glob (*) searching for all matching names
- trainable (bool or None (optional, default=None)) – is not None, it will only reset from the matched variables the ones that trainable attribute matches the given one
Returns: Return type: None
Example
>>> import symjax >>> w = symjax.tensor.Variable(1., name='w', dtype='float32') >>> x = symjax.tensor.Variable(2., name='x', dtype='float32') >>> f = symjax.function(outputs=[w, x], updates={w:w + 1,x:x + 1}) >>> for i in range(10): ... print(f()) [array(1., dtype=float32), array(2., dtype=float32)] [array(2., dtype=float32), array(3., dtype=float32)] [array(3., dtype=float32), array(4., dtype=float32)] [array(4., dtype=float32), array(5., dtype=float32)] [array(5., dtype=float32), array(6., dtype=float32)] [array(6., dtype=float32), array(7., dtype=float32)] [array(7., dtype=float32), array(8., dtype=float32)] [array(8., dtype=float32), array(9., dtype=float32)] [array(9., dtype=float32), array(10., dtype=float32)] [array(10., dtype=float32), array(11., dtype=float32)] >>> # reset only the w variable >>> symjax.reset_variables('w') >>> # reset all variables >>> symjax.reset_variables('*')
symjax.data
¶
Utilities¶
symjax.data.patchify_1d (x, window_length, stride) |
extract patches from a numpy array |
symjax.data.patchify_2d (x, window_length, stride) |
|
symjax.data.train_test_split (*args[, …]) |
split given data into two non overlapping sets |
symjax.data.batchify (*args, batch_size[, …]) |
|
symjax.data.resample_images (images, target_shape) |
|
symjax.data.download_dataset (path, dataset, …) |
dataset downlading utility |
symjax.data.extract_file (filename, target) |
Images¶
symjax.data.mnist.load ([path]) |
The MNIST database of handwritten digits, available from this page has a training set of 60,000 examples, and a test set of 10,000 examples. |
symjax.data.emnist.load ([option, path]) |
Grayscale digit/letter classification. |
symjax.data.fashionmnist.load ([path]) |
Grayscale image classification |
symjax.data.dsprites.load ([path]) |
greyscale image classification and disentanglement |
symjax.data.svhn.load ([path]) |
Street number classification. |
symjax.data.cifar10.load ([path]) |
Image classification. |
symjax.data.cifar100.load ([path]) |
Image classification. |
symjax.data.celebA.load |
|
symjax.data.ibeans.load |
|
symjax.data.cassava.load |
|
symjax.data.stl10.load ([path]) |
Image classification with extra unlabeled images. |
symjax.data.tinyimagenet.load |
Audio¶
symjax.data.audiomnist.load ([path]) |
digit recognition | ||
symjax.data.univariate_timeseries.load |
|||
symjax.data.dcase_2019_task4.load ([path]) |
synthetic data for polyphonic event detection | ||
symjax.data.groove_MIDI.load ([path]) |
The Groove MIDI Dataset (GMD) is composed of 13.6 hours of aligned MIDI and (synthesized) audio of human-performed, tempo-aligned expressive drumming. | ||
symjax.data.speech_commands.load ([path]) |
|||
symjax.data.picidae.load ([path]) |
|
||
symjax.data.esc.load ([path]) |
ESC-10/50: Environmental Sound Classification | ||
symjax.data.warblr.load |
|||
symjax.data.gtzan.load ([path]) |
music genre classification | ||
symjax.data.dclde.load |
|||
symjax.data.irmas.load ([path]) |
music instrument classification | ||
symjax.data.vocalset.load |
|||
symjax.data.freefield1010.load ([path]) |
Audio binary classification, presence or absence of bird songs. | ||
symjax.data.birdvox_70k.load ([path]) |
a dataset for avian flight call detection in half-second clips | ||
symjax.data.birdvox_dcase_20k.load |
|||
symjax.data.seizures_neonatal.load |
|||
symjax.data.sonycust.load ([path]) |
multilabel urban sound classification | ||
symjax.data.gtzan.load ([path]) |
music genre classification | ||
symjax.data.FSDKaggle2018.load |
|||
symjax.data.TUTacousticscenes2017.load ([path]) |
Acoustic Scene classification |
Detailed description¶
-
symjax.data.
patchify_1d
(x, window_length, stride)[source]¶ extract patches from a numpy array
Parameters: - x (array-like) – the input data to extract patches from, any shape, the last dimension is the one being patched
- window_length (int) – the length of the patches
- stride (int) – the amount of stride (bins separating two consecutive patches
Returns: x_patches – the number of patches is put in the pre-last dimension (-2)
Return type: array-like
-
symjax.data.
train_test_split
(*args, train_size=0.8, stratify=None, seed=None)[source]¶ split given data into two non overlapping sets
Parameters: - *args (inputs) – the sets to be split by the function
- train_size (scalar) – the amount of data to put in the first set, either an integer value being the actual number of data to keep, or a ratio (0 to 1 number)
- stratify (array (optional)) – the optimal stratify guide to spit the array s.t. the same proportion based on the stratify array is kep in both set based on the proportion of the split
- seed (integer (optional)) – the seed for the random number generator for reproducibility
Returns: Example
x = numpy.random.randn(100, 4) y = numpy.random.randn(100) train, test = train_test_split(x, y, train_size=0.5) print(train[0].shape, train[1].shape) # (50, 4) (50,) print(test[0].shape, test[1].shape) # (50, 4) (50,)
-
class
symjax.data.
batchify
(*args, batch_size, option='random', load_func=None, extra_process=0, n_batches=None)[source]¶
-
symjax.data.
resample_images
(images, target_shape, ratio='same', order=1, mode='nearest', data_format='channels_first')[source]¶
-
symjax.data.
download_dataset
(path, dataset, urls_names, baseurl='', extract=False)[source]¶ dataset downlading utility
Args:
- path: string
- the path where the dataset should be download
- dataset: string
- the name of the dataset, used as the folder name
- urls_names: dict
- dictionnary mapping urls to filename. If the urls have a common root, then it can be omited from this variable and put into the baseurl argument
- baseurl: string
- the common url to prepend onto each url in urls_names
-
symjax.data.mnist.
load
(path=None)[source]¶ The MNIST database of handwritten digits, available from this page has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image.
It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting.
Parameters: path (str (optional)) – default ($DATASET_PATH), the path to look for the data and where the data will be downloaded if not present Returns: - train_images (array)
- train_labels (array)
- valid_images (array)
- valid_labels (array)
- test_images (array)
- test_labels (array)
-
symjax.data.emnist.
load
(option='byclass', path=None)[source]¶ Grayscale digit/letter classification.
The EMNIST Dataset
Gregory Cohen, Saeed Afshar, Jonathan Tapson, and Andre van Schaik
The MARCS Institute for Brain, Behaviour and Development Western Sydney University Penrith, Australia 2751
Email: g.cohen@westernsydney.edu.au
The EMNIST dataset is a set of handwritten character digits derived from the NIST Special Database 19 (https://www.nist.gov/srd/nist-special-database-19) and converted to a 28x28 pixel image format and dataset structure that directly matches the MNIST dataset (http://yann.lecun.com/exdb/mnist/). Further information on the dataset contents and conversion process can be found in the paper available at https://arxiv.org/abs/1702.05373v1.
The dataset is provided in two file formats. Both versions of the dataset contain identical information, and are provided entirely for the sake of convenience. The first dataset is provided in a Matlab format that is accessible through both Matlab and Python (using the scipy.io.loadmat function). The second version of the dataset is provided in the same binary format as the original MNIST dataset as outlined in http://yann.lecun.com/exdb/mnist/
There are six different splits provided in this dataset. A short summary of the dataset is provided below:
EMNIST ByClass:EMNIST814,255 characters. 62 unbalanced classes EMNIST ByMerge: 814,255 characters. 47 unbalanced classes EMNIST Balanced:Balanced131,600 characters. 47 balanced classes. EMNIST Letters:EMNIST145,600 characters. 26 balanced classes. EMNIST Digits:EMNIST280,000 characters. 10 balanced classes. EMNIST MNIST:EMNIST 70,000 characters. 10 balanced classes.
The full complement of the NIST Special Database 19 is available in the ByClass and ByMerge splits. The EMNIST Balanced dataset contains a set of characters with an equal number of samples per class. The EMNIST Letters dataset merges a balanced set of the uppercase and lowercase letters into a single 26-class task. The EMNIST Digits and EMNIST MNIST dataset provide balanced handwritten digit datasets directly compatible with the original MNIST dataset.
Please refer to the EMNIST paper (available at https://arxiv.org/abs/1702.05373v1) for further details of the dataset structure.
Please cite the following paper when using or referencing the dataset:
Cohen, G., Afshar, S., Tapson, J., & van Schaik, A. (2017). EMNIST: an extension of MNIST to handwritten letters. Retrieved from http://arxiv.org/abs/1702.05373
The dataset consists of the following files:
. +– gzip.zip ¦ +– emnist-balanced-mapping.txt ¦ +– emnist-balanced-test-images-idx3-ubyte.gz ¦ +– emnist-balanced-test-labels-idx1-ubyte.gz ¦ +– emnist-balanced-train-images-idx3-ubyte.gz ¦ +– emnist-balanced-train-labels-idx1-ubyte.gz ¦ +– emnist-byclass-mapping.txt ¦ +– emnist-byclass-test-images-idx3-ubyte.gz ¦ +– emnist-byclass-test-labels-idx1-ubyte.gz ¦ +– emnist-byclass-train-images-idx3-ubyte.gz ¦ +– emnist-byclass-train-labels-idx1-ubyte.gz ¦ +– emnist-bymerge-mapping.txt ¦ +– emnist-bymerge-test-images-idx3-ubyte.gz ¦ +– emnist-bymerge-test-labels-idx1-ubyte.gz ¦ +– emnist-bymerge-train-images-idx3-ubyte.gz ¦ +– emnist-bymerge-train-labels-idx1-ubyte.gz ¦ +– emnist-digits-mapping.txt ¦ +– emnist-digits-test-images-idx3-ubyte.gz ¦ +– emnist-digits-test-labels-idx1-ubyte.gz ¦ +– emnist-digits-train-images-idx3-ubyte.gz ¦ +– emnist-digits-train-labels-idx1-ubyte.gz ¦ +– emnist-letters-mapping.txt ¦ +– emnist-letters-test-images-idx3-ubyte.gz ¦ +– emnist-letters-test-labels-idx1-ubyte.gz ¦ +– emnist-letters-train-images-idx3-ubyte.gz ¦ +– emnist-letters-train-labels-idx1-ubyte.gz ¦ +– emnist-mnist-mapping.txt ¦ +– emnist-mnist-test-images-idx3-ubyte.gz ¦ +– emnist-mnist-test-labels-idx1-ubyte.gz ¦ +– emnist-mnist-train-images-idx3-ubyte.gz ¦ +– emnist-mnist-train-labels-idx1-ubyte.gz +– matlab.zip
- +– emnist-balanced.mat
- +– emnist-byclass.mat
- +– emnist-bymerge.mat
- +– emnist-digits.mat
- +– emnist-letters.mat
- +– emnist-mnist.mat +– Readme.txt
-
symjax.data.fashionmnist.
load
(path=None)[source]¶ Grayscale image classification
Zalando ‘s article image classification. Fashion-MNIST is a dataset of Zalando ‘s article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. We intend Fashion-MNIST to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms. It shares the same image size and structure of training and testing splits.
-
symjax.data.dsprites.
load
(path=None)[source]¶ greyscale image classification and disentanglement
This dataset consists of 737,280 images of 2D shapes, procedurally generated from 5 ground truth independent latent factors, controlling the shape, scale, rotation and position of a sprite. This data can be used to assess the disentanglement properties of unsupervised learning methods.
dSprites is a dataset of 2D shapes procedurally generated from 6 ground truth independent latent factors. These factors are color, shape, scale, rotation, x and y positions of a sprite.
All possible combinations of these latents are present exactly once, generating N = 737280 total images.
https://github.com/deepmind/dsprites-dataset
- path: str (optional)
- default ($DATASET_PATH), the path to look for the data and where the data will be downloaded if not present
images: array
latent: array
classes: array
-
symjax.data.svhn.
load
(path=None)[source]¶ Street number classification.
The SVHN dataset is a real-world image dataset for developing machine learning and object recognition algorithms with minimal requirement on data preprocessing and formatting. It can be seen as similar in flavor to MNIST (e.g., the images are of small cropped digits), but incorporates an order of magnitude more labeled data (over 600,000 digit images) and comes from a significantly harder, unsolved, real world problem (recognizing digits and numbers in natural scene images). SVHN is obtained from house numbers in Google Street View images.
Parameters: path (str (optional)) – default $DATASET_PATH, the path to look for the data and where the data will be downloaded if not present Returns: - train_images (array)
- train_labels (array)
- test_images (array)
- test_labels (array)
-
symjax.data.cifar10.
load
(path=None)[source]¶ Image classification. The `CIFAR-10 < https: // www.cs.toronto.edu/~kriz/cifar.html >`_ dataset was collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. It consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class. :param path: default ($DATASET_PATH), the path to look for the data and
where the data will be downloaded if not presentReturns: - train_images (array)
- train_labels (array)
- test_images (array)
- test_labels (array)
-
symjax.data.cifar100.
load
(path=None)[source]¶ Image classification.
The `CIFAR-100 < https: // www.cs.toronto.edu/~kriz/cifar.html >`_ dataset is just like the CIFAR-10, except it has 100 classes containing 600 images each. There are 500 training images and 100 testing images per class. The 100 classes in the CIFAR-100 are grouped into 20 superclasses. Each image comes with a “fine” label(the class to which it belongs) and a “coarse” label(the superclass to which it belongs).
-
symjax.data.stl10.
load
(path=None)[source]¶ Image classification with extra unlabeled images.
The STL-10 dataset is an image recognition dataset for developing unsupervised feature learning, deep learning, self-taught learning algorithms. It is inspired by the CIFAR-10 dataset but with some modifications. In particular, each class has fewer labeled training examples than in CIFAR-10, but a very large set of unlabeled examples is provided to learn image models prior to supervised training. The primary challenge is to make use of the unlabeled data (which comes from a similar but different distribution from the labeled data) to build a useful prior. We also expect that the higher resolution of this dataset (96x96) will make it a challenging benchmark for developing more scalable unsupervised learning methods.
Parameters: path (str (optional)) – the path to look for the data and where it will be downloaded if not present Returns: - train_images (array) – the training images
- train_labels (array) – the training labels
- test_images (array) – the test images
- test_labels (array) – the test labels
- extra_images (array) – the unlabeled additional images
..autofunction:: symjax.data.audiomnist.load ..autofunction:: symjax.data.univariate_timeseries.load ..autofunction:: symjax.data.speech_commands.load ..autofunction:: symjax.data.picidae.load ..autofunction:: symjax.data.esc.load ..autofunction:: symjax.data.warblr.load ..autofunction:: symjax.data.gtzan.load ..autofunction:: symjax.data.dclde.load ..autofunction:: symjax.data.irmas.load ..autofunction:: symjax.data.vocalset.load ..autofunction:: symjax.data.freefield1010.load ..autofunction:: symjax.data.birdvox_70k.load ..autofunction:: symjax.data.birdvox_dcase_20k.load ..autofunction:: symjax.data.seizures_neonatal.load ..autofunction:: symjax.data.sonycust.load ..autofunction:: symjax.data.gtzan.load ..autofunction:: symjax.data.FSDKaggle2018.load ..autofunction:: symjax.data.TUTacousticscences2017.load
symjax.tensor
¶
Implements the NumPy API, using the primitives in jax.lax
.
As SymJAX follows the JAX restrictions, not all NumPy functins are present.
- Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays
in-place cannot be implemented in JAX. However, often JAX is able to provide a
alternative API that is purely functional. For example, instead of in-place
array updates (
x[i] = y
), JAX provides an alternative pure indexed update functionjax.ops.index_update()
. - NumPy is very aggressive at promoting values to
float64
type. JAX sometimes is less aggressive about type promotion.
Finally, since SymJAX uses jit-compilation, any function that returns
data-dependent output shapes are incompatible and thus not implemented.
In fact, The XLA compiler requires that shapes of arrays be known at
compile time. While it would be possible to provide. Thus an implementation of an API such as numpy.nonzero()
, we would be unable
to JIT-compile it because the shape of its output depends on the contents
of the input data.
Not every function in NumPy is implemented; contributions are welcome!
Numpy Ops¶
abs (x) |
Calculate the absolute value element-wise. |
absolute (x) |
Calculate the absolute value element-wise. |
add (x1, x2) |
Add arguments element-wise. |
all (a[, axis, out, keepdims]) |
Test whether all array elements along a given axis evaluate to True. |
allclose (a, b[, rtol, atol, equal_nan]) |
Returns True if two arrays are element-wise equal within a tolerance. |
alltrue (a[, axis, out, keepdims]) |
Test whether all array elements along a given axis evaluate to True. |
amax (a[, axis, out, keepdims, initial, where]) |
Return the maximum of an array or maximum along an axis. |
amin (a[, axis, out, keepdims, initial, where]) |
Return the minimum of an array or minimum along an axis. |
angle (z) |
Return the angle of the complex argument. |
any (a[, axis, out, keepdims]) |
Test whether any array element along a given axis evaluates to True. |
append (arr, values[, axis]) |
Append values to the end of an array. |
arange (start[, stop, step, dtype]) |
Return evenly spaced values within a given interval. |
arccos (x) |
Trigonometric inverse cosine, element-wise. |
arccosh (x) |
Inverse hyperbolic cosine, element-wise. |
arcsin (x) |
Inverse sine, element-wise. |
arcsinh (x) |
Inverse hyperbolic sine element-wise. |
arctan (x) |
Trigonometric inverse tangent, element-wise. |
arctan2 (x1, x2) |
Element-wise arc tangent of x1/x2 choosing the quadrant correctly. |
arctanh (x) |
Inverse hyperbolic tangent element-wise. |
argmax (a[, axis, out]) |
Returns the indices of the maximum values along an axis. |
argmin (a[, axis, out]) |
Returns the indices of the minimum values along an axis. |
argsort (a[, axis, kind, order]) |
Returns the indices that would sort an array. |
argwhere (a) |
Find the indices of array elements that are non-zero, grouped by element. |
around (a[, decimals, out]) |
Round an array to the given number of decimals. |
array (object[, dtype, copy, order, ndmin]) |
Create an array. |
array_repr |
|
array_str |
|
asarray (a[, dtype, order]) |
Convert the input to an array. |
atleast_1d (*arys) |
Convert inputs to arrays with at least one dimension. |
atleast_2d (*arys) |
View inputs as arrays with at least two dimensions. |
atleast_3d (*arys) |
View inputs as arrays with at least three dimensions. |
bartlett |
|
bincount (x[, weights, minlength, length]) |
Count number of occurrences of each value in array of non-negative ints. |
bitwise_and (x1, x2) |
Compute the bit-wise AND of two arrays element-wise. |
bitwise_not (x) |
Compute bit-wise inversion, or bit-wise NOT, element-wise. |
bitwise_or (x1, x2) |
Compute the bit-wise OR of two arrays element-wise. |
bitwise_xor (x1, x2) |
Compute the bit-wise XOR of two arrays element-wise. |
blackman |
|
block (arrays) |
Assemble an nd-array from nested lists of blocks. |
broadcast_arrays (*args) |
Like Numpy’s broadcast_arrays but doesn’t return views. |
broadcast_to (arr, shape) |
Broadcast an array to a new shape. |
can_cast (from_, to[, casting]) |
Returns True if cast between data types can occur according to the casting rule. |
ceil (x) |
Return the ceiling of the input, element-wise. |
clip (a[, a_min, a_max, out]) |
Clip (limit) the values in an array. |
column_stack (tup) |
Stack 1-D arrays as columns into a 2-D array. |
compress (condition, a[, axis, out]) |
Return selected slices of an array along given axis. |
concatenate (arrays[, axis]) |
Join a sequence of arrays along an existing axis. |
conj (x) |
Return the complex conjugate, element-wise. |
conjugate (x) |
Return the complex conjugate, element-wise. |
convolve (a, v[, mode, precision]) |
Returns the discrete, linear convolution of two one-dimensional sequences. |
copysign (x1, x2) |
Change the sign of x1 to that of x2, element-wise. |
corrcoef (x[, y, rowvar]) |
Return Pearson product-moment correlation coefficients. |
correlate (a, v[, mode, precision]) |
Cross-correlation of two 1-dimensional sequences. |
cos (x) |
Cosine element-wise. |
cosh (x) |
Hyperbolic cosine, element-wise. |
count_nonzero (a[, axis, keepdims]) |
Counts the number of non-zero values in the array a . |
cov (m[, y, rowvar, bias, ddof, fweights, …]) |
Estimate a covariance matrix, given data and weights. |
cross (a, b[, axisa, axisb, axisc, axis]) |
Return the cross product of two (arrays of) vectors. |
cumsum (a[, axis, dtype, out]) |
Return the cumulative sum of the elements along a given axis. |
cumprod (a[, axis, dtype, out]) |
Return the cumulative product of elements along a given axis. |
cumproduct (a[, axis, dtype, out]) |
Return the cumulative product of elements along a given axis. |
deg2rad (x) |
Convert angles from degrees to radians. |
degrees (x) |
Convert angles from radians to degrees. |
diag (v[, k]) |
Extract a diagonal or construct a diagonal array. |
diag_indices (n[, ndim]) |
Return the indices to access the main diagonal of an array. |
diag_indices_from (arr) |
Return the indices to access the main diagonal of an n-dimensional array. |
diagflat (v[, k]) |
Create a two-dimensional array with the flattened input as a diagonal. |
diagonal (a[, offset, axis1, axis2]) |
Return specified diagonals. |
digitize |
|
divide (x1, x2) |
Returns a true division of the inputs, element-wise. |
divmod (x1, x2) |
Return element-wise quotient and remainder simultaneously. |
dot (a, b, *[, precision]) |
Dot product of two arrays. |
dsplit (ary, indices_or_sections) |
Split array into multiple sub-arrays along the 3rd axis (depth). |
dstack (tup) |
Stack arrays in sequence depth wise (along third axis). |
ediff1d (ary[, to_end, to_begin]) |
The differences between consecutive elements of an array. |
einsum (*operands[, out, optimize, precision]) |
Evaluates the Einstein summation convention on the operands. |
equal (x1, x2) |
Return (x1 == x2) element-wise. |
empty (shape[, dtype]) |
Return a new array of given shape and type, filled with zeros. |
empty_like (a[, dtype, shape]) |
Return an array of zeros with the same shape and type as a given array. |
exp (x) |
Calculate the exponential of all elements in the input array. |
exp2 (x) |
Calculate 2**p for all p in the input array. |
expand_dims (a, axis, Tuple[int, …]]) |
Expand the shape of an array. |
expm1 (x) |
Calculate exp(x) - 1 for all elements in the array. |
extract (condition, arr) |
Return the elements of an array that satisfy some condition. |
eye (N[, M, k, dtype]) |
Return a 2-D array with ones on the diagonal and zeros elsewhere. |
fabs (x) |
Compute the absolute values element-wise. |
fix (x[, out]) |
Round to nearest integer towards zero. |
flatnonzero (a) |
Return indices that are non-zero in the flattened version of a. |
flip (m[, axis]) |
Reverse the order of elements in an array along the given axis. |
fliplr (m) |
Flip array in the left/right direction. |
flipud (m) |
Flip array in the up/down direction. |
float_power (x1, x2) |
First array elements raised to powers from second array, element-wise. |
floor (x) |
Return the floor of the input, element-wise. |
floor_divide (x1, x2) |
Return the largest integer smaller or equal to the division of the inputs. |
fmax (x1, x2) |
Element-wise maximum of array elements. |
fmin (x1, x2) |
Element-wise minimum of array elements. |
fmod (x1, x2) |
Return the element-wise remainder of division. |
frexp |
|
full (shape, fill_value[, dtype]) |
Return a new array of given shape and type, filled with fill_value. |
full_like (a, fill_value[, dtype, shape]) |
Return a full array with the same shape and type as a given array. |
gcd (x1, x2) |
Returns the greatest common divisor of |x1| and |x2| |
geomspace (start, stop[, num, endpoint, …]) |
Return numbers spaced evenly on a log scale (a geometric progression). |
greater (x1, x2) |
Return the truth value of (x1 > x2) element-wise. |
greater_equal (x1, x2) |
Return the truth value of (x1 >= x2) element-wise. |
hamming |
|
hanning |
|
heaviside (x1, x2) |
Compute the Heaviside step function. |
histogram (a[, bins, range, weights, density]) |
Compute the histogram of a set of data. |
histogram_bin_edges (a[, bins, range, weights]) |
Function to calculate only the edges of the bins used by the histogram |
hsplit (ary, indices_or_sections) |
Split an array into multiple sub-arrays horizontally (column-wise). |
hstack (tup) |
Stack arrays in sequence horizontally (column wise). |
hypot (x1, x2) |
Given the “legs” of a right triangle, return its hypotenuse. |
identity (n[, dtype]) |
Return the identity array. |
imag (val) |
Return the imaginary part of the complex argument. |
in1d (ar1, ar2[, assume_unique, invert]) |
Test whether each element of a 1-D array is also present in a second array. |
indices (dimensions[, dtype, sparse]) |
Return an array representing the indices of a grid. |
inner (a, b, *[, precision]) |
Inner product of two arrays. |
isclose (a, b[, rtol, atol, equal_nan]) |
Returns a boolean array where two arrays are element-wise equal within a |
iscomplex (x) |
Returns a bool array, where True if input element is complex. |
isfinite (x) |
Test element-wise for finiteness (not infinity or not Not a Number). |
isin (element, test_elements[, …]) |
Calculates element in test_elements, broadcasting over element only. |
isinf (x) |
Test element-wise for positive or negative infinity. |
isnan (x) |
Test element-wise for NaN and return result as a boolean array. |
isneginf (x[, out]) |
Test element-wise for negative infinity, return result as bool array. |
isposinf (x[, out]) |
Test element-wise for positive infinity, return result as bool array. |
isreal (x) |
Returns a bool array, where True if input element is real. |
isscalar (element) |
Returns True if the type of element is a scalar type. |
issubdtype (arg1, arg2) |
Returns True if first argument is a typecode lower/equal in type hierarchy. |
issubsctype (arg1, arg2) |
Determine if the first argument is a subclass of the second argument. |
ix_ (*args) |
Construct an open mesh from multiple sequences. |
kaiser |
|
kron (a, b) |
Kronecker product of two arrays. |
lcm (x1, x2) |
Returns the lowest common multiple of |x1| and |x2| |
ldexp (x1, x2) |
Returns x1 * 2**x2, element-wise. |
left_shift (x1, x2) |
Shift the bits of an integer to the left. |
less (x1, x2) |
Return the truth value of (x1 < x2) element-wise. |
less_equal (x1, x2) |
Return the truth value of (x1 =< x2) element-wise. |
linspace (start, stop[, num, endpoint, …]) |
Return evenly spaced numbers over a specified interval. |
log (x) |
Natural logarithm, element-wise. |
log10 (x) |
Return the base 10 logarithm of the input array, element-wise. |
log1p (x) |
Return the natural logarithm of one plus the input array, element-wise. |
log2 (x) |
Base-2 logarithm of x. |
logaddexp (x1, x2) |
Logarithm of the sum of exponentiations of the inputs. |
logaddexp2 (x1, x2) |
Logarithm of the sum of exponentiations of the inputs in base-2. |
logical_and (*args) |
Compute the truth value of x1 AND x2 element-wise. |
logical_not (*args) |
Compute the truth value of NOT x element-wise. |
logical_or (*args) |
Compute the truth value of x1 OR x2 element-wise. |
logical_xor (*args) |
Compute the truth value of x1 XOR x2, element-wise. |
logspace (start, stop[, num, endpoint, base, …]) |
Return numbers spaced evenly on a log scale. |
matmul (a, b, *[, precision]) |
Matrix product of two arrays. |
max (a[, axis, out, keepdims, initial, where]) |
Return the maximum of an array or maximum along an axis. |
maximum (x1, x2) |
Element-wise maximum of array elements. |
mean (a[, axis, dtype, out, keepdims]) |
Compute the arithmetic mean along the specified axis. |
median (a[, axis, out, overwrite_input, keepdims]) |
Compute the median along the specified axis. |
meshgrid (*args, **kwargs) |
Return coordinate matrices from coordinate vectors. |
min (a[, axis, out, keepdims, initial, where]) |
Return the minimum of an array or minimum along an axis. |
minimum (x1, x2) |
Element-wise minimum of array elements. |
mod (x1, x2) |
Return element-wise remainder of division. |
moveaxis (a, source, destination) |
Move axes of an array to new positions. |
msort (a) |
Return a copy of an array sorted along the first axis. |
multiply (x1, x2) |
Multiply arguments element-wise. |
nan_to_num (x[, copy, nan, posinf, neginf]) |
Replace NaN with zero and infinity with large finite numbers (default |
nanargmax (a[, axis]) |
Return the indices of the maximum values in the specified axis ignoring |
nanargmin (a[, axis]) |
Return the indices of the minimum values in the specified axis ignoring |
nancumprod (a[, axis, dtype, out]) |
Return the cumulative product of array elements over a given axis treating Not a |
nancumsum (a[, axis, dtype, out]) |
Return the cumulative sum of array elements over a given axis treating Not a |
nanmax (a[, axis, out, keepdims]) |
Return the maximum of an array or maximum along an axis, ignoring any |
nanmedian (a[, axis, out, overwrite_input, …]) |
Compute the median along the specified axis, while ignoring NaNs. |
nanmin (a[, axis, out, keepdims]) |
Return minimum of an array or minimum along an axis, ignoring any NaNs. |
nanpercentile (a, q[, axis, out, …]) |
Compute the qth percentile of the data along the specified axis, |
nanprod (a[, axis, dtype, out, keepdims]) |
Return the product of array elements over a given axis treating Not a |
nanquantile (a, q[, axis, out, …]) |
Compute the qth quantile of the data along the specified axis, |
nansum (a[, axis, dtype, out, keepdims]) |
Return the sum of array elements over a given axis treating Not a |
negative (x) |
Numerical negative, element-wise. |
nextafter (x1, x2) |
Return the next floating-point value after x1 towards x2, element-wise. |
nonzero (a) |
Return the indices of the elements that are non-zero. |
not_equal (x1, x2) |
Return (x1 != x2) element-wise. |
ones (shape[, dtype]) |
Return a new array of given shape and type, filled with ones. |
ones_like (input[, detach]) |
|
outer (a, b[, out]) |
Compute the outer product of two vectors. |
packbits |
|
pad (array, pad_width[, mode, …]) |
Pad an array. |
percentile (a, q[, axis, out, …]) |
Compute the q-th percentile of the data along the specified axis. |
polyadd (a1, a2) |
Find the sum of two polynomials. |
polyder (p[, m]) |
Return the derivative of the specified order of a polynomial. |
polymul (a1, a2, *[, trim_leading_zeros]) |
Find the product of two polynomials. |
polysub (a1, a2) |
Difference (subtraction) of two polynomials. |
polyval (p, x) |
Evaluate a polynomial at specific values. |
power (x1, x2) |
First array elements raised to powers from second array, element-wise. |
positive (x) |
Numerical positive, element-wise. |
prod (a[, axis, dtype, out, keepdims, …]) |
Return the product of array elements over a given axis. |
product (a[, axis, dtype, out, keepdims, …]) |
Return the product of array elements over a given axis. |
promote_types (a, b) |
Returns the type to which a binary operation should cast its arguments. |
ptp (a[, axis, out, keepdims]) |
Range of values (maximum - minimum) along an axis. |
quantile (a, q[, axis, out, overwrite_input, …]) |
Compute the q-th quantile of the data along the specified axis. |
rad2deg (x) |
Convert angles from radians to degrees. |
radians (x) |
Convert angles from degrees to radians. |
ravel (a[, order]) |
Return a contiguous flattened array. |
real (val) |
Return the real part of the complex argument. |
reciprocal (x) |
Return the reciprocal of the argument, element-wise. |
remainder (x1, x2) |
Return element-wise remainder of division. |
repeat (a, repeats[, axis, total_repeat_length]) |
Repeat elements of an array. |
reshape (a, newshape[, order]) |
Gives a new shape to an array without changing its data. |
result_type (*args) |
Returns the type that results from applying the NumPy |
right_shift (x1, x2) |
Shift the bits of an integer to the right. |
rint (x) |
Round elements of the array to the nearest integer. |
roll (a, shift[, axis]) |
Roll array elements along a given axis. |
rollaxis (a, axis[, start]) |
Roll the specified axis backwards, until it lies in a given position. |
roots (p, *[, strip_zeros]) |
Return the roots of a polynomial with coefficients given in p. |
rot90 (m[, k, axes]) |
Rotate an array by 90 degrees in the plane specified by axes. |
round (a[, decimals, out]) |
Round an array to the given number of decimals. |
row_stack (tup) |
Stack arrays in sequence vertically (row wise). |
searchsorted |
|
select (condlist, choicelist[, default]) |
Return an array drawn from elements in choicelist, depending on conditions. |
sign (x) |
Returns an element-wise indication of the sign of a number. |
signbit (x) |
Returns element-wise True where signbit is set (less than zero). |
sin (x) |
Trigonometric sine, element-wise. |
sinc (x) |
Return the sinc function. |
sinh (x) |
Hyperbolic sine, element-wise. |
sometrue (a[, axis, out, keepdims]) |
Test whether any array element along a given axis evaluates to True. |
sort (a[, axis, kind, order]) |
Return a sorted copy of an array. |
split (ary, indices_or_sections[, axis]) |
Split an array into multiple sub-arrays as views into ary. |
sqrt (x) |
Return the non-negative square-root of an array, element-wise. |
square (x) |
Return the element-wise square of the input. |
squeeze (a, axis, Tuple[int, …]] = None) |
Remove single-dimensional entries from the shape of an array. |
stack (arrays[, axis, out]) |
Join a sequence of arrays along a new axis. |
std (a[, axis, dtype, out, ddof, keepdims]) |
Compute the standard deviation along the specified axis. |
subtract (x1, x2) |
Subtract arguments, element-wise. |
sum (a[, axis, dtype, out, keepdims, …]) |
Sum of array elements over a given axis. |
swapaxes (a, axis1, axis2) |
Interchange two axes of an array. |
take (a, indices[, axis, out, mode]) |
Take elements from an array along an axis. |
take_along_axis (arr, indices, axis) |
Take values from the input array by matching 1d index and data slices. |
tan (x) |
Compute tangent element-wise. |
tanh (x) |
Compute hyperbolic tangent element-wise. |
tensordot (a, b[, axes, precision]) |
Compute tensor dot product along specified axes. |
tile (A, reps) |
Construct an array by repeating A the number of times given by reps. |
trace (a[, offset, axis1, axis2, dtype, out]) |
Return the sum along diagonals of the array. |
transpose (a[, axes]) |
Reverse or permute the axes of an array; returns the modified array. |
tri (N[, M, k, dtype]) |
An array with ones at and below the given diagonal and zeros elsewhere. |
tril (m[, k]) |
Lower triangle of an array. |
tril_indices (*args, **kwargs) |
Return the indices for the lower-triangle of an (n, m) array. |
tril_indices_from (arr[, k]) |
Return the indices for the lower-triangle of arr. |
triu (m[, k]) |
Upper triangle of an array. |
triu_indices (*args, **kwargs) |
Return the indices for the upper-triangle of an (n, m) array. |
triu_indices_from (arr[, k]) |
Return the indices for the upper-triangle of arr. |
true_divide (x1, x2) |
Returns a true division of the inputs, element-wise. |
trunc (x) |
Return the truncated value of the input, element-wise. |
unique |
|
unpackbits |
|
unravel_index (indices, shape) |
Converts a flat index or array of flat indices into a tuple |
unwrap (p[, discont, axis]) |
Unwrap by changing deltas between values to 2*pi complement. |
vander (x[, N, increasing]) |
Generate a Vandermonde matrix. |
var (a[, axis, dtype, out, ddof, keepdims]) |
Compute the variance along the specified axis. |
vdot (a, b, *[, precision]) |
Return the dot product of two vectors. |
vsplit (ary, indices_or_sections) |
Split an array into multiple sub-arrays vertically (row-wise). |
vstack (tup) |
Stack arrays in sequence vertically (row wise). |
where (condition[, x, y]) |
Return elements chosen from x or y depending on condition. |
zeros (shape[, dtype]) |
Return a new array of given shape and type, filled with zeros. |
zeros_like (input[, detach]) |
|
stop_gradient (x) |
Stops gradient computation. |
one_hot (i, N[, dtype]) |
Create a one-hot encoding of x of size k. |
dimshuffle (tensor, pattern) |
Reorder the dimensions of this variable, optionally inserting broadcasted dimensions. |
flatten (input) |
reshape the input into a vector |
flatten2d (input) |
reshape the input into a matrix |
flatten3d (input) |
reshape the input into a 3D-tensor |
flatten4d (input) |
reshape the input into a 4D-tensor |
Indexed Operations¶
index |
Helper object for building indexes for indexed update functions. |
index_update (x, idx, y[, …]) |
Pure equivalent of x[idx] = y . |
index_min (x, idx, y[, indices_are_sorted, …]) |
Pure equivalent of x[idx] = minimum(x[idx], y) . |
index_add (x, idx, y[, indices_are_sorted, …]) |
Pure equivalent of x[idx] += y . |
index_max (x, idx, y[, indices_are_sorted, …]) |
Pure equivalent of x[idx] = maximum(x[idx], y) . |
index_take (src, idxs, axes) |
|
index_in_dim (operand, index, axis, keepdims) |
Convenience wrapper around slice to perform int indexing. |
dynamic_slice_in_dim (operand, start_index, …) |
Convenience wrapper around dynamic_slice applying to one dimension. |
dynamic_slice (operand, start_indices, …) |
Wraps XLA’s DynamicSlice operator. |
dynamic_index_in_dim (operand, index, axis, …) |
Convenience wrapper around dynamic_slice to perform int indexing. |
Control flow Ops¶
cond (pred, true_fun, false_fun[, …]) |
conditional branch evaluation |
fori_loop |
|
map (f, sequences[, non_sequences]) |
Map a function over leading array axes. |
scan (f, init, sequences[, non_sequences, …]) |
Scan a function over leading array axes while carrying along state. |
while_loop (cond_fun, body_fun, sequences[, …]) |
Call body_fun repeatedly in a loop while cond_fun is True. |
Detailed Descriptions¶
-
symjax.tensor.
abs
(x)¶ Calculate the absolute value element-wise.
LAX-backend implementation of
absolute()
. Original docstring below.absolute(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
np.abs
is a shorthand for this function.Parameters: x (array_like) – Input array. Returns: absolute – An ndarray containing the absolute value of each element in x. For complex input, a + ib
, the absolute value is \(\sqrt{ a^2 + b^2 }\). This is a scalar if x is a scalar.Return type: ndarray Examples
>>> x = np.array([-1.2, 1.2]) >>> np.absolute(x) array([ 1.2, 1.2]) >>> np.absolute(1.2 + 1j) 1.5620499351813308
Plot the function over
[-10, 10]
:>>> import matplotlib.pyplot as plt
>>> x = np.linspace(start=-10, stop=10, num=101) >>> plt.plot(x, np.absolute(x)) >>> plt.show()
Plot the function over the complex plane:
>>> xx = x + 1j * x[:, np.newaxis] >>> plt.imshow(np.abs(xx), extent=[-10, 10, -10, 10], cmap='gray') >>> plt.show()
-
symjax.tensor.
absolute
(x)[source]¶ Calculate the absolute value element-wise.
LAX-backend implementation of
absolute()
. Original docstring below.absolute(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
np.abs
is a shorthand for this function.Parameters: x (array_like) – Input array. Returns: absolute – An ndarray containing the absolute value of each element in x. For complex input, a + ib
, the absolute value is \(\sqrt{ a^2 + b^2 }\). This is a scalar if x is a scalar.Return type: ndarray Examples
>>> x = np.array([-1.2, 1.2]) >>> np.absolute(x) array([ 1.2, 1.2]) >>> np.absolute(1.2 + 1j) 1.5620499351813308
Plot the function over
[-10, 10]
:>>> import matplotlib.pyplot as plt
>>> x = np.linspace(start=-10, stop=10, num=101) >>> plt.plot(x, np.absolute(x)) >>> plt.show()
Plot the function over the complex plane:
>>> xx = x + 1j * x[:, np.newaxis] >>> plt.imshow(np.abs(xx), extent=[-10, 10, -10, 10], cmap='gray') >>> plt.show()
-
symjax.tensor.
add
(x1, x2)¶ Add arguments element-wise.
LAX-backend implementation of
add()
. Original docstring below.add(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – The arrays to be added. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: add – The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. Return type: ndarray or scalar Notes
Equivalent to x1 + x2 in terms of array broadcasting.
Examples
>>> np.add(1.0, 4.0) 5.0 >>> x1 = np.arange(9.0).reshape((3, 3)) >>> x2 = np.arange(3.0) >>> np.add(x1, x2) array([[ 0., 2., 4.], [ 3., 5., 7.], [ 6., 8., 10.]])
-
symjax.tensor.
all
(a, axis=None, out=None, keepdims=None)[source]¶ Test whether all array elements along a given axis evaluate to True.
LAX-backend implementation of
all()
. Original docstring below.Parameters: - a (array_like) – Input array or object that can be converted to an array.
- axis (None or int or tuple of ints, optional) – Axis or axes along which a logical AND reduction is performed.
The default (
axis=None
) is to perform a logical AND over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis. - out (ndarray, optional) – Alternate output array in which to place the result.
It must have the same shape as the expected output and its
type is preserved (e.g., if
dtype(out)
is float, the result will consist of 0.0’s and 1.0’s). See ufuncs-output-type for more details. - keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
Returns: all – A new boolean or array is returned unless out is specified, in which case a reference to out is returned.
Return type: ndarray, bool
See also
ndarray.all()
- equivalent method
any()
- Test whether any element along a given axis evaluates to True.
Notes
Not a Number (NaN), positive infinity and negative infinity evaluate to True because these are not equal to zero.
Examples
>>> np.all([[True,False],[True,True]]) False
>>> np.all([[True,False],[True,True]], axis=0) array([ True, False])
>>> np.all([-1, 4, 5]) True
>>> np.all([1.0, np.nan]) True
>>> o=np.array(False) >>> z=np.all([-1, 4, 5], out=o) >>> id(z), id(o), z (28293632, 28293632, array(True)) # may vary
-
symjax.tensor.
allclose
(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)[source]¶ Returns True if two arrays are element-wise equal within a tolerance.
LAX-backend implementation of
allclose()
. Original docstring below.The tolerance values are positive, typically very small numbers. The relative difference (rtol * abs(b)) and the absolute difference atol are added together to compare against the absolute difference between a and b.
NaNs are treated as equal if they are in the same place and if
equal_nan=True
. Infs are treated as equal if they are in the same place and of the same sign in both arrays.Parameters: - b (a,) – Input arrays to compare.
- rtol (float) – The relative tolerance parameter (see Notes).
- atol (float) – The absolute tolerance parameter (see Notes).
- equal_nan (bool) – Whether to compare NaN’s as equal. If True, NaN’s in a will be considered equal to NaN’s in b in the output array.
Returns: allclose – Returns True if the two arrays are equal within the given tolerance; False otherwise.
Return type: bool
Notes
If the following equation is element-wise True, then allclose returns True.
absolute(a - b) <= (atol + rtol * absolute(b))The above equation is not symmetric in a and b, so that
allclose(a, b)
might be different fromallclose(b, a)
in some rare cases.The comparison of a and b uses standard broadcasting, which means that a and b need not have the same shape in order for
allclose(a, b)
to evaluate to True. The same is true for equal but not array_equal.Examples
>>> np.allclose([1e10,1e-7], [1.00001e10,1e-8]) False >>> np.allclose([1e10,1e-8], [1.00001e10,1e-9]) True >>> np.allclose([1e10,1e-8], [1.0001e10,1e-9]) False >>> np.allclose([1.0, np.nan], [1.0, np.nan]) False >>> np.allclose([1.0, np.nan], [1.0, np.nan], equal_nan=True) True
-
symjax.tensor.
alltrue
(a, axis=None, out=None, keepdims=None)¶ Test whether all array elements along a given axis evaluate to True.
LAX-backend implementation of
all()
. Original docstring below.Parameters: - a (array_like) – Input array or object that can be converted to an array.
- axis (None or int or tuple of ints, optional) – Axis or axes along which a logical AND reduction is performed.
The default (
axis=None
) is to perform a logical AND over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis. - out (ndarray, optional) – Alternate output array in which to place the result.
It must have the same shape as the expected output and its
type is preserved (e.g., if
dtype(out)
is float, the result will consist of 0.0’s and 1.0’s). See ufuncs-output-type for more details. - keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
Returns: all – A new boolean or array is returned unless out is specified, in which case a reference to out is returned.
Return type: ndarray, bool
See also
ndarray.all()
- equivalent method
any()
- Test whether any element along a given axis evaluates to True.
Notes
Not a Number (NaN), positive infinity and negative infinity evaluate to True because these are not equal to zero.
Examples
>>> np.all([[True,False],[True,True]]) False
>>> np.all([[True,False],[True,True]], axis=0) array([ True, False])
>>> np.all([-1, 4, 5]) True
>>> np.all([1.0, np.nan]) True
>>> o=np.array(False) >>> z=np.all([-1, 4, 5], out=o) >>> id(z), id(o), z (28293632, 28293632, array(True)) # may vary
-
symjax.tensor.
amax
(a, axis=None, out=None, keepdims=None, initial=None, where=None)¶ Return the maximum of an array or maximum along an axis.
LAX-backend implementation of
amax()
. Original docstring below.Parameters: - a (array_like) – Input data.
- axis (None or int or tuple of ints, optional) – Axis or axes along which to operate. By default, flattened input is used.
- out (ndarray, optional) – Alternative output array in which to place the result. Must be of the same shape and buffer length as the expected output. See ufuncs-output-type for more details.
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
- initial (scalar, optional) – The minimum value of an output element. Must be present to allow computation on empty slice. See ~numpy.ufunc.reduce for details.
- where (array_like of bool, optional) – Elements to compare for the maximum. See ~numpy.ufunc.reduce for details.
Returns: amax – Maximum of a. If axis is None, the result is a scalar value. If axis is given, the result is an array of dimension
a.ndim - 1
.Return type: ndarray or scalar
See also
amin()
- The minimum value of an array along a given axis, propagating any NaNs.
nanmax()
- The maximum value of an array along a given axis, ignoring any NaNs.
maximum()
- Element-wise maximum of two arrays, propagating any NaNs.
fmax()
- Element-wise maximum of two arrays, ignoring any NaNs.
argmax()
- Return the indices of the maximum values.
Notes
NaN values are propagated, that is if at least one item is NaN, the corresponding max value will be NaN as well. To ignore NaN values (MATLAB behavior), please use nanmax.
Don’t use amax for element-wise comparison of 2 arrays; when
a.shape[0]
is 2,maximum(a[0], a[1])
is faster thanamax(a, axis=0)
.Examples
>>> a = np.arange(4).reshape((2,2)) >>> a array([[0, 1], [2, 3]]) >>> np.amax(a) # Maximum of the flattened array 3 >>> np.amax(a, axis=0) # Maxima along the first axis array([2, 3]) >>> np.amax(a, axis=1) # Maxima along the second axis array([1, 3]) >>> np.amax(a, where=[False, True], initial=-1, axis=0) array([-1, 3]) >>> b = np.arange(5, dtype=float) >>> b[2] = np.NaN >>> np.amax(b) nan >>> np.amax(b, where=~np.isnan(b), initial=-1) 4.0 >>> np.nanmax(b) 4.0
You can use an initial value to compute the maximum of an empty slice, or to initialize it to a different value:
>>> np.max([[-50], [10]], axis=-1, initial=0) array([ 0, 10])
Notice that the initial value is used as one of the elements for which the maximum is determined, unlike for the default argument Python’s max function, which is only used for empty iterables.
>>> np.max([5], initial=6) 6 >>> max([5], default=6) 5
-
symjax.tensor.
amin
(a, axis=None, out=None, keepdims=None, initial=None, where=None)¶ Return the minimum of an array or minimum along an axis.
LAX-backend implementation of
amin()
. Original docstring below.Parameters: - a (array_like) – Input data.
- axis (None or int or tuple of ints, optional) – Axis or axes along which to operate. By default, flattened input is used.
- out (ndarray, optional) – Alternative output array in which to place the result. Must be of the same shape and buffer length as the expected output. See ufuncs-output-type for more details.
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
- initial (scalar, optional) – The maximum value of an output element. Must be present to allow computation on empty slice. See ~numpy.ufunc.reduce for details.
- where (array_like of bool, optional) – Elements to compare for the minimum. See ~numpy.ufunc.reduce for details.
Returns: amin – Minimum of a. If axis is None, the result is a scalar value. If axis is given, the result is an array of dimension
a.ndim - 1
.Return type: ndarray or scalar
See also
amax()
- The maximum value of an array along a given axis, propagating any NaNs.
nanmin()
- The minimum value of an array along a given axis, ignoring any NaNs.
minimum()
- Element-wise minimum of two arrays, propagating any NaNs.
fmin()
- Element-wise minimum of two arrays, ignoring any NaNs.
argmin()
- Return the indices of the minimum values.
Notes
NaN values are propagated, that is if at least one item is NaN, the corresponding min value will be NaN as well. To ignore NaN values (MATLAB behavior), please use nanmin.
Don’t use amin for element-wise comparison of 2 arrays; when
a.shape[0]
is 2,minimum(a[0], a[1])
is faster thanamin(a, axis=0)
.Examples
>>> a = np.arange(4).reshape((2,2)) >>> a array([[0, 1], [2, 3]]) >>> np.amin(a) # Minimum of the flattened array 0 >>> np.amin(a, axis=0) # Minima along the first axis array([0, 1]) >>> np.amin(a, axis=1) # Minima along the second axis array([0, 2]) >>> np.amin(a, where=[False, True], initial=10, axis=0) array([10, 1])
>>> b = np.arange(5, dtype=float) >>> b[2] = np.NaN >>> np.amin(b) nan >>> np.amin(b, where=~np.isnan(b), initial=10) 0.0 >>> np.nanmin(b) 0.0
>>> np.min([[-50], [10]], axis=-1, initial=0) array([-50, 0])
Notice that the initial value is used as one of the elements for which the minimum is determined, unlike for the default argument Python’s max function, which is only used for empty iterables.
Notice that this isn’t the same as Python’s
default
argument.>>> np.min([6], initial=5) 5 >>> min([6], default=5) 6
-
symjax.tensor.
angle
(z)[source]¶ Return the angle of the complex argument.
LAX-backend implementation of
angle()
. Original docstring below.Parameters: z (array_like) – A complex number or sequence of complex numbers. Returns: angle – The counterclockwise angle from the positive real axis on the complex plane in the range (-pi, pi]
, with dtype as numpy.float64.- ..versionchanged:: 1.16.0
- This function works on subclasses of ndarray like ma.array.
Return type: ndarray or scalar See also
Notes
Although the angle of the complex number 0 is undefined,
numpy.angle(0)
returns the value 0.Examples
>>> np.angle([1.0, 1.0j, 1+1j]) # in radians array([ 0. , 1.57079633, 0.78539816]) # may vary >>> np.angle(1+1j, deg=True) # in degrees 45.0
-
symjax.tensor.
any
(a, axis=None, out=None, keepdims=None)[source]¶ Test whether any array element along a given axis evaluates to True.
LAX-backend implementation of
any()
. Original docstring below.Returns single boolean unless axis is not
None
Parameters: - a (array_like) – Input array or object that can be converted to an array.
- axis (None or int or tuple of ints, optional) – Axis or axes along which a logical OR reduction is performed.
The default (
axis=None
) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis. - out (ndarray, optional) – Alternate output array in which to place the result. It must have the same shape as the expected output and its type is preserved (e.g., if it is of type float, then it will remain so, returning 1.0 for True and 0.0 for False, regardless of the type of a). See ufuncs-output-type for more details.
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
Returns: any – A new boolean or ndarray is returned unless out is specified, in which case a reference to out is returned.
Return type: bool or ndarray
See also
ndarray.any()
- equivalent method
all()
- Test whether all elements along a given axis evaluate to True.
Notes
Not a Number (NaN), positive infinity and negative infinity evaluate to True because these are not equal to zero.
Examples
>>> np.any([[True, False], [True, True]]) True
>>> np.any([[True, False], [False, False]], axis=0) array([ True, False])
>>> np.any([-1, 0, 5]) True
>>> np.any(np.nan) True
>>> o=np.array(False) >>> z=np.any([-1, 4, 5], out=o) >>> z, o (array(True), array(True)) >>> # Check now that z is a reference to o >>> z is o True >>> id(z), id(o) # identity of z and o # doctest: +SKIP (191614240, 191614240)
-
symjax.tensor.
append
(arr, values, axis=None)[source]¶ Append values to the end of an array.
LAX-backend implementation of
append()
. Original docstring below.Parameters: - arr (array_like) – Values are appended to a copy of this array.
- values (array_like) – These values are appended to a copy of arr. It must be of the correct shape (the same shape as arr, excluding axis). If axis is not specified, values can be any shape and will be flattened before use.
- axis (int, optional) – The axis along which values are appended. If axis is not given, both arr and values are flattened before use.
Returns: append – A copy of arr with values appended to axis. Note that append does not occur in-place: a new array is allocated and filled. If axis is None, out is a flattened array.
Return type: ndarray
See also
insert()
- Insert elements into an array.
delete()
- Delete elements from an array.
Examples
>>> np.append([1, 2, 3], [[4, 5, 6], [7, 8, 9]]) array([1, 2, 3, ..., 7, 8, 9])
When axis is specified, values must have the correct shape.
>>> np.append([[1, 2, 3], [4, 5, 6]], [[7, 8, 9]], axis=0) array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> np.append([[1, 2, 3], [4, 5, 6]], [7, 8, 9], axis=0) Traceback (most recent call last): ... ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 1 dimension(s)
-
symjax.tensor.
arange
(start, stop=None, step=None, dtype=None)[source]¶ Return evenly spaced values within a given interval.
LAX-backend implementation of
arange()
. Original docstring below.arange([start,] stop[, step,], dtype=None)
Values are generated within the half-open interval
[start, stop)
(in other words, the interval including start but excluding stop). For integer arguments the function is equivalent to the Python built-in range function, but returns an ndarray rather than a list.When using a non-integer step, such as 0.1, the results will often not be consistent. It is better to use numpy.linspace for these cases.
- Returns
- arange : ndarray
Array of evenly spaced values.
For floating point arguments, the length of the result is
ceil((stop - start)/step)
. Because of floating point overflow, this rule may result in the last element of out being greater than stop.
numpy.linspace : Evenly spaced numbers with careful handling of endpoints. numpy.ogrid: Arrays of evenly spaced numbers in N-dimensions. numpy.mgrid: Grid-shaped arrays of evenly spaced numbers in N-dimensions.
>>> np.arange(3) array([0, 1, 2]) >>> np.arange(3.0) array([ 0., 1., 2.]) >>> np.arange(3,7) array([3, 4, 5, 6]) >>> np.arange(3,7,2) array([3, 5])
-
symjax.tensor.
arccos
(x)¶ Trigonometric inverse cosine, element-wise.
LAX-backend implementation of
arccos()
. Original docstring below.arccos(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The inverse of cos so that, if
y = cos(x)
, thenx = arccos(y)
.Parameters: x (array_like) – x-coordinate on the unit circle. For real arguments, the domain is [-1, 1]. Returns: angle – The angle of the ray intersecting the unit circle at the given x-coordinate in radians [0, pi]. This is a scalar if x is a scalar. Return type: ndarray Notes
arccos is a multivalued function: for each x there are infinitely many numbers z such that cos(z) = x. The convention is to return the angle z whose real part lies in [0, pi].
For real-valued input data types, arccos always returns real output. For each value that cannot be expressed as a real number or infinity, it yields
nan
and sets the invalid floating point error flag.For complex-valued input, arccos is a complex analytic function that has branch cuts [-inf, -1] and [1, inf] and is continuous from above on the former and from below on the latter.
The inverse cos is also known as acos or cos^-1.
References
M. Abramowitz and I.A. Stegun, “Handbook of Mathematical Functions”, 10th printing, 1964, pp. 79. http://www.math.sfu.ca/~cbm/aands/
Examples
We expect the arccos of 1 to be 0, and of -1 to be pi:
>>> np.arccos([1, -1]) array([ 0. , 3.14159265])
Plot arccos:
>>> import matplotlib.pyplot as plt >>> x = np.linspace(-1, 1, num=100) >>> plt.plot(x, np.arccos(x)) >>> plt.axis('tight') >>> plt.show()
-
symjax.tensor.
arccosh
(x)¶ Inverse hyperbolic cosine, element-wise.
LAX-backend implementation of
arccosh()
. Original docstring below.arccosh(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input array. Returns: arccosh – Array of the same shape as x. This is a scalar if x is a scalar. Return type: ndarray Notes
arccosh is a multivalued function: for each x there are infinitely many numbers z such that cosh(z) = x. The convention is to return the z whose imaginary part lies in [-pi, pi] and the real part in
[0, inf]
.For real-valued input data types, arccosh always returns real output. For each value that cannot be expressed as a real number or infinity, it yields
nan
and sets the invalid floating point error flag.For complex-valued input, arccosh is a complex analytical function that has a branch cut [-inf, 1] and is continuous from above on it.
References
[1] M. Abramowitz and I.A. Stegun, “Handbook of Mathematical Functions”, 10th printing, 1964, pp. 86. http://www.math.sfu.ca/~cbm/aands/ [2] Wikipedia, “Inverse hyperbolic function”, https://en.wikipedia.org/wiki/Arccosh Examples
>>> np.arccosh([np.e, 10.0]) array([ 1.65745445, 2.99322285]) >>> np.arccosh(1) 0.0
-
symjax.tensor.
arcsin
(x)¶ Inverse sine, element-wise.
LAX-backend implementation of
arcsin()
. Original docstring below.arcsin(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – y-coordinate on the unit circle. Returns: angle – The inverse sine of each element in x, in radians and in the closed interval [-pi/2, pi/2]
. This is a scalar if x is a scalar.Return type: ndarray Notes
arcsin is a multivalued function: for each x there are infinitely many numbers z such that \(sin(z) = x\). The convention is to return the angle z whose real part lies in [-pi/2, pi/2].
For real-valued input data types, arcsin always returns real output. For each value that cannot be expressed as a real number or infinity, it yields
nan
and sets the invalid floating point error flag.For complex-valued input, arcsin is a complex analytic function that has, by convention, the branch cuts [-inf, -1] and [1, inf] and is continuous from above on the former and from below on the latter.
The inverse sine is also known as asin or sin^{-1}.
References
Abramowitz, M. and Stegun, I. A., Handbook of Mathematical Functions, 10th printing, New York: Dover, 1964, pp. 79ff. http://www.math.sfu.ca/~cbm/aands/
Examples
>>> np.arcsin(1) # pi/2 1.5707963267948966 >>> np.arcsin(-1) # -pi/2 -1.5707963267948966 >>> np.arcsin(0) 0.0
-
symjax.tensor.
arcsinh
(x)¶ Inverse hyperbolic sine element-wise.
LAX-backend implementation of
arcsinh()
. Original docstring below.arcsinh(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input array. Returns: out – Array of the same shape as x. This is a scalar if x is a scalar. Return type: ndarray or scalar Notes
arcsinh is a multivalued function: for each x there are infinitely many numbers z such that sinh(z) = x. The convention is to return the z whose imaginary part lies in [-pi/2, pi/2].
For real-valued input data types, arcsinh always returns real output. For each value that cannot be expressed as a real number or infinity, it returns
nan
and sets the invalid floating point error flag.For complex-valued input, arccos is a complex analytical function that has branch cuts [1j, infj] and [-1j, -infj] and is continuous from the right on the former and from the left on the latter.
The inverse hyperbolic sine is also known as asinh or
sinh^-1
.References
[1] M. Abramowitz and I.A. Stegun, “Handbook of Mathematical Functions”, 10th printing, 1964, pp. 86. http://www.math.sfu.ca/~cbm/aands/ [2] Wikipedia, “Inverse hyperbolic function”, https://en.wikipedia.org/wiki/Arcsinh Examples
>>> np.arcsinh(np.array([np.e, 10.0])) array([ 1.72538256, 2.99822295])
-
symjax.tensor.
arctan
(x)¶ Trigonometric inverse tangent, element-wise.
LAX-backend implementation of
arctan()
. Original docstring below.arctan(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The inverse of tan, so that if
y = tan(x)
thenx = arctan(y)
.Parameters: x (array_like) – Returns: out – Out has the same shape as x. Its real part is in [-pi/2, pi/2]
(arctan(+/-inf)
returns+/-pi/2
). This is a scalar if x is a scalar.Return type: ndarray or scalar See also
Notes
arctan is a multi-valued function: for each x there are infinitely many numbers z such that tan(z) = x. The convention is to return the angle z whose real part lies in [-pi/2, pi/2].
For real-valued input data types, arctan always returns real output. For each value that cannot be expressed as a real number or infinity, it yields
nan
and sets the invalid floating point error flag.For complex-valued input, arctan is a complex analytic function that has [1j, infj] and [-1j, -infj] as branch cuts, and is continuous from the left on the former and from the right on the latter.
The inverse tangent is also known as atan or tan^{-1}.
References
Abramowitz, M. and Stegun, I. A., Handbook of Mathematical Functions, 10th printing, New York: Dover, 1964, pp. 79. http://www.math.sfu.ca/~cbm/aands/
Examples
We expect the arctan of 0 to be 0, and of 1 to be pi/4:
>>> np.arctan([0, 1]) array([ 0. , 0.78539816])
>>> np.pi/4 0.78539816339744828
Plot arctan:
>>> import matplotlib.pyplot as plt >>> x = np.linspace(-10, 10) >>> plt.plot(x, np.arctan(x)) >>> plt.axis('tight') >>> plt.show()
-
symjax.tensor.
arctan2
(x1, x2)¶ Element-wise arc tangent of
x1/x2
choosing the quadrant correctly.LAX-backend implementation of
arctan2()
. Original docstring below.arctan2(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The quadrant (i.e., branch) is chosen so that
arctan2(x1, x2)
is the signed angle in radians between the ray ending at the origin and passing through the point (1,0), and the ray ending at the origin and passing through the point (x2, x1). (Note the role reversal: the “y-coordinate” is the first function parameter, the “x-coordinate” is the second.) By IEEE convention, this function is defined for x2 = +/-0 and for either or both of x1 and x2 = +/-inf (see Notes for specific values).This function is not defined for complex-valued arguments; for the so-called argument of complex values, use angle.
Parameters: - x1 (array_like, real-valued) – y-coordinates.
- x2 (array_like, real-valued) – x-coordinates.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: angle – Array of angles in radians, in the range
[-pi, pi]
. This is a scalar if both x1 and x2 are scalars.Return type: ndarray
Notes
arctan2 is identical to the atan2 function of the underlying C library. The following special values are defined in the C standard: [1]_
x1 x2 arctan2(x1,x2) +/- 0 +0 +/- 0 +/- 0 -0 +/- pi > 0 +/-inf +0 / +pi < 0 +/-inf -0 / -pi +/-inf +inf +/- (pi/4) +/-inf -inf +/- (3*pi/4) Note that +0 and -0 are distinct floating point numbers, as are +inf and -inf.
References
[1] ISO/IEC standard 9899:1999, “Programming language C.” Examples
Consider four points in different quadrants:
>>> x = np.array([-1, +1, +1, -1]) >>> y = np.array([-1, -1, +1, +1]) >>> np.arctan2(y, x) * 180 / np.pi array([-135., -45., 45., 135.])
Note the order of the parameters. arctan2 is defined also when x2 = 0 and at several other special points, obtaining values in the range
[-pi, pi]
:>>> np.arctan2([1., -1.], [0., 0.]) array([ 1.57079633, -1.57079633]) >>> np.arctan2([0., 0., np.inf], [+0., -0., np.inf]) array([ 0. , 3.14159265, 0.78539816])
-
symjax.tensor.
arctanh
(x)¶ Inverse hyperbolic tangent element-wise.
LAX-backend implementation of
arctanh()
. Original docstring below.arctanh(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input array. Returns: out – Array of the same shape as x. This is a scalar if x is a scalar. Return type: ndarray or scalar See also
emath.arctanh()
Notes
arctanh is a multivalued function: for each x there are infinitely many numbers z such that tanh(z) = x. The convention is to return the z whose imaginary part lies in [-pi/2, pi/2].
For real-valued input data types, arctanh always returns real output. For each value that cannot be expressed as a real number or infinity, it yields
nan
and sets the invalid floating point error flag.For complex-valued input, arctanh is a complex analytical function that has branch cuts [-1, -inf] and [1, inf] and is continuous from above on the former and from below on the latter.
The inverse hyperbolic tangent is also known as atanh or
tanh^-1
.References
[1] M. Abramowitz and I.A. Stegun, “Handbook of Mathematical Functions”, 10th printing, 1964, pp. 86. http://www.math.sfu.ca/~cbm/aands/ [2] Wikipedia, “Inverse hyperbolic function”, https://en.wikipedia.org/wiki/Arctanh Examples
>>> np.arctanh([0, -0.5]) array([ 0. , -0.54930614])
-
symjax.tensor.
argmax
(a, axis=None, out=None)[source]¶ Returns the indices of the maximum values along an axis.
LAX-backend implementation of
argmax()
. Original docstring below.Parameters: - a (array_like) – Input array.
- axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.
- out (array, optional) – If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
Returns: index_array – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.
Return type: ndarray of ints
See also
ndarray.argmax()
,argmin()
amax()
- The maximum value along a given axis.
unravel_index()
- Convert a flat index into an index tuple.
take_along_axis()
- Apply
np.expand_dims(index_array, axis)
from argmax to an array as if by calling max.
Notes
In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.
Examples
>>> a = np.arange(6).reshape(2,3) + 10 >>> a array([[10, 11, 12], [13, 14, 15]]) >>> np.argmax(a) 5 >>> np.argmax(a, axis=0) array([1, 1, 1]) >>> np.argmax(a, axis=1) array([2, 2])
Indexes of the maximal elements of a N-dimensional array:
>>> ind = np.unravel_index(np.argmax(a, axis=None), a.shape) >>> ind (1, 2) >>> a[ind] 15
>>> b = np.arange(6) >>> b[1] = 5 >>> b array([0, 5, 2, 3, 4, 5]) >>> np.argmax(b) # Only the first occurrence is returned. 1
>>> x = np.array([[4,2,3], [1,0,3]]) >>> index_array = np.argmax(x, axis=-1) >>> # Same as np.max(x, axis=-1, keepdims=True) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1) array([[4], [3]]) >>> # Same as np.max(x, axis=-1) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1) array([4, 3])
-
symjax.tensor.
argmin
(a, axis=None, out=None)[source]¶ Returns the indices of the minimum values along an axis.
LAX-backend implementation of
argmin()
. Original docstring below.Parameters: - a (array_like) – Input array.
- axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.
- out (array, optional) – If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
Returns: index_array – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.
Return type: ndarray of ints
See also
ndarray.argmin()
,argmax()
amin()
- The minimum value along a given axis.
unravel_index()
- Convert a flat index into an index tuple.
take_along_axis()
- Apply
np.expand_dims(index_array, axis)
from argmin to an array as if by calling min.
Notes
In case of multiple occurrences of the minimum values, the indices corresponding to the first occurrence are returned.
Examples
>>> a = np.arange(6).reshape(2,3) + 10 >>> a array([[10, 11, 12], [13, 14, 15]]) >>> np.argmin(a) 0 >>> np.argmin(a, axis=0) array([0, 0, 0]) >>> np.argmin(a, axis=1) array([0, 0])
Indices of the minimum elements of a N-dimensional array:
>>> ind = np.unravel_index(np.argmin(a, axis=None), a.shape) >>> ind (0, 0) >>> a[ind] 10
>>> b = np.arange(6) + 10 >>> b[4] = 10 >>> b array([10, 11, 12, 13, 10, 15]) >>> np.argmin(b) # Only the first occurrence is returned. 0
>>> x = np.array([[4,2,3], [1,0,3]]) >>> index_array = np.argmin(x, axis=-1) >>> # Same as np.min(x, axis=-1, keepdims=True) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1) array([[2], [0]]) >>> # Same as np.max(x, axis=-1) >>> np.take_along_axis(x, np.expand_dims(index_array, axis=-1), axis=-1).squeeze(axis=-1) array([2, 0])
-
symjax.tensor.
argsort
(a, axis=-1, kind='quicksort', order=None)[source]¶ Returns the indices that would sort an array.
LAX-backend implementation of
argsort()
. Original docstring below.Perform an indirect sort along the given axis using the algorithm specified by the kind keyword. It returns an array of indices of the same shape as a that index data along the given axis in sorted order.
Parameters: - a (array_like) – Array to sort.
- axis (int or None, optional) – Axis along which to sort. The default is -1 (the last axis). If None, the flattened array is used.
- kind ({'quicksort', 'mergesort', 'heapsort', 'stable'}, optional) – Sorting algorithm. The default is ‘quicksort’. Note that both ‘stable’ and ‘mergesort’ use timsort under the covers and, in general, the actual implementation will vary with data type. The ‘mergesort’ option is retained for backwards compatibility.
- order (str or list of str, optional) – When a is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string, and not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties.
Returns: index_array – Array of indices that sort a along the specified axis. If a is one-dimensional,
a[index_array]
yields a sorted a. More generally,np.take_along_axis(a, index_array, axis=axis)
always yields the sorted a, irrespective of dimensionality.Return type: ndarray, int
See also
sort()
- Describes sorting algorithms used.
lexsort()
- Indirect stable sort with multiple keys.
ndarray.sort()
- Inplace sort.
argpartition()
- Indirect partial sort.
take_along_axis()
- Apply
index_array
from argsort to an array as if by calling sort.
Notes
See sort for notes on the different sorting algorithms.
As of NumPy 1.4.0 argsort works with real/complex arrays containing nan values. The enhanced sort order is documented in sort.
Examples
One dimensional array:
>>> x = np.array([3, 1, 2]) >>> np.argsort(x) array([1, 2, 0])
Two-dimensional array:
>>> x = np.array([[0, 3], [2, 2]]) >>> x array([[0, 3], [2, 2]])
>>> ind = np.argsort(x, axis=0) # sorts along first axis (down) >>> ind array([[0, 1], [1, 0]]) >>> np.take_along_axis(x, ind, axis=0) # same as np.sort(x, axis=0) array([[0, 2], [2, 3]])
>>> ind = np.argsort(x, axis=1) # sorts along last axis (across) >>> ind array([[0, 1], [0, 1]]) >>> np.take_along_axis(x, ind, axis=1) # same as np.sort(x, axis=1) array([[0, 3], [2, 2]])
Indices of the sorted elements of a N-dimensional array:
>>> ind = np.unravel_index(np.argsort(x, axis=None), x.shape) >>> ind (array([0, 1, 1, 0]), array([0, 0, 1, 1])) >>> x[ind] # same as np.sort(x, axis=None) array([0, 2, 2, 3])
Sorting with keys:
>>> x = np.array([(1, 0), (0, 1)], dtype=[('x', '<i4'), ('y', '<i4')]) >>> x array([(1, 0), (0, 1)], dtype=[('x', '<i4'), ('y', '<i4')])
>>> np.argsort(x, order=('x','y')) array([1, 0])
>>> np.argsort(x, order=('y','x')) array([0, 1])
-
symjax.tensor.
around
(a, decimals=0, out=None)¶ Round an array to the given number of decimals.
LAX-backend implementation of
round_()
. Original docstring below.around : equivalent function; see for details.
-
symjax.tensor.
asarray
(a, dtype=None, order=None)[source]¶ Convert the input to an array.
LAX-backend implementation of
asarray()
. Original docstring below.Parameters: - a (array_like) – Input data, in any form that can be converted to an array. This includes lists, lists of tuples, tuples, tuples of tuples, tuples of lists and ndarrays.
- dtype (data-type, optional) – By default, the data-type is inferred from the input data.
- order ({'C', 'F'}, optional) – Whether to use row-major (C-style) or column-major (Fortran-style) memory representation. Defaults to ‘C’.
Returns: out – Array interpretation of a. No copy is performed if the input is already an ndarray with matching dtype and order. If a is a subclass of ndarray, a base class ndarray is returned.
Return type: ndarray
See also
asanyarray()
- Similar function which passes through subclasses.
ascontiguousarray()
- Convert input to a contiguous array.
asfarray()
- Convert input to a floating point ndarray.
asfortranarray()
- Convert input to an ndarray with column-major memory order.
asarray_chkfinite()
- Similar function which checks input for NaNs and Infs.
fromiter()
- Create an array from an iterator.
fromfunction()
- Construct an array by executing a function on grid positions.
Examples
Convert a list into an array:
>>> a = [1, 2] >>> np.asarray(a) array([1, 2])
Existing arrays are not copied:
>>> a = np.array([1, 2]) >>> np.asarray(a) is a True
If dtype is set, array is copied only if dtype does not match:
>>> a = np.array([1, 2], dtype=np.float32) >>> np.asarray(a, dtype=np.float32) is a True >>> np.asarray(a, dtype=np.float64) is a False
Contrary to asanyarray, ndarray subclasses are not passed through:
>>> issubclass(np.recarray, np.ndarray) True >>> a = np.array([(1.0, 2), (3.0, 4)], dtype='f4,i4').view(np.recarray) >>> np.asarray(a) is a False >>> np.asanyarray(a) is a True
-
symjax.tensor.
atleast_1d
(*arys)[source]¶ Convert inputs to arrays with at least one dimension.
LAX-backend implementation of
atleast_1d()
. Original docstring below.Scalar inputs are converted to 1-dimensional arrays, whilst higher-dimensional inputs are preserved.
- arys1, arys2, … : array_like
- One or more input arrays.
- ret : ndarray
- An array, or list of arrays, each with
a.ndim >= 1
. Copies are made only if necessary.
atleast_2d, atleast_3d
>>> np.atleast_1d(1.0) array([1.])
>>> x = np.arange(9.0).reshape(3,3) >>> np.atleast_1d(x) array([[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]) >>> np.atleast_1d(x) is x True
>>> np.atleast_1d(1, [3, 4]) [array([1]), array([3, 4])]
-
symjax.tensor.
atleast_2d
(*arys)[source]¶ View inputs as arrays with at least two dimensions.
LAX-backend implementation of
atleast_2d()
. Original docstring below.- arys1, arys2, … : array_like
- One or more array-like sequences. Non-array inputs are converted to arrays. Arrays that already have two or more dimensions are preserved.
- res, res2, … : ndarray
- An array, or list of arrays, each with
a.ndim >= 2
. Copies are avoided where possible, and views with two or more dimensions are returned.
atleast_1d, atleast_3d
>>> np.atleast_2d(3.0) array([[3.]])
>>> x = np.arange(3.0) >>> np.atleast_2d(x) array([[0., 1., 2.]]) >>> np.atleast_2d(x).base is x True
>>> np.atleast_2d(1, [1, 2], [[1, 2]]) [array([[1]]), array([[1, 2]]), array([[1, 2]])]
-
symjax.tensor.
atleast_3d
(*arys)[source]¶ View inputs as arrays with at least three dimensions.
LAX-backend implementation of
atleast_3d()
. Original docstring below.- arys1, arys2, … : array_like
- One or more array-like sequences. Non-array inputs are converted to arrays. Arrays that already have three or more dimensions are preserved.
- res1, res2, … : ndarray
- An array, or list of arrays, each with
a.ndim >= 3
. Copies are avoided where possible, and views with three or more dimensions are returned. For example, a 1-D array of shape(N,)
becomes a view of shape(1, N, 1)
, and a 2-D array of shape(M, N)
becomes a view of shape(M, N, 1)
.
atleast_1d, atleast_2d
>>> np.atleast_3d(3.0) array([[[3.]]])
>>> x = np.arange(3.0) >>> np.atleast_3d(x).shape (1, 3, 1)
>>> x = np.arange(12.0).reshape(4,3) >>> np.atleast_3d(x).shape (4, 3, 1) >>> np.atleast_3d(x).base is x.base # x is a reshape, so not base itself True
>>> for arr in np.atleast_3d([1, 2], [[1, 2]], [[[1, 2]]]): ... print(arr, arr.shape) # doctest: +SKIP ... [[[1] [2]]] (1, 2, 1) [[[1] [2]]] (1, 2, 1) [[[1 2]]] (1, 1, 2)
-
symjax.tensor.
bitwise_and
(x1, x2)¶ Compute the bit-wise AND of two arrays element-wise.
LAX-backend implementation of
bitwise_and()
. Original docstring below.bitwise_and(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Computes the bit-wise AND of the underlying binary representation of the integers in the input arrays. This ufunc implements the C/Python operator
&
.Parameters: x2 (x1,) – Only integer and boolean types are handled. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: out – Result. This is a scalar if both x1 and x2 are scalars. Return type: ndarray or scalar See also
logical_and()
,bitwise_or()
,bitwise_xor()
binary_repr()
- Return the binary representation of the input number as a string.
Examples
The number 13 is represented by
00001101
. Likewise, 17 is represented by00010001
. The bit-wise AND of 13 and 17 is therefore000000001
, or 1:>>> np.bitwise_and(13, 17) 1
>>> np.bitwise_and(14, 13) 12 >>> np.binary_repr(12) '1100' >>> np.bitwise_and([14,3], 13) array([12, 1])
>>> np.bitwise_and([11,7], [4,25]) array([0, 1]) >>> np.bitwise_and(np.array([2,5,255]), np.array([3,14,16])) array([ 2, 4, 16]) >>> np.bitwise_and([True, True], [False, True]) array([False, True])
-
symjax.tensor.
bitwise_not
(x)¶ Compute bit-wise inversion, or bit-wise NOT, element-wise.
LAX-backend implementation of
invert()
. Original docstring below.invert(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Computes the bit-wise NOT of the underlying binary representation of the integers in the input arrays. This ufunc implements the C/Python operator
~
.For signed integer inputs, the two’s complement is returned. In a two’s-complement system negative numbers are represented by the two’s complement of the absolute value. This is the most common method of representing signed integers on computers [1]_. A N-bit two’s-complement system can represent every integer in the range \(-2^{N-1}\) to \(+2^{N-1}-1\).
Parameters: x (array_like) – Only integer and boolean types are handled. Returns: out – Result. This is a scalar if x is a scalar. Return type: ndarray or scalar See also
bitwise_and()
,bitwise_or()
,bitwise_xor()
,logical_not()
binary_repr()
- Return the binary representation of the input number as a string.
Notes
bitwise_not is an alias for invert:
>>> np.bitwise_not is np.invert True
References
[1] Wikipedia, “Two’s complement”, https://en.wikipedia.org/wiki/Two’s_complement Examples
We’ve seen that 13 is represented by
00001101
. The invert or bit-wise NOT of 13 is then:>>> x = np.invert(np.array(13, dtype=np.uint8)) >>> x 242 >>> np.binary_repr(x, width=8) '11110010'
The result depends on the bit-width:
>>> x = np.invert(np.array(13, dtype=np.uint16)) >>> x 65522 >>> np.binary_repr(x, width=16) '1111111111110010'
When using signed integer types the result is the two’s complement of the result for the unsigned type:
>>> np.invert(np.array([13], dtype=np.int8)) array([-14], dtype=int8) >>> np.binary_repr(-14, width=8) '11110010'
Booleans are accepted as well:
>>> np.invert(np.array([True, False])) array([False, True])
-
symjax.tensor.
bitwise_or
(x1, x2)¶ Compute the bit-wise OR of two arrays element-wise.
LAX-backend implementation of
bitwise_or()
. Original docstring below.bitwise_or(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Computes the bit-wise OR of the underlying binary representation of the integers in the input arrays. This ufunc implements the C/Python operator
|
.Parameters: x2 (x1,) – Only integer and boolean types are handled. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: out – Result. This is a scalar if both x1 and x2 are scalars. Return type: ndarray or scalar See also
logical_or()
,bitwise_and()
,bitwise_xor()
binary_repr()
- Return the binary representation of the input number as a string.
Examples
The number 13 has the binaray representation
00001101
. Likewise, 16 is represented by00010000
. The bit-wise OR of 13 and 16 is then000111011
, or 29:>>> np.bitwise_or(13, 16) 29 >>> np.binary_repr(29) '11101'
>>> np.bitwise_or(32, 2) 34 >>> np.bitwise_or([33, 4], 1) array([33, 5]) >>> np.bitwise_or([33, 4], [1, 2]) array([33, 6])
>>> np.bitwise_or(np.array([2, 5, 255]), np.array([4, 4, 4])) array([ 6, 5, 255]) >>> np.array([2, 5, 255]) | np.array([4, 4, 4]) array([ 6, 5, 255]) >>> np.bitwise_or(np.array([2, 5, 255, 2147483647], dtype=np.int32), ... np.array([4, 4, 4, 2147483647], dtype=np.int32)) array([ 6, 5, 255, 2147483647]) >>> np.bitwise_or([True, True], [False, True]) array([ True, True])
-
symjax.tensor.
bitwise_xor
(x1, x2)¶ Compute the bit-wise XOR of two arrays element-wise.
LAX-backend implementation of
bitwise_xor()
. Original docstring below.bitwise_xor(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Computes the bit-wise XOR of the underlying binary representation of the integers in the input arrays. This ufunc implements the C/Python operator
^
.Parameters: x2 (x1,) – Only integer and boolean types are handled. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: out – Result. This is a scalar if both x1 and x2 are scalars. Return type: ndarray or scalar See also
logical_xor()
,bitwise_and()
,bitwise_or()
binary_repr()
- Return the binary representation of the input number as a string.
Examples
The number 13 is represented by
00001101
. Likewise, 17 is represented by00010001
. The bit-wise XOR of 13 and 17 is therefore00011100
, or 28:>>> np.bitwise_xor(13, 17) 28 >>> np.binary_repr(28) '11100'
>>> np.bitwise_xor(31, 5) 26 >>> np.bitwise_xor([31,3], 5) array([26, 6])
>>> np.bitwise_xor([31,3], [5,6]) array([26, 5]) >>> np.bitwise_xor([True, True], [False, True]) array([ True, False])
-
symjax.tensor.
block
(arrays)[source]¶ Assemble an nd-array from nested lists of blocks.
LAX-backend implementation of
block()
. Original docstring below.Blocks in the innermost lists are concatenated (see concatenate) along the last dimension (-1), then these are concatenated along the second-last dimension (-2), and so on until the outermost list is reached.
Blocks can be of any dimension, but will not be broadcasted using the normal rules. Instead, leading axes of size 1 are inserted, to make
block.ndim
the same for all blocks. This is primarily useful for working with scalars, and means that code likenp.block([v, 1])
is valid, wherev.ndim == 1
.When the nested list is two levels deep, this allows block matrices to be constructed from their components.
New in version 1.13.0.
Returns: block_array – The array assembled from the given blocks.
The dimensionality of the output is equal to the greatest of: * the dimensionality of all the inputs * the depth to which the input list is nested
Return type: ndarray
Raises: ValueError
– * If list depths are mismatched - for instance,[[a, b], c]
isillegal, and should be spelt
[[a, b], [c]]
- If lists are empty - for instance,
[[a, b], []]
See also
concatenate()
- Join a sequence of arrays along an existing axis.
stack()
- Join a sequence of arrays along a new axis.
vstack()
- Stack arrays in sequence vertically (row wise).
hstack()
- Stack arrays in sequence horizontally (column wise).
dstack()
- Stack arrays in sequence depth wise (along third axis).
column_stack()
- Stack 1-D arrays as columns into a 2-D array.
vsplit()
- Split an array into multiple sub-arrays vertically (row-wise).
Notes
When called with only scalars,
np.block
is equivalent to an ndarray call. Sonp.block([[1, 2], [3, 4]])
is equivalent tonp.array([[1, 2], [3, 4]])
.This function does not enforce that the blocks lie on a fixed grid.
np.block([[a, b], [c, d]])
is not restricted to arrays of the form:AAAbb AAAbb cccDD
But is also allowed to produce, for some
a, b, c, d
:AAAbb AAAbb cDDDD
Since concatenation happens along the last axis first, block is _not_ capable of producing the following directly:
AAAbb cccbb cccDD
Matlab’s “square bracket stacking”,
[A, B, ...; p, q, ...]
, is equivalent tonp.block([[A, B, ...], [p, q, ...]])
.Examples
The most common use of this function is to build a block matrix
>>> A = np.eye(2) * 2 >>> B = np.eye(3) * 3 >>> np.block([ ... [A, np.zeros((2, 3))], ... [np.ones((3, 2)), B ] ... ]) array([[2., 0., 0., 0., 0.], [0., 2., 0., 0., 0.], [1., 1., 3., 0., 0.], [1., 1., 0., 3., 0.], [1., 1., 0., 0., 3.]])
With a list of depth 1, block can be used as hstack
>>> np.block([1, 2, 3]) # hstack([1, 2, 3]) array([1, 2, 3])
>>> a = np.array([1, 2, 3]) >>> b = np.array([2, 3, 4]) >>> np.block([a, b, 10]) # hstack([a, b, 10]) array([ 1, 2, 3, 2, 3, 4, 10])
>>> A = np.ones((2, 2), int) >>> B = 2 * A >>> np.block([A, B]) # hstack([A, B]) array([[1, 1, 2, 2], [1, 1, 2, 2]])
With a list of depth 2, block can be used in place of vstack:
>>> a = np.array([1, 2, 3]) >>> b = np.array([2, 3, 4]) >>> np.block([[a], [b]]) # vstack([a, b]) array([[1, 2, 3], [2, 3, 4]])
>>> A = np.ones((2, 2), int) >>> B = 2 * A >>> np.block([[A], [B]]) # vstack([A, B]) array([[1, 1], [1, 1], [2, 2], [2, 2]])
It can also be used in places of atleast_1d and atleast_2d
>>> a = np.array(0) >>> b = np.array([1]) >>> np.block([a]) # atleast_1d(a) array([0]) >>> np.block([b]) # atleast_1d(b) array([1])
>>> np.block([[a]]) # atleast_2d(a) array([[0]]) >>> np.block([[b]]) # atleast_2d(b) array([[1]])
-
symjax.tensor.
broadcast_arrays
(*args)[source]¶ Like Numpy’s broadcast_arrays but doesn’t return views.
-
symjax.tensor.
broadcast_to
(arr, shape)[source]¶ Broadcast an array to a new shape.
LAX-backend implementation of
broadcast_to()
. The JAX version does not necessarily return a view of the input.Original docstring below.
Parameters: shape (tuple) – The shape of the desired array. Returns: broadcast – A readonly view on the original array with the given shape. It is typically not contiguous. Furthermore, more than one element of a broadcasted array may refer to a single memory location. Return type: array Raises: ValueError
– If the array is not compatible with the new shape according to NumPy’s broadcasting rules.Notes
New in version 1.10.0.
Examples
>>> x = np.array([1, 2, 3]) >>> np.broadcast_to(x, (3, 3)) array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
-
symjax.tensor.
can_cast
(from_, to, casting='safe')¶ Returns True if cast between data types can occur according to the casting rule. If from is a scalar or array scalar, also returns True if the scalar value can be cast without overflow or truncation to an integer.
Parameters: - from (dtype, dtype specifier, scalar, or array) – Data type, scalar, or array to cast from.
- to (dtype or dtype specifier) – Data type to cast to.
- casting ({'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional) –
Controls what kind of data casting may occur.
- ’no’ means the data types should not be cast at all.
- ’equiv’ means only byte-order changes are allowed.
- ’safe’ means only casts which can preserve values are allowed.
- ’same_kind’ means only safe casts or casts within a kind, like float64 to float32, are allowed.
- ’unsafe’ means any data conversions may be done.
Returns: out – True if cast can occur according to the casting rule.
Return type: bool
Notes
Changed in version 1.17.0: Casting between a simple data type and a structured one is possible only for “unsafe” casting. Casting to multiple fields is allowed, but casting from multiple fields is not.
Changed in version 1.9.0: Casting from numeric to string types in ‘safe’ casting mode requires that the string dtype length is long enough to store the maximum integer/float value converted.
See also
dtype()
,result_type()
Examples
Basic examples
>>> np.can_cast(np.int32, np.int64) True >>> np.can_cast(np.float64, complex) True >>> np.can_cast(complex, float) False
>>> np.can_cast('i8', 'f8') True >>> np.can_cast('i8', 'f4') False >>> np.can_cast('i4', 'S4') False
Casting scalars
>>> np.can_cast(100, 'i1') True >>> np.can_cast(150, 'i1') False >>> np.can_cast(150, 'u1') True
>>> np.can_cast(3.5e100, np.float32) False >>> np.can_cast(1000.0, np.float32) True
Array scalar checks the value, array does not
>>> np.can_cast(np.array(1000.0), np.float32) True >>> np.can_cast(np.array([1000.0]), np.float32) False
Using the casting rules
>>> np.can_cast('i8', 'i8', 'no') True >>> np.can_cast('<i8', '>i8', 'no') False
>>> np.can_cast('<i8', '>i8', 'equiv') True >>> np.can_cast('<i4', '>i8', 'equiv') False
>>> np.can_cast('<i4', '>i8', 'safe') True >>> np.can_cast('<i8', '>i4', 'safe') False
>>> np.can_cast('<i8', '>i4', 'same_kind') True >>> np.can_cast('<i8', '>u4', 'same_kind') False
>>> np.can_cast('<i8', '>u4', 'unsafe') True
-
symjax.tensor.
ceil
(x)¶ Return the ceiling of the input, element-wise.
LAX-backend implementation of
ceil()
. Original docstring below.ceil(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The ceil of the scalar x is the smallest integer i, such that i >= x. It is often denoted as \(\lceil x \rceil\).
Parameters: x (array_like) – Input data. Returns: y – The ceiling of each element in x, with float dtype. This is a scalar if x is a scalar. Return type: ndarray or scalar See also
floor()
,trunc()
,rint()
Examples
>>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) >>> np.ceil(a) array([-1., -1., -0., 1., 2., 2., 2.])
-
symjax.tensor.
clip
(a, a_min=None, a_max=None, out=None)[source]¶ Clip (limit) the values in an array.
LAX-backend implementation of
clip()
. Original docstring below.Given an interval, values outside the interval are clipped to the interval edges. For example, if an interval of
[0, 1]
is specified, values smaller than 0 become 0, and values larger than 1 become 1.Equivalent to but faster than
np.minimum(a_max, np.maximum(a, a_min))
.No check is performed to ensure
a_min < a_max
.Parameters: - a (array_like) – Array containing elements to clip.
- a_min (scalar or array_like or None) – Minimum value. If None, clipping is not performed on lower interval edge. Not more than one of a_min and a_max may be None.
- a_max (scalar or array_like or None) – Maximum value. If None, clipping is not performed on upper interval edge. Not more than one of a_min and a_max may be None. If a_min or a_max are array_like, then the three arrays will be broadcasted to match their shapes.
- out (ndarray, optional) – The results will be placed in this array. It may be the input array for in-place clipping. out must be of the right shape to hold the output. Its type is preserved.
Returns: clipped_array – An array with the elements of a, but where values < a_min are replaced with a_min, and those > a_max with a_max.
Return type: ndarray
See also
ufuncs-output-type()
Examples
>>> a = np.arange(10) >>> np.clip(a, 1, 8) array([1, 1, 2, 3, 4, 5, 6, 7, 8, 8]) >>> a array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> np.clip(a, 3, 6, out=a) array([3, 3, 3, 3, 4, 5, 6, 6, 6, 6]) >>> a = np.arange(10) >>> a array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> np.clip(a, [3, 4, 1, 1, 1, 4, 4, 4, 4, 4], 8) array([3, 4, 2, 3, 4, 5, 6, 7, 8, 8])
-
symjax.tensor.
column_stack
(tup)[source]¶ Stack 1-D arrays as columns into a 2-D array.
LAX-backend implementation of
column_stack()
. Original docstring below.Take a sequence of 1-D arrays and stack them as columns to make a single 2-D array. 2-D arrays are stacked as-is, just like with hstack. 1-D arrays are turned into 2-D columns first.
Parameters: tup (sequence of 1-D or 2-D arrays.) – Arrays to stack. All of them must have the same first dimension. Returns: stacked – The array formed by stacking the given arrays. Return type: 2-D array See also
Examples
>>> a = np.array((1,2,3)) >>> b = np.array((2,3,4)) >>> np.column_stack((a,b)) array([[1, 2], [2, 3], [3, 4]])
-
symjax.tensor.
concatenate
(arrays, axis=0)[source]¶ Join a sequence of arrays along an existing axis.
LAX-backend implementation of
concatenate()
. Original docstring below.concatenate((a1, a2, …), axis=0, out=None)- Returns
- res : ndarray
- The concatenated array.
ma.concatenate : Concatenate function that preserves input masks. array_split : Split an array into multiple sub-arrays of equal or
near-equal size.split : Split array into a list of multiple sub-arrays of equal size. hsplit : Split array into multiple sub-arrays horizontally (column wise). vsplit : Split array into multiple sub-arrays vertically (row wise). dsplit : Split array into multiple sub-arrays along the 3rd axis (depth). stack : Stack a sequence of arrays along a new axis. block : Assemble arrays from blocks. hstack : Stack arrays in sequence horizontally (column wise). vstack : Stack arrays in sequence vertically (row wise). dstack : Stack arrays in sequence depth wise (along third dimension). column_stack : Stack 1-D arrays as columns into a 2-D array.
When one or more of the arrays to be concatenated is a MaskedArray, this function will return a MaskedArray object instead of an ndarray, but the input masks are not preserved. In cases where a MaskedArray is expected as input, use the ma.concatenate function from the masked array module instead.
>>> a = np.array([[1, 2], [3, 4]]) >>> b = np.array([[5, 6]]) >>> np.concatenate((a, b), axis=0) array([[1, 2], [3, 4], [5, 6]]) >>> np.concatenate((a, b.T), axis=1) array([[1, 2, 5], [3, 4, 6]]) >>> np.concatenate((a, b), axis=None) array([1, 2, 3, 4, 5, 6])
This function will not preserve masking of MaskedArray inputs.
>>> a = np.ma.arange(3) >>> a[1] = np.ma.masked >>> b = np.arange(2, 5) >>> a masked_array(data=[0, --, 2], mask=[False, True, False], fill_value=999999) >>> b array([2, 3, 4]) >>> np.concatenate([a, b]) masked_array(data=[0, 1, 2, 2, 3, 4], mask=False, fill_value=999999) >>> np.ma.concatenate([a, b]) masked_array(data=[0, --, 2, 2, 3, 4], mask=[False, True, False, False, False, False], fill_value=999999)
-
symjax.tensor.
conj
(x)¶ Return the complex conjugate, element-wise.
LAX-backend implementation of
conjugate()
. Original docstring below.conjugate(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The complex conjugate of a complex number is obtained by changing the sign of its imaginary part.
Parameters: x (array_like) – Input value. Returns: y – The complex conjugate of x, with same dtype as y. This is a scalar if x is a scalar. Return type: ndarray Notes
conj is an alias for conjugate:
>>> np.conj is np.conjugate True
Examples
>>> np.conjugate(1+2j) (1-2j)
>>> x = np.eye(2) + 1j * np.eye(2) >>> np.conjugate(x) array([[ 1.-1.j, 0.-0.j], [ 0.-0.j, 1.-1.j]])
-
symjax.tensor.
conjugate
(x)[source]¶ Return the complex conjugate, element-wise.
LAX-backend implementation of
conjugate()
. Original docstring below.conjugate(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The complex conjugate of a complex number is obtained by changing the sign of its imaginary part.
Parameters: x (array_like) – Input value. Returns: y – The complex conjugate of x, with same dtype as y. This is a scalar if x is a scalar. Return type: ndarray Notes
conj is an alias for conjugate:
>>> np.conj is np.conjugate True
Examples
>>> np.conjugate(1+2j) (1-2j)
>>> x = np.eye(2) + 1j * np.eye(2) >>> np.conjugate(x) array([[ 1.-1.j, 0.-0.j], [ 0.-0.j, 1.-1.j]])
-
symjax.tensor.
corrcoef
(x, y=None, rowvar=True)[source]¶ Return Pearson product-moment correlation coefficients.
LAX-backend implementation of
corrcoef()
. Original docstring below.Please refer to the documentation for cov for more detail. The relationship between the correlation coefficient matrix, R, and the covariance matrix, C, is
\[R_{ij} = \frac{ C_{ij} } { \sqrt{ C_{ii} * C_{jj} } }\]The values of R are between -1 and 1, inclusive.
Parameters: - x (array_like) – A 1-D or 2-D array containing multiple variables and observations. Each row of x represents a variable, and each column a single observation of all those variables. Also see rowvar below.
- y (array_like, optional) – An additional set of variables and observations. y has the same shape as x.
- rowvar (bool, optional) – If rowvar is True (default), then each row represents a variable, with observations in the columns. Otherwise, the relationship is transposed: each column represents a variable, while the rows contain observations.
Returns: R – The correlation coefficient matrix of the variables.
Return type: ndarray
See also
cov()
- Covariance matrix
Notes
Due to floating point rounding the resulting array may not be Hermitian, the diagonal elements may not be 1, and the elements may not satisfy the inequality abs(a) <= 1. The real and imaginary parts are clipped to the interval [-1, 1] in an attempt to improve on that situation but is not much help in the complex case.
This function accepts but discards arguments bias and ddof. This is for backwards compatibility with previous versions of this function. These arguments had no effect on the return values of the function and can be safely ignored in this and previous versions of numpy.
-
symjax.tensor.
cos
(x)¶ Cosine element-wise.
LAX-backend implementation of
cos()
. Original docstring below.cos(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input array in radians. Returns: y – The corresponding cosine values. This is a scalar if x is a scalar. Return type: ndarray Notes
If out is provided, the function writes the result into it, and returns a reference to out. (See Examples)
References
M. Abramowitz and I. A. Stegun, Handbook of Mathematical Functions. New York, NY: Dover, 1972.
Examples
>>> np.cos(np.array([0, np.pi/2, np.pi])) array([ 1.00000000e+00, 6.12303177e-17, -1.00000000e+00]) >>> >>> # Example of providing the optional output parameter >>> out1 = np.array([0], dtype='d') >>> out2 = np.cos([0.1], out1) >>> out2 is out1 True >>> >>> # Example of ValueError due to provision of shape mis-matched `out` >>> np.cos(np.zeros((3,3)),np.zeros((2,2))) Traceback (most recent call last): File "<stdin>", line 1, in <module> ValueError: operands could not be broadcast together with shapes (3,3) (2,2)
-
symjax.tensor.
cosh
(x)¶ Hyperbolic cosine, element-wise.
LAX-backend implementation of
cosh()
. Original docstring below.cosh(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Equivalent to
1/2 * (np.exp(x) + np.exp(-x))
andnp.cos(1j*x)
.Parameters: x (array_like) – Input array. Returns: out – Output array of same shape as x. This is a scalar if x is a scalar. Return type: ndarray or scalar Examples
>>> np.cosh(0) 1.0
The hyperbolic cosine describes the shape of a hanging cable:
>>> import matplotlib.pyplot as plt >>> x = np.linspace(-4, 4, 1000) >>> plt.plot(x, np.cosh(x)) >>> plt.show()
-
symjax.tensor.
count_nonzero
(a, axis=None, keepdims=False)[source]¶ Counts the number of non-zero values in the array
a
.LAX-backend implementation of
count_nonzero()
. Original docstring below.The word “non-zero” is in reference to the Python 2.x built-in method
__nonzero__()
(renamed__bool__()
in Python 3.x) of Python objects that tests an object’s “truthfulness”. For example, any number is considered truthful if it is nonzero, whereas any string is considered truthful if it is not the empty string. Thus, this function (recursively) counts how many elements ina
(and in sub-arrays thereof) have their__nonzero__()
or__bool__()
method evaluated toTrue
.Parameters: - a (array_like) – The array for which to count non-zeros.
- axis (int or tuple, optional) – Axis or tuple of axes along which to count non-zeros.
Default is None, meaning that non-zeros will be counted
along a flattened version of
a
. - keepdims (bool, optional) – If this is set to True, the axes that are counted are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
Returns: count – Number of non-zero values in the array along a given axis. Otherwise, the total number of non-zero values in the array is returned.
Return type: int or array of int
See also
nonzero()
- Return the coordinates of all the non-zero values.
Examples
>>> np.count_nonzero(np.eye(4)) 4 >>> a = np.array([[0, 1, 7, 0], ... [3, 0, 2, 19]]) >>> np.count_nonzero(a) 5 >>> np.count_nonzero(a, axis=0) array([1, 1, 2, 1]) >>> np.count_nonzero(a, axis=1) array([2, 3]) >>> np.count_nonzero(a, axis=1, keepdims=True) array([[2], [3]])
-
symjax.tensor.
cov
(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None)[source]¶ Estimate a covariance matrix, given data and weights.
LAX-backend implementation of
cov()
. Original docstring below.Covariance indicates the level to which two variables vary together. If we examine N-dimensional samples, \(X = [x_1, x_2, ... x_N]^T\), then the covariance matrix element \(C_{ij}\) is the covariance of \(x_i\) and \(x_j\). The element \(C_{ii}\) is the variance of \(x_i\).
See the notes for an outline of the algorithm.
Parameters: - m (array_like) – A 1-D or 2-D array containing multiple variables and observations. Each row of m represents a variable, and each column a single observation of all those variables. Also see rowvar below.
- y (array_like, optional) – An additional set of variables and observations. y has the same form as that of m.
- rowvar (bool, optional) – If rowvar is True (default), then each row represents a variable, with observations in the columns. Otherwise, the relationship is transposed: each column represents a variable, while the rows contain observations.
- bias (bool, optional) – Default normalization (False) is by
(N - 1)
, whereN
is the number of observations given (unbiased estimate). If bias is True, then normalization is byN
. These values can be overridden by using the keywordddof
in numpy versions >= 1.5. - ddof (int, optional) – If not
None
the default value implied by bias is overridden. Note thatddof=1
will return the unbiased estimate, even if both fweights and aweights are specified, andddof=0
will return the simple average. See the notes for the details. The default value isNone
. - fweights (array_like, int, optional) – 1-D array of integer frequency weights; the number of times each observation vector should be repeated.
- aweights (array_like, optional) – 1-D array of observation vector weights. These relative weights are
typically large for observations considered “important” and smaller for
observations considered less “important”. If
ddof=0
the array of weights can be used to assign probabilities to observation vectors.
Returns: out – The covariance matrix of the variables.
Return type: ndarray
See also
corrcoef()
- Normalized covariance matrix
Notes
Assume that the observations are in the columns of the observation array m and let
f = fweights
anda = aweights
for brevity. The steps to compute the weighted covariance are as follows:>>> m = np.arange(10, dtype=np.float64) >>> f = np.arange(10) * 2 >>> a = np.arange(10) ** 2. >>> ddof = 1 >>> w = f * a >>> v1 = np.sum(w) >>> v2 = np.sum(w * a) >>> m -= np.sum(m * w, axis=None, keepdims=True) / v1 >>> cov = np.dot(m * w, m.T) * v1 / (v1**2 - ddof * v2)
Note that when
a == 1
, the normalization factorv1 / (v1**2 - ddof * v2)
goes over to1 / (np.sum(f) - ddof)
as it should.Examples
Consider two variables, \(x_0\) and \(x_1\), which correlate perfectly, but in opposite directions:
>>> x = np.array([[0, 2], [1, 1], [2, 0]]).T >>> x array([[0, 1, 2], [2, 1, 0]])
Note how \(x_0\) increases while \(x_1\) decreases. The covariance matrix shows this clearly:
>>> np.cov(x) array([[ 1., -1.], [-1., 1.]])
Note that element \(C_{0,1}\), which shows the correlation between \(x_0\) and \(x_1\), is negative.
Further, note how x and y are combined:
>>> x = [-2.1, -1, 4.3] >>> y = [3, 1.1, 0.12] >>> X = np.stack((x, y), axis=0) >>> np.cov(X) array([[11.71 , -4.286 ], # may vary [-4.286 , 2.144133]]) >>> np.cov(x, y) array([[11.71 , -4.286 ], # may vary [-4.286 , 2.144133]]) >>> np.cov(x) array(11.71)
-
symjax.tensor.
cross
(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)[source]¶ Return the cross product of two (arrays of) vectors.
LAX-backend implementation of
cross()
. Original docstring below.The cross product of a and b in \(R^3\) is a vector perpendicular to both a and b. If a and b are arrays of vectors, the vectors are defined by the last axis of a and b by default, and these axes can have dimensions 2 or 3. Where the dimension of either a or b is 2, the third component of the input vector is assumed to be zero and the cross product calculated accordingly. In cases where both input vectors have dimension 2, the z-component of the cross product is returned.
Parameters: - a (array_like) – Components of the first vector(s).
- b (array_like) – Components of the second vector(s).
- axisa (int, optional) – Axis of a that defines the vector(s). By default, the last axis.
- axisb (int, optional) – Axis of b that defines the vector(s). By default, the last axis.
- axisc (int, optional) – Axis of c containing the cross product vector(s). Ignored if both input vectors have dimension 2, as the return is scalar. By default, the last axis.
- axis (int, optional) – If defined, the axis of a, b and c that defines the vector(s) and cross product(s). Overrides axisa, axisb and axisc.
Returns: c – Vector cross product(s).
Return type: ndarray
Raises: ValueError
– When the dimension of the vector(s) in a and/or b does not equal 2 or 3.Notes
New in version 1.9.0.
Supports full broadcasting of the inputs.
Examples
Vector cross-product.
>>> x = [1, 2, 3] >>> y = [4, 5, 6] >>> np.cross(x, y) array([-3, 6, -3])
One vector with dimension 2.
>>> x = [1, 2] >>> y = [4, 5, 6] >>> np.cross(x, y) array([12, -6, -3])
Equivalently:
>>> x = [1, 2, 0] >>> y = [4, 5, 6] >>> np.cross(x, y) array([12, -6, -3])
Both vectors with dimension 2.
>>> x = [1,2] >>> y = [4,5] >>> np.cross(x, y) array(-3)
Multiple vector cross-products. Note that the direction of the cross product vector is defined by the right-hand rule.
>>> x = np.array([[1,2,3], [4,5,6]]) >>> y = np.array([[4,5,6], [1,2,3]]) >>> np.cross(x, y) array([[-3, 6, -3], [ 3, -6, 3]])
The orientation of c can be changed using the axisc keyword.
>>> np.cross(x, y, axisc=0) array([[-3, 3], [ 6, -6], [-3, 3]])
Change the vector definition of x and y using axisa and axisb.
>>> x = np.array([[1,2,3], [4,5,6], [7, 8, 9]]) >>> y = np.array([[7, 8, 9], [4,5,6], [1,2,3]]) >>> np.cross(x, y) array([[ -6, 12, -6], [ 0, 0, 0], [ 6, -12, 6]]) >>> np.cross(x, y, axisa=0, axisb=0) array([[-24, 48, -24], [-30, 60, -30], [-36, 72, -36]])
-
symjax.tensor.
cumsum
(a, axis=None, dtype=None, out=None)¶ Return the cumulative sum of the elements along a given axis.
LAX-backend implementation of
cumsum()
. Original docstring below.Parameters: - a (array_like) – Input array.
- axis (int, optional) – Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.
- dtype (dtype, optional) – Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output but the type will be cast if necessary. See ufuncs-output-type for more details.
Returns: cumsum_along_axis – A new array holding the result is returned unless out is specified, in which case a reference to out is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.
Return type: ndarray.
See also
sum()
- Sum array elements.
trapz()
- Integration of array values using the composite trapezoidal rule.
diff()
- Calculate the n-th discrete difference along given axis.
Notes
Arithmetic is modular when using integer types, and no error is raised on overflow.
Examples
>>> a = np.array([[1,2,3], [4,5,6]]) >>> a array([[1, 2, 3], [4, 5, 6]]) >>> np.cumsum(a) array([ 1, 3, 6, 10, 15, 21]) >>> np.cumsum(a, dtype=float) # specifies type of output value(s) array([ 1., 3., 6., 10., 15., 21.])
>>> np.cumsum(a,axis=0) # sum over rows for each of the 3 columns array([[1, 2, 3], [5, 7, 9]]) >>> np.cumsum(a,axis=1) # sum over columns for each of the 2 rows array([[ 1, 3, 6], [ 4, 9, 15]])
-
symjax.tensor.
cumprod
(a, axis=None, dtype=None, out=None)¶ Return the cumulative product of elements along a given axis.
LAX-backend implementation of
cumprod()
. Original docstring below.Parameters: - a (array_like) – Input array.
- axis (int, optional) – Axis along which the cumulative product is computed. By default the input is flattened.
- dtype (dtype, optional) – Type of the returned array, as well as of the accumulator in which the elements are multiplied. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used instead.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output but the type of the resulting values will be cast if necessary.
Returns: cumprod – A new array holding the result is returned unless out is specified, in which case a reference to out is returned.
Return type: ndarray
See also
ufuncs-output-type()
Notes
Arithmetic is modular when using integer types, and no error is raised on overflow.
Examples
>>> a = np.array([1,2,3]) >>> np.cumprod(a) # intermediate results 1, 1*2 ... # total product 1*2*3 = 6 array([1, 2, 6]) >>> a = np.array([[1, 2, 3], [4, 5, 6]]) >>> np.cumprod(a, dtype=float) # specify type of output array([ 1., 2., 6., 24., 120., 720.])
The cumulative product for each column (i.e., over the rows) of a:
>>> np.cumprod(a, axis=0) array([[ 1, 2, 3], [ 4, 10, 18]])
The cumulative product for each row (i.e. over the columns) of a:
>>> np.cumprod(a,axis=1) array([[ 1, 2, 6], [ 4, 20, 120]])
-
symjax.tensor.
cumproduct
(a, axis=None, dtype=None, out=None)¶ Return the cumulative product of elements along a given axis.
LAX-backend implementation of
cumprod()
. Original docstring below.Parameters: - a (array_like) – Input array.
- axis (int, optional) – Axis along which the cumulative product is computed. By default the input is flattened.
- dtype (dtype, optional) – Type of the returned array, as well as of the accumulator in which the elements are multiplied. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used instead.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output but the type of the resulting values will be cast if necessary.
Returns: cumprod – A new array holding the result is returned unless out is specified, in which case a reference to out is returned.
Return type: ndarray
See also
ufuncs-output-type()
Notes
Arithmetic is modular when using integer types, and no error is raised on overflow.
Examples
>>> a = np.array([1,2,3]) >>> np.cumprod(a) # intermediate results 1, 1*2 ... # total product 1*2*3 = 6 array([1, 2, 6]) >>> a = np.array([[1, 2, 3], [4, 5, 6]]) >>> np.cumprod(a, dtype=float) # specify type of output array([ 1., 2., 6., 24., 120., 720.])
The cumulative product for each column (i.e., over the rows) of a:
>>> np.cumprod(a, axis=0) array([[ 1, 2, 3], [ 4, 10, 18]])
The cumulative product for each row (i.e. over the columns) of a:
>>> np.cumprod(a,axis=1) array([[ 1, 2, 6], [ 4, 20, 120]])
-
symjax.tensor.
deg2rad
(x)[source]¶ Convert angles from degrees to radians.
LAX-backend implementation of
deg2rad()
. Original docstring below.deg2rad(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Angles in degrees. Returns: y – The corresponding angle in radians. This is a scalar if x is a scalar. Return type: ndarray See also
rad2deg()
- Convert angles from radians to degrees.
unwrap()
- Remove large jumps in angle by wrapping.
Notes
New in version 1.3.0.
deg2rad(x)
isx * pi / 180
.Examples
>>> np.deg2rad(180) 3.1415926535897931
-
symjax.tensor.
degrees
(x)¶ Convert angles from radians to degrees.
LAX-backend implementation of
rad2deg()
. Original docstring below.rad2deg(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Angle in radians. Returns: y – The corresponding angle in degrees. This is a scalar if x is a scalar. Return type: ndarray See also
deg2rad()
- Convert angles from degrees to radians.
unwrap()
- Remove large jumps in angle by wrapping.
Notes
New in version 1.3.0.
rad2deg(x) is
180 * x / pi
.Examples
>>> np.rad2deg(np.pi/2) 90.0
-
symjax.tensor.
diag
(v, k=0)[source]¶ Extract a diagonal or construct a diagonal array.
LAX-backend implementation of
diag()
. Original docstring below.See the more detailed documentation for
numpy.diagonal
if you use this function to extract a diagonal and wish to write to the resulting array; whether it returns a copy or a view depends on what version of numpy you are using.Parameters: - v (array_like) – If v is a 2-D array, return a copy of its k-th diagonal. If v is a 1-D array, return a 2-D array with v on the k-th diagonal.
- k (int, optional) – Diagonal in question. The default is 0. Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal.
Returns: out – The extracted diagonal or constructed diagonal array.
Return type: ndarray
See also
diagonal()
- Return specified diagonals.
diagflat()
- Create a 2-D array with the flattened input as a diagonal.
trace()
- Sum along diagonals.
triu()
- Upper triangle of an array.
tril()
- Lower triangle of an array.
Examples
>>> x = np.arange(9).reshape((3,3)) >>> x array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
>>> np.diag(x) array([0, 4, 8]) >>> np.diag(x, k=1) array([1, 5]) >>> np.diag(x, k=-1) array([3, 7])
>>> np.diag(np.diag(x)) array([[0, 0, 0], [0, 4, 0], [0, 0, 8]])
-
symjax.tensor.
diag_indices
(n, ndim=2)[source]¶ Return the indices to access the main diagonal of an array.
LAX-backend implementation of
diag_indices()
. Original docstring below.This returns a tuple of indices that can be used to access the main diagonal of an array a with
a.ndim >= 2
dimensions and shape (n, n, …, n). Fora.ndim = 2
this is the usual diagonal, fora.ndim > 2
this is the set of indices to accessa[i, i, ..., i]
fori = [0..n-1]
.Parameters: - n (int) –
- ndim (int, optional)) –
-
symjax.tensor.
diagonal
(a, offset=0, axis1=0, axis2=1)[source]¶ Return specified diagonals.
LAX-backend implementation of
diagonal()
. Original docstring below.If a is 2-D, returns the diagonal of a with the given offset, i.e., the collection of elements of the form
a[i, i+offset]
. If a has more than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D sub-array whose diagonal is returned. The shape of the resulting array can be determined by removing axis1 and axis2 and appending an index to the right equal to the size of the resulting diagonals.In versions of NumPy prior to 1.7, this function always returned a new, independent array containing a copy of the values in the diagonal.
In NumPy 1.7 and 1.8, it continues to return a copy of the diagonal, but depending on this fact is deprecated. Writing to the resulting array continues to work as it used to, but a FutureWarning is issued.
Starting in NumPy 1.9 it returns a read-only view on the original array. Attempting to write to the resulting array will produce an error.
In some future release, it will return a read/write view and writing to the returned array will alter your original array. The returned array will have the same type as the input array.
If you don’t write to the array returned by this function, then you can just ignore all of the above.
If you depend on the current behavior, then we suggest copying the returned array explicitly, i.e., use
np.diagonal(a).copy()
instead of justnp.diagonal(a)
. This will work with both past and future versions of NumPy.Parameters: - a (array_like) – Array from which the diagonals are taken.
- offset (int, optional) – Offset of the diagonal from the main diagonal. Can be positive or negative. Defaults to main diagonal (0).
- axis1 (int, optional) – Axis to be used as the first axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to first axis (0).
- axis2 (int, optional) – Axis to be used as the second axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to second axis (1).
Returns: array_of_diagonals – If a is 2-D, then a 1-D array containing the diagonal and of the same type as a is returned unless a is a matrix, in which case a 1-D array rather than a (2-D) matrix is returned in order to maintain backward compatibility.
If
a.ndim > 2
, then the dimensions specified by axis1 and axis2 are removed, and a new axis inserted at the end corresponding to the diagonal.Return type: ndarray
Raises: ValueError
– If the dimension of a is less than 2.See also
Examples
>>> a = np.arange(4).reshape(2,2) >>> a array([[0, 1], [2, 3]]) >>> a.diagonal() array([0, 3]) >>> a.diagonal(1) array([1])
A 3-D example:
>>> a = np.arange(8).reshape(2,2,2); a array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) >>> a.diagonal(0, # Main diagonals of two arrays created by skipping ... 0, # across the outer(left)-most axis last and ... 1) # the "middle" (row) axis first. array([[0, 6], [1, 7]])
The sub-arrays whose main diagonals we just obtained; note that each corresponds to fixing the right-most (column) axis, and that the diagonals are “packed” in rows.
>>> a[:,:,0] # main diagonal is [0 6] array([[0, 2], [4, 6]]) >>> a[:,:,1] # main diagonal is [1 7] array([[1, 3], [5, 7]])
The anti-diagonal can be obtained by reversing the order of elements using either numpy.flipud or numpy.fliplr.
>>> a = np.arange(9).reshape(3, 3) >>> a array([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) >>> np.fliplr(a).diagonal() # Horizontal flip array([2, 4, 6]) >>> np.flipud(a).diagonal() # Vertical flip array([6, 4, 2])
Note that the order in which the diagonal is retrieved varies depending on the flip function.
-
symjax.tensor.
divide
(x1, x2)¶ Returns a true division of the inputs, element-wise.
LAX-backend implementation of
true_divide()
. Original docstring below.true_divide(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Instead of the Python traditional ‘floor division’, this returns a true division. True division adjusts the output type to present the best answer, regardless of input types.
Parameters: - x1 (array_like) – Dividend array.
- x2 (array_like) – Divisor array.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: out – This is a scalar if both x1 and x2 are scalars.
Return type: ndarray or scalar
Notes
In Python,
//
is the floor division operator and/
the true division operator. Thetrue_divide(x1, x2)
function is equivalent to true division in Python.Examples
>>> x = np.arange(5) >>> np.true_divide(x, 4) array([ 0. , 0.25, 0.5 , 0.75, 1. ])
>>> x/4 array([ 0. , 0.25, 0.5 , 0.75, 1. ])
>>> x//4 array([0, 0, 0, 0, 1])
-
symjax.tensor.
divmod
(x1, x2)[source]¶ Return element-wise quotient and remainder simultaneously.
LAX-backend implementation of
divmod()
. Original docstring below.divmod(x1, x2[, out1, out2], / [, out=(None, None)], *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
New in version 1.13.0.
np.divmod(x, y)
is equivalent to(x // y, x % y)
, but faster because it avoids redundant work. It is used to implement the Python built-in functiondivmod
on NumPy arrays.Parameters: - x1 (array_like) – Dividend array.
- x2 (array_like) – Divisor array.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: - out1 (ndarray) – Element-wise quotient resulting from floor division. This is a scalar if both x1 and x2 are scalars.
- out2 (ndarray) – Element-wise remainder from floor division. This is a scalar if both x1 and x2 are scalars.
See also
floor_divide()
- Equivalent to Python’s
//
operator. remainder()
- Equivalent to Python’s
%
operator. modf()
- Equivalent to
divmod(x, 1)
for positivex
with the return values switched.
Examples
>>> np.divmod(np.arange(5), 3) (array([0, 0, 0, 1, 1]), array([0, 1, 2, 0, 1]))
-
symjax.tensor.
dot
(a, b, *, precision=None)[source]¶ Dot product of two arrays. Specifically,
LAX-backend implementation of
dot()
. In addition to the original NumPy arguments listed below, also supportsprecision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
dot(a, b, out=None)
If both a and b are 1-D arrays, it is inner product of vectors (without complex conjugation).
If both a and b are 2-D arrays, it is matrix multiplication, but using
matmul()
ora @ b
is preferred.If either a or b is 0-D (scalar), it is equivalent to
multiply()
and usingnumpy.multiply(a, b)
ora * b
is preferred.If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b.
If a is an N-D array and b is an M-D array (where
M>=2
), it is a sum product over the last axis of a and the second-to-last axis of b:dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
- Returns
- output : ndarray
- Returns the dot product of a and b. If a and b are both scalars or both 1-D arrays then a scalar is returned; otherwise an array is returned. If out is given, then it is returned.
- ValueError
- If the last dimension of a is not the same size as the second-to-last dimension of b.
vdot : Complex-conjugating dot product. tensordot : Sum products over arbitrary axes. einsum : Einstein summation convention. matmul : ‘@’ operator as method with out parameter.
>>> np.dot(3, 4) 12
Neither argument is complex-conjugated:
>>> np.dot([2j, 3j], [2j, 3j]) (-13+0j)
For 2-D arrays it is the matrix product:
>>> a = [[1, 0], [0, 1]] >>> b = [[4, 1], [2, 2]] >>> np.dot(a, b) array([[4, 1], [2, 2]])
>>> a = np.arange(3*4*5*6).reshape((3,4,5,6)) >>> b = np.arange(3*4*5*6)[::-1].reshape((5,4,6,3)) >>> np.dot(a, b)[2,3,2,1,2,2] 499128 >>> sum(a[2,3,2,:] * b[1,2,:,2]) 499128
-
symjax.tensor.
dsplit
(ary, indices_or_sections)¶ Split array into multiple sub-arrays along the 3rd axis (depth).
LAX-backend implementation of
dsplit()
. Original docstring below.Please refer to the split documentation. dsplit is equivalent to split with
axis=2
, the array is always split along the third axis provided the array dimension is greater than or equal to 3.split : Split an array into multiple sub-arrays of equal size.
>>> x = np.arange(16.0).reshape(2, 2, 4) >>> x array([[[ 0., 1., 2., 3.], [ 4., 5., 6., 7.]], [[ 8., 9., 10., 11.], [12., 13., 14., 15.]]]) >>> np.dsplit(x, 2) [array([[[ 0., 1.], [ 4., 5.]], [[ 8., 9.], [12., 13.]]]), array([[[ 2., 3.], [ 6., 7.]], [[10., 11.], [14., 15.]]])] >>> np.dsplit(x, np.array([3, 6])) [array([[[ 0., 1., 2.], [ 4., 5., 6.]], [[ 8., 9., 10.], [12., 13., 14.]]]), array([[[ 3.], [ 7.]], [[11.], [15.]]]), array([], shape=(2, 2, 0), dtype=float64)]
-
symjax.tensor.
dstack
(tup)[source]¶ Stack arrays in sequence depth wise (along third axis).
LAX-backend implementation of
dstack()
. Original docstring below.This is equivalent to concatenation along the third axis after 2-D arrays of shape (M,N) have been reshaped to (M,N,1) and 1-D arrays of shape (N,) have been reshaped to (1,N,1). Rebuilds arrays divided by dsplit.
This function makes most sense for arrays with up to 3 dimensions. For instance, for pixel-data with a height (first axis), width (second axis), and r/g/b channels (third axis). The functions concatenate, stack and block provide more general stacking and concatenation operations.
Parameters: tup (sequence of arrays) – The arrays must have the same shape along all but the third axis. 1-D or 2-D arrays must have the same shape. Returns: stacked – The array formed by stacking the given arrays, will be at least 3-D. Return type: ndarray See also
concatenate()
- Join a sequence of arrays along an existing axis.
stack()
- Join a sequence of arrays along a new axis.
block()
- Assemble an nd-array from nested lists of blocks.
vstack()
- Stack arrays in sequence vertically (row wise).
hstack()
- Stack arrays in sequence horizontally (column wise).
column_stack()
- Stack 1-D arrays as columns into a 2-D array.
dsplit()
- Split array along third axis.
Examples
>>> a = np.array((1,2,3)) >>> b = np.array((2,3,4)) >>> np.dstack((a,b)) array([[[1, 2], [2, 3], [3, 4]]])
>>> a = np.array([[1],[2],[3]]) >>> b = np.array([[2],[3],[4]]) >>> np.dstack((a,b)) array([[[1, 2]], [[2, 3]], [[3, 4]]])
-
symjax.tensor.
einsum
(*operands, out=None, optimize='greedy', precision=None)[source]¶ Evaluates the Einstein summation convention on the operands.
LAX-backend implementation of
einsum()
. In addition to the original NumPy arguments listed below, also supportsprecision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
- einsum(subscripts, *operands, out=None, dtype=None, order=’K’,
- casting=’safe’, optimize=False)
Using the Einstein summation convention, many common multi-dimensional, linear algebraic array operations can be represented in a simple fashion. In implicit mode einsum computes these values.
In explicit mode, einsum provides further flexibility to compute other array operations that might not be considered classical Einstein summation operations, by disabling, or forcing summation over specified subscript labels.
See the notes and examples for clarification.
- Returns
- output : ndarray
- The calculation based on the Einstein summation convention.
einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
New in version 1.6.0.
The Einstein summation convention can be used to compute many multi-dimensional, linear algebraic array operations. einsum provides a succinct way of representing these.
A non-exhaustive list of these operations, which can be computed by einsum, is shown below along with examples:
- Trace of an array,
numpy.trace()
. - Return a diagonal,
numpy.diag()
. - Array axis summations,
numpy.sum()
. - Transpositions and permutations,
numpy.transpose()
. - Matrix multiplication and dot product,
numpy.matmul()
numpy.dot()
. - Vector inner and outer products,
numpy.inner()
numpy.outer()
. - Broadcasting, element-wise and scalar multiplication,
numpy.multiply()
. - Tensor contractions,
numpy.tensordot()
. - Chained array operations, in efficient calculation order,
numpy.einsum_path()
.
The subscripts string is a comma-separated list of subscript labels, where each label refers to a dimension of the corresponding operand. Whenever a label is repeated it is summed, so
np.einsum('i,i', a, b)
is equivalent tonp.inner(a,b)
. If a label appears only once, it is not summed, sonp.einsum('i', a)
produces a view ofa
with no changes. A further examplenp.einsum('ij,jk', a, b)
describes traditional matrix multiplication and is equivalent tonp.matmul(a,b)
. Repeated subscript labels in one operand take the diagonal. For example,np.einsum('ii', a)
is equivalent tonp.trace(a)
.In implicit mode, the chosen subscripts are important since the axes of the output are reordered alphabetically. This means that
np.einsum('ij', a)
doesn’t affect a 2D array, whilenp.einsum('ji', a)
takes its transpose. Additionally,np.einsum('ij,jk', a, b)
returns a matrix multiplication, while,np.einsum('ij,jh', a, b)
returns the transpose of the multiplication since subscript ‘h’ precedes subscript ‘i’.In explicit mode the output can be directly controlled by specifying output subscript labels. This requires the identifier ‘->’ as well as the list of output subscript labels. This feature increases the flexibility of the function since summing can be disabled or forced when required. The call
np.einsum('i->', a)
is likenp.sum(a, axis=-1)
, andnp.einsum('ii->i', a)
is likenp.diag(a)
. The difference is that einsum does not allow broadcasting by default. Additionallynp.einsum('ij,jh->ih', a, b)
directly specifies the order of the output subscript labels and therefore returns matrix multiplication, unlike the example above in implicit mode.To enable and control broadcasting, use an ellipsis. Default NumPy-style broadcasting is done by adding an ellipsis to the left of each term, like
np.einsum('...ii->...i', a)
. To take the trace along the first and last axes, you can donp.einsum('i...i', a)
, or to do a matrix-matrix product with the left-most indices instead of rightmost, one can donp.einsum('ij...,jk...->ik...', a, b)
.When there is only one operand, no axes are summed, and no output parameter is provided, a view into the operand is returned instead of a new array. Thus, taking the diagonal as
np.einsum('ii->i', a)
produces a view (changed in version 1.10.0).einsum also provides an alternative way to provide the subscripts and operands as
einsum(op0, sublist0, op1, sublist1, ..., [sublistout])
. If the output shape is not provided in this format einsum will be calculated in implicit mode, otherwise it will be performed explicitly. The examples below have corresponding einsum calls with the two parameter methods.New in version 1.10.0.
Views returned from einsum are now writeable whenever the input array is writeable. For example,
np.einsum('ijk...->kji...', a)
will now have the same effect asnp.swapaxes(a, 0, 2)
andnp.einsum('ii->i', a)
will return a writeable view of the diagonal of a 2D array.New in version 1.12.0.
Added the
optimize
argument which will optimize the contraction order of an einsum expression. For a contraction with three or more operands this can greatly increase the computational efficiency at the cost of a larger memory footprint during computation.Typically a ‘greedy’ algorithm is applied which empirical tests have shown returns the optimal path in the majority of cases. In some cases ‘optimal’ will return the superlative path through a more expensive, exhaustive search. For iterative calculations it may be advisable to calculate the optimal path once and reuse that path by supplying it as an argument. An example is given below.
See
numpy.einsum_path()
for more details.>>> a = np.arange(25).reshape(5,5) >>> b = np.arange(5) >>> c = np.arange(6).reshape(2,3)
Trace of a matrix:
>>> np.einsum('ii', a) 60 >>> np.einsum(a, [0,0]) 60 >>> np.trace(a) 60
Extract the diagonal (requires explicit form):
>>> np.einsum('ii->i', a) array([ 0, 6, 12, 18, 24]) >>> np.einsum(a, [0,0], [0]) array([ 0, 6, 12, 18, 24]) >>> np.diag(a) array([ 0, 6, 12, 18, 24])
Sum over an axis (requires explicit form):
>>> np.einsum('ij->i', a) array([ 10, 35, 60, 85, 110]) >>> np.einsum(a, [0,1], [0]) array([ 10, 35, 60, 85, 110]) >>> np.sum(a, axis=1) array([ 10, 35, 60, 85, 110])
For higher dimensional arrays summing a single axis can be done with ellipsis:
>>> np.einsum('...j->...', a) array([ 10, 35, 60, 85, 110]) >>> np.einsum(a, [Ellipsis,1], [Ellipsis]) array([ 10, 35, 60, 85, 110])
Compute a matrix transpose, or reorder any number of axes:
>>> np.einsum('ji', c) array([[0, 3], [1, 4], [2, 5]]) >>> np.einsum('ij->ji', c) array([[0, 3], [1, 4], [2, 5]]) >>> np.einsum(c, [1,0]) array([[0, 3], [1, 4], [2, 5]]) >>> np.transpose(c) array([[0, 3], [1, 4], [2, 5]])
Vector inner products:
>>> np.einsum('i,i', b, b) 30 >>> np.einsum(b, [0], b, [0]) 30 >>> np.inner(b,b) 30
Matrix vector multiplication:
>>> np.einsum('ij,j', a, b) array([ 30, 80, 130, 180, 230]) >>> np.einsum(a, [0,1], b, [1]) array([ 30, 80, 130, 180, 230]) >>> np.dot(a, b) array([ 30, 80, 130, 180, 230]) >>> np.einsum('...j,j', a, b) array([ 30, 80, 130, 180, 230])
Broadcasting and scalar multiplication:
>>> np.einsum('..., ...', 3, c) array([[ 0, 3, 6], [ 9, 12, 15]]) >>> np.einsum(',ij', 3, c) array([[ 0, 3, 6], [ 9, 12, 15]]) >>> np.einsum(3, [Ellipsis], c, [Ellipsis]) array([[ 0, 3, 6], [ 9, 12, 15]]) >>> np.multiply(3, c) array([[ 0, 3, 6], [ 9, 12, 15]])
Vector outer product:
>>> np.einsum('i,j', np.arange(2)+1, b) array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]) >>> np.einsum(np.arange(2)+1, [0], b, [1]) array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]]) >>> np.outer(np.arange(2)+1, b) array([[0, 1, 2, 3, 4], [0, 2, 4, 6, 8]])
Tensor contraction:
>>> a = np.arange(60.).reshape(3,4,5) >>> b = np.arange(24.).reshape(4,3,2) >>> np.einsum('ijk,jil->kl', a, b) array([[4400., 4730.], [4532., 4874.], [4664., 5018.], [4796., 5162.], [4928., 5306.]]) >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3]) array([[4400., 4730.], [4532., 4874.], [4664., 5018.], [4796., 5162.], [4928., 5306.]]) >>> np.tensordot(a,b, axes=([1,0],[0,1])) array([[4400., 4730.], [4532., 4874.], [4664., 5018.], [4796., 5162.], [4928., 5306.]])
Writeable returned arrays (since version 1.10.0):
>>> a = np.zeros((3, 3)) >>> np.einsum('ii->i', a)[:] = 1 >>> a array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
Example of ellipsis use:
>>> a = np.arange(6).reshape((3,2)) >>> b = np.arange(12).reshape((4,3)) >>> np.einsum('ki,jk->ij', a, b) array([[10, 28, 46, 64], [13, 40, 67, 94]]) >>> np.einsum('ki,...k->i...', a, b) array([[10, 28, 46, 64], [13, 40, 67, 94]]) >>> np.einsum('k...,jk', a, b) array([[10, 28, 46, 64], [13, 40, 67, 94]])
Chained array operations. For more complicated contractions, speed ups might be achieved by repeatedly computing a ‘greedy’ path or pre-computing the ‘optimal’ path and repeatedly applying it, using an einsum_path insertion (since version 1.12.0). Performance improvements can be particularly significant with larger arrays:
>>> a = np.ones(64).reshape(2,4,8)
Basic einsum: ~1520ms (benchmarked on 3.1GHz Intel i5.)
>>> for iteration in range(500): ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
Sub-optimal einsum (due to repeated path calculation time): ~330ms
>>> for iteration in range(500): ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
Greedy einsum (faster optimal path approximation): ~160ms
>>> for iteration in range(500): ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
Optimal einsum (best usage pattern in some use cases): ~110ms
>>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0] >>> for iteration in range(500): ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
-
symjax.tensor.
equal
(x1, x2)¶ Return (x1 == x2) element-wise.
LAX-backend implementation of
equal()
. Original docstring below.equal(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – Input arrays. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: out – Output array, element-wise comparison of x1 and x2. Typically of type bool, unless dtype=object
is passed. This is a scalar if both x1 and x2 are scalars.Return type: ndarray or scalar See also
not_equal()
,greater_equal()
,less_equal()
,greater()
,less()
Examples
>>> np.equal([0, 1, 3], np.arange(3)) array([ True, True, False])
What is compared are values, not types. So an int (1) and an array of length one can evaluate as True:
>>> np.equal(1, np.ones(1)) array([ True])
-
symjax.tensor.
empty
(shape, dtype=None)¶ Return a new array of given shape and type, filled with zeros.
LAX-backend implementation of
zeros()
. Original docstring below.zeros(shape, dtype=float, order=’C’)
- Returns
- out : ndarray
- Array of zeros with the given shape, dtype, and order.
zeros_like : Return an array of zeros with shape and type of input. empty : Return a new uninitialized array. ones : Return a new array setting values to one. full : Return a new array of given shape filled with value.
>>> np.zeros(5) array([ 0., 0., 0., 0., 0.])
>>> np.zeros((5,), dtype=int) array([0, 0, 0, 0, 0])
>>> np.zeros((2, 1)) array([[ 0.], [ 0.]])
>>> s = (2,2) >>> np.zeros(s) array([[ 0., 0.], [ 0., 0.]])
>>> np.zeros((2,), dtype=[('x', 'i4'), ('y', 'i4')]) # custom dtype array([(0, 0), (0, 0)], dtype=[('x', '<i4'), ('y', '<i4')])
-
symjax.tensor.
empty_like
(a, dtype=None, shape=None)¶ Return an array of zeros with the same shape and type as a given array.
LAX-backend implementation of
zeros_like()
. Original docstring below.Parameters: - a (array_like) – The shape and data-type of a define these same attributes of the returned array.
- dtype (data-type, optional) – Overrides the data type of the result.
- shape (int or sequence of ints, optional.) – Overrides the shape of the result. If order=’K’ and the number of dimensions is unchanged, will try to keep order, otherwise, order=’C’ is implied.
Returns: out – Array of zeros with the same shape and type as a.
Return type: ndarray
See also
empty_like()
- Return an empty array with shape and type of input.
ones_like()
- Return an array of ones with shape and type of input.
full_like()
- Return a new array with shape of input filled with value.
zeros()
- Return a new array setting values to zero.
Examples
>>> x = np.arange(6) >>> x = x.reshape((2, 3)) >>> x array([[0, 1, 2], [3, 4, 5]]) >>> np.zeros_like(x) array([[0, 0, 0], [0, 0, 0]])
>>> y = np.arange(3, dtype=float) >>> y array([0., 1., 2.]) >>> np.zeros_like(y) array([0., 0., 0.])
-
symjax.tensor.
exp
(x)¶ Calculate the exponential of all elements in the input array.
LAX-backend implementation of
exp()
. Original docstring below.exp(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input values. Returns: out – Output array, element-wise exponential of x. This is a scalar if x is a scalar. Return type: ndarray or scalar See also
Notes
The irrational number
e
is also known as Euler’s number. It is approximately 2.718281, and is the base of the natural logarithm,ln
(this means that, if \(x = \ln y = \log_e y\), then \(e^x = y\). For real input,exp(x)
is always positive.For complex arguments,
x = a + ib
, we can write \(e^x = e^a e^{ib}\). The first term, \(e^a\), is already known (it is the real argument, described above). The second term, \(e^{ib}\), is \(\cos b + i \sin b\), a function with magnitude 1 and a periodic phase.References
[1] Wikipedia, “Exponential function”, https://en.wikipedia.org/wiki/Exponential_function [2] M. Abramovitz and I. A. Stegun, “Handbook of Mathematical Functions with Formulas, Graphs, and Mathematical Tables,” Dover, 1964, p. 69, http://www.math.sfu.ca/~cbm/aands/page_69.htm Examples
Plot the magnitude and phase of
exp(x)
in the complex plane:>>> import matplotlib.pyplot as plt
>>> x = np.linspace(-2*np.pi, 2*np.pi, 100) >>> xx = x + 1j * x[:, np.newaxis] # a + ib over complex plane >>> out = np.exp(xx)
>>> plt.subplot(121) >>> plt.imshow(np.abs(out), ... extent=[-2*np.pi, 2*np.pi, -2*np.pi, 2*np.pi], cmap='gray') >>> plt.title('Magnitude of exp(x)')
>>> plt.subplot(122) >>> plt.imshow(np.angle(out), ... extent=[-2*np.pi, 2*np.pi, -2*np.pi, 2*np.pi], cmap='hsv') >>> plt.title('Phase (angle) of exp(x)') >>> plt.show()
-
symjax.tensor.
exp2
(x)[source]¶ Calculate 2**p for all p in the input array.
LAX-backend implementation of
exp2()
. Original docstring below.exp2(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input values. Returns: out – Element-wise 2 to the power x. This is a scalar if x is a scalar. Return type: ndarray or scalar See also
Notes
New in version 1.3.0.
Examples
>>> np.exp2([2, 3]) array([ 4., 8.])
-
symjax.tensor.
expand_dims
(a, axis: Union[int, Tuple[int, ...]])[source]¶ Expand the shape of an array.
LAX-backend implementation of
expand_dims()
. Original docstring below.Insert a new axis that will appear at the axis position in the expanded array shape.
Parameters: - a (array_like) – Input array.
- axis (int or tuple of ints) – Position in the expanded axes where the new axis (or axes) is placed.
Returns: result – View of a with the number of dimensions increased.
Return type: ndarray
See also
squeeze()
- The inverse operation, removing singleton dimensions
reshape()
- Insert, remove, and combine dimensions, and resize existing ones
doc.indexing()
,atleast_1d()
,atleast_2d()
,atleast_3d()
Examples
>>> x = np.array([1, 2]) >>> x.shape (2,)
The following is equivalent to
x[np.newaxis, :]
orx[np.newaxis]
:>>> y = np.expand_dims(x, axis=0) >>> y array([[1, 2]]) >>> y.shape (1, 2)
The following is equivalent to
x[:, np.newaxis]
:>>> y = np.expand_dims(x, axis=1) >>> y array([[1], [2]]) >>> y.shape (2, 1)
axis
may also be a tuple:>>> y = np.expand_dims(x, axis=(0, 1)) >>> y array([[[1, 2]]])
>>> y = np.expand_dims(x, axis=(2, 0)) >>> y array([[[1], [2]]])
Note that some examples may use
None
instead ofnp.newaxis
. These are the same objects:>>> np.newaxis is None True
-
symjax.tensor.
expm1
(x)¶ Calculate
exp(x) - 1
for all elements in the array.LAX-backend implementation of
expm1()
. Original docstring below.expm1(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input values. Returns: out – Element-wise exponential minus one: out = exp(x) - 1
. This is a scalar if x is a scalar.Return type: ndarray or scalar See also
log1p()
log(1 + x)
, the inverse of expm1.
Notes
This function provides greater precision than
exp(x) - 1
for small values ofx
.Examples
The true value of
exp(1e-10) - 1
is1.00000000005e-10
to about 32 significant digits. This example shows the superiority of expm1 in this case.>>> np.expm1(1e-10) 1.00000000005e-10 >>> np.exp(1e-10) - 1 1.000000082740371e-10
-
symjax.tensor.
eye
(N, M=None, k=0, dtype=None)[source]¶ Return a 2-D array with ones on the diagonal and zeros elsewhere.
LAX-backend implementation of
eye()
. Original docstring below.Parameters: - N (int) –
- M (int, optional) –
- k (int, optional) –
- dtype (data-type, optional) –
Returns: I – An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
Return type: ndarray of shape (N,M)
See also
identity()
- (almost) equivalent function
diag()
- diagonal 2-D array from a 1-D array specified by the user.
Examples
>>> np.eye(2, dtype=int) array([[1, 0], [0, 1]]) >>> np.eye(3, k=1) array([[0., 1., 0.], [0., 0., 1.], [0., 0., 0.]])
-
symjax.tensor.
fabs
(x)¶ Compute the absolute values element-wise.
LAX-backend implementation of
fabs()
. Original docstring below.fabs(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
This function returns the absolute values (positive magnitude) of the data in x. Complex values are not handled, use absolute to find the absolute values of complex data.
Parameters: x (array_like) – The array of numbers for which the absolute values are required. If x is a scalar, the result y will also be a scalar. Returns: y – The absolute values of x, the returned values are always floats. This is a scalar if x is a scalar. Return type: ndarray or scalar See also
absolute()
- Absolute values including complex types.
Examples
>>> np.fabs(-1) 1.0 >>> np.fabs([-1.2, 1.2]) array([ 1.2, 1.2])
-
symjax.tensor.
fix
(x, out=None)[source]¶ Round to nearest integer towards zero.
LAX-backend implementation of
fix()
. Original docstring below.Round an array of floats element-wise to nearest integer towards zero. The rounded values are returned as floats.
Parameters: - x (array_like) – An array of floats to be rounded
- out (ndarray, optional) – A location into which the result is stored. If provided, it must have a shape that the input broadcasts to. If not provided or None, a freshly-allocated array is returned.
Returns: out – A float array with the same dimensions as the input. If second argument is not supplied then a float array is returned with the rounded values.
If a second argument is supplied the result is stored there. The return value out is then a reference to that array.
Return type: ndarray of floats
Examples
>>> np.fix(3.14) 3.0 >>> np.fix(3) 3.0 >>> np.fix([2.1, 2.9, -2.1, -2.9]) array([ 2., 2., -2., -2.])
-
symjax.tensor.
flip
(m, axis=None)[source]¶ Reverse the order of elements in an array along the given axis.
LAX-backend implementation of
flip()
. Original docstring below.The shape of the array is preserved, but the elements are reordered.
New in version 1.12.0.
Parameters: - m (array_like) – Input array.
- axis (None or int or tuple of ints, optional) – Axis or axes along which to flip over. The default, axis=None, will flip over all of the axes of the input array. If axis is negative it counts from the last to the first axis.
Returns: out – A view of m with the entries of axis reversed. Since a view is returned, this operation is done in constant time.
Return type: array_like
Notes
flip(m, 0) is equivalent to flipud(m).
flip(m, 1) is equivalent to fliplr(m).
flip(m, n) corresponds to
m[...,::-1,...]
with::-1
at position n.flip(m) corresponds to
m[::-1,::-1,...,::-1]
with::-1
at all positions.flip(m, (0, 1)) corresponds to
m[::-1,::-1,...]
with::-1
at position 0 and position 1.Examples
>>> A = np.arange(8).reshape((2,2,2)) >>> A array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) >>> np.flip(A, 0) array([[[4, 5], [6, 7]], [[0, 1], [2, 3]]]) >>> np.flip(A, 1) array([[[2, 3], [0, 1]], [[6, 7], [4, 5]]]) >>> np.flip(A) array([[[7, 6], [5, 4]], [[3, 2], [1, 0]]]) >>> np.flip(A, (0, 2)) array([[[5, 4], [7, 6]], [[1, 0], [3, 2]]]) >>> A = np.random.randn(3,4,5) >>> np.all(np.flip(A,2) == A[:,:,::-1,...]) True
-
symjax.tensor.
fliplr
(m)[source]¶ Flip array in the left/right direction.
LAX-backend implementation of
fliplr()
. Original docstring below.Flip the entries in each row in the left/right direction. Columns are preserved, but appear in a different order than before.
Parameters: m (array_like) – Input array, must be at least 2-D. Returns: f – A view of m with the columns reversed. Since a view is returned, this operation is \(\mathcal O(1)\). Return type: ndarray Notes
Equivalent to m[:,::-1]. Requires the array to be at least 2-D.
Examples
>>> A = np.diag([1.,2.,3.]) >>> A array([[1., 0., 0.], [0., 2., 0.], [0., 0., 3.]]) >>> np.fliplr(A) array([[0., 0., 1.], [0., 2., 0.], [3., 0., 0.]])
>>> A = np.random.randn(2,3,5) >>> np.all(np.fliplr(A) == A[:,::-1,...]) True
-
symjax.tensor.
flipud
(m)[source]¶ Flip array in the up/down direction.
LAX-backend implementation of
flipud()
. Original docstring below.Flip the entries in each column in the up/down direction. Rows are preserved, but appear in a different order than before.
Parameters: m (array_like) – Input array. Returns: out – A view of m with the rows reversed. Since a view is returned, this operation is \(\mathcal O(1)\). Return type: array_like Notes
Equivalent to
m[::-1,...]
. Does not require the array to be two-dimensional.Examples
>>> A = np.diag([1.0, 2, 3]) >>> A array([[1., 0., 0.], [0., 2., 0.], [0., 0., 3.]]) >>> np.flipud(A) array([[0., 0., 3.], [0., 2., 0.], [1., 0., 0.]])
>>> A = np.random.randn(2,3,5) >>> np.all(np.flipud(A) == A[::-1,...]) True
>>> np.flipud([1,2]) array([2, 1])
-
symjax.tensor.
float_power
(x1, x2)¶ First array elements raised to powers from second array, element-wise.
LAX-backend implementation of
float_power()
. Original docstring below.float_power(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Raise each base in x1 to the positionally-corresponding power in x2. x1 and x2 must be broadcastable to the same shape. This differs from the power function in that integers, float16, and float32 are promoted to floats with a minimum precision of float64 so that the result is always inexact. The intent is that the function will return a usable result for negative powers and seldom overflow for positive powers.
New in version 1.12.0.
Parameters: - x1 (array_like) – The bases.
- x2 (array_like) – The exponents.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: y – The bases in x1 raised to the exponents in x2. This is a scalar if both x1 and x2 are scalars.
Return type: ndarray
See also
power()
- power function that preserves type
Examples
Cube each element in a list.
>>> x1 = range(6) >>> x1 [0, 1, 2, 3, 4, 5] >>> np.float_power(x1, 3) array([ 0., 1., 8., 27., 64., 125.])
Raise the bases to different exponents.
>>> x2 = [1.0, 2.0, 3.0, 3.0, 2.0, 1.0] >>> np.float_power(x1, x2) array([ 0., 1., 8., 27., 16., 5.])
The effect of broadcasting.
>>> x2 = np.array([[1, 2, 3, 3, 2, 1], [1, 2, 3, 3, 2, 1]]) >>> x2 array([[1, 2, 3, 3, 2, 1], [1, 2, 3, 3, 2, 1]]) >>> np.float_power(x1, x2) array([[ 0., 1., 8., 27., 16., 5.], [ 0., 1., 8., 27., 16., 5.]])
-
symjax.tensor.
floor
(x)¶ Return the floor of the input, element-wise.
LAX-backend implementation of
floor()
. Original docstring below.floor(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The floor of the scalar x is the largest integer i, such that i <= x. It is often denoted as \(\lfloor x \rfloor\).
Parameters: x (array_like) – Input data. Returns: y – The floor of each element in x. This is a scalar if x is a scalar. Return type: ndarray or scalar See also
ceil()
,trunc()
,rint()
Notes
Some spreadsheet programs calculate the “floor-towards-zero”, in other words
floor(-2.5) == -2
. NumPy instead uses the definition of floor where floor(-2.5) == -3.Examples
>>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) >>> np.floor(a) array([-2., -2., -1., 0., 1., 1., 2.])
-
symjax.tensor.
floor_divide
(x1, x2)[source]¶ Return the largest integer smaller or equal to the division of the inputs. It is equivalent to the Python
//
operator and pairs with the Python%
(remainder), function so thata = a % b + b * (a // b)
up to roundoff.LAX-backend implementation of
floor_divide()
. Original docstring below.floor_divide(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: - x1 (array_like) – Numerator.
- x2 (array_like) – Denominator.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output). - out (ndarray, None, or tuple of ndarray and None, optional) – A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or None, a freshly-allocated array is returned. A tuple (possible only as a keyword argument) must have length equal to the number of outputs.
Returns: y – y = floor(x1/x2) This is a scalar if both x1 and x2 are scalars.
Return type: ndarray
See also
remainder()
- Remainder complementary to floor_divide.
divmod()
- Simultaneous floor division and remainder.
divide()
- Standard division.
floor()
- Round a number to the nearest integer toward minus infinity.
ceil()
- Round a number to the nearest integer toward infinity.
Examples
>>> np.floor_divide(7,3) 2 >>> np.floor_divide([1., 2., 3., 4.], 2.5) array([ 0., 0., 1., 1.])
-
symjax.tensor.
fmod
(x1, x2)[source]¶ Return the element-wise remainder of division.
LAX-backend implementation of
fmod()
. Original docstring below.fmod(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
This is the NumPy implementation of the C library function fmod, the remainder has the same sign as the dividend x1. It is equivalent to the Matlab(TM)
rem
function and should not be confused with the Python modulus operatorx1 % x2
.Parameters: - x1 (array_like) – Dividend.
- x2 (array_like) – Divisor.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: y – The remainder of the division of x1 by x2. This is a scalar if both x1 and x2 are scalars.
Return type: array_like
Notes
The result of the modulo operation for negative dividend and divisors is bound by conventions. For fmod, the sign of result is the sign of the dividend, while for remainder the sign of the result is the sign of the divisor. The fmod function is equivalent to the Matlab(TM)
rem
function.Examples
>>> np.fmod([-3, -2, -1, 1, 2, 3], 2) array([-1, 0, -1, 1, 0, 1]) >>> np.remainder([-3, -2, -1, 1, 2, 3], 2) array([1, 0, 1, 1, 0, 1])
>>> np.fmod([5, 3], [2, 2.]) array([ 1., 1.]) >>> a = np.arange(-3, 3).reshape(3, 2) >>> a array([[-3, -2], [-1, 0], [ 1, 2]]) >>> np.fmod(a, [2,2]) array([[-1, 0], [-1, 0], [ 1, 0]])
-
symjax.tensor.
full
(shape, fill_value, dtype=None)[source]¶ Return a new array of given shape and type, filled with fill_value.
LAX-backend implementation of
full()
. Original docstring below.Parameters: - shape (int or sequence of ints) – Shape of the new array, e.g.,
(2, 3)
or2
. - fill_value (scalar or array_like) – Fill value.
- dtype (data-type, optional) –
- The desired data-type for the array The default, None, means
- np.array(fill_value).dtype.
Returns: out – Array of fill_value with the given shape, dtype, and order.
Return type: ndarray
See also
full_like()
- Return a new array with shape of input filled with value.
empty()
- Return a new uninitialized array.
ones()
- Return a new array setting values to one.
zeros()
- Return a new array setting values to zero.
Examples
>>> np.full((2, 2), np.inf) array([[inf, inf], [inf, inf]]) >>> np.full((2, 2), 10) array([[10, 10], [10, 10]])
>>> np.full((2, 2), [1, 2]) array([[1, 2], [1, 2]])
- shape (int or sequence of ints) – Shape of the new array, e.g.,
-
symjax.tensor.
full_like
(a, fill_value, dtype=None, shape=None)[source]¶ Return a full array with the same shape and type as a given array.
LAX-backend implementation of
full_like()
. Original docstring below.Parameters: - a (array_like) – The shape and data-type of a define these same attributes of the returned array.
- fill_value (scalar) – Fill value.
- dtype (data-type, optional) – Overrides the data type of the result.
- shape (int or sequence of ints, optional.) – Overrides the shape of the result. If order=’K’ and the number of dimensions is unchanged, will try to keep order, otherwise, order=’C’ is implied.
Returns: out – Array of fill_value with the same shape and type as a.
Return type: ndarray
See also
empty_like()
- Return an empty array with shape and type of input.
ones_like()
- Return an array of ones with shape and type of input.
zeros_like()
- Return an array of zeros with shape and type of input.
full()
- Return a new array of given shape filled with value.
Examples
>>> x = np.arange(6, dtype=int) >>> np.full_like(x, 1) array([1, 1, 1, 1, 1, 1]) >>> np.full_like(x, 0.1) array([0, 0, 0, 0, 0, 0]) >>> np.full_like(x, 0.1, dtype=np.double) array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) >>> np.full_like(x, np.nan, dtype=np.double) array([nan, nan, nan, nan, nan, nan])
>>> y = np.arange(6, dtype=np.double) >>> np.full_like(y, 0.1) array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
-
symjax.tensor.
gcd
(x1, x2)[source]¶ Returns the greatest common divisor of
|x1|
and|x2|
LAX-backend implementation of
gcd()
. Original docstring below.gcd(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – Arrays of values. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: y – The greatest common divisor of the absolute value of the inputs This is a scalar if both x1 and x2 are scalars. Return type: ndarray or scalar See also
lcm()
- The lowest common multiple
Examples
>>> np.gcd(12, 20) 4 >>> np.gcd.reduce([15, 25, 35]) 5 >>> np.gcd(np.arange(6), 20) array([20, 1, 2, 1, 4, 5])
-
symjax.tensor.
geomspace
(start, stop, num=50, endpoint=True, dtype=None, axis=0)[source]¶ Return numbers spaced evenly on a log scale (a geometric progression).
LAX-backend implementation of
geomspace()
. Original docstring below.This is similar to logspace, but with endpoints specified directly. Each output sample is a constant multiple of the previous.
Changed in version 1.16.0: Non-scalar start and stop are now supported.
Parameters: - start (array_like) – The starting value of the sequence.
- stop (array_like) – The final value of the sequence, unless endpoint is False.
In that case,
num + 1
values are spaced over the interval in log-space, of which all but the last (a sequence of length num) are returned. - num (integer, optional) – Number of samples to generate. Default is 50.
- endpoint (boolean, optional) – If true, stop is the last sample. Otherwise, it is not included. Default is True.
- dtype (dtype) – The type of the output array. If dtype is not given, infer the data type from the other input arguments.
- axis (int, optional) – The axis in the result to store the samples. Relevant only if start or stop are array-like. By default (0), the samples will be along a new axis inserted at the beginning. Use -1 to get an axis at the end.
Returns: samples – num samples, equally spaced on a log scale.
Return type: ndarray
See also
logspace()
- Similar to geomspace, but with endpoints specified using log and base.
linspace()
- Similar to geomspace, but with arithmetic instead of geometric progression.
arange()
- Similar to linspace, with the step size specified instead of the number of samples.
Notes
If the inputs or dtype are complex, the output will follow a logarithmic spiral in the complex plane. (There are an infinite number of spirals passing through two points; the output will follow the shortest such path.)
Examples
>>> np.geomspace(1, 1000, num=4) array([ 1., 10., 100., 1000.]) >>> np.geomspace(1, 1000, num=3, endpoint=False) array([ 1., 10., 100.]) >>> np.geomspace(1, 1000, num=4, endpoint=False) array([ 1. , 5.62341325, 31.6227766 , 177.827941 ]) >>> np.geomspace(1, 256, num=9) array([ 1., 2., 4., 8., 16., 32., 64., 128., 256.])
Note that the above may not produce exact integers:
>>> np.geomspace(1, 256, num=9, dtype=int) array([ 1, 2, 4, 7, 16, 32, 63, 127, 256]) >>> np.around(np.geomspace(1, 256, num=9)).astype(int) array([ 1, 2, 4, 8, 16, 32, 64, 128, 256])
Negative, decreasing, and complex inputs are allowed:
>>> np.geomspace(1000, 1, num=4) array([1000., 100., 10., 1.]) >>> np.geomspace(-1000, -1, num=4) array([-1000., -100., -10., -1.]) >>> np.geomspace(1j, 1000j, num=4) # Straight line array([0. +1.j, 0. +10.j, 0. +100.j, 0.+1000.j]) >>> np.geomspace(-1+0j, 1+0j, num=5) # Circle array([-1.00000000e+00+1.22464680e-16j, -7.07106781e-01+7.07106781e-01j, 6.12323400e-17+1.00000000e+00j, 7.07106781e-01+7.07106781e-01j, 1.00000000e+00+0.00000000e+00j])
Graphical illustration of
endpoint
parameter:>>> import matplotlib.pyplot as plt >>> N = 10 >>> y = np.zeros(N) >>> plt.semilogx(np.geomspace(1, 1000, N, endpoint=True), y + 1, 'o') [<matplotlib.lines.Line2D object at 0x...>] >>> plt.semilogx(np.geomspace(1, 1000, N, endpoint=False), y + 2, 'o') [<matplotlib.lines.Line2D object at 0x...>] >>> plt.axis([0.5, 2000, 0, 3]) [0.5, 2000, 0, 3] >>> plt.grid(True, color='0.7', linestyle='-', which='both', axis='both') >>> plt.show()
-
symjax.tensor.
greater
(x1, x2)¶ Return the truth value of (x1 > x2) element-wise.
LAX-backend implementation of
greater()
. Original docstring below.greater(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – Input arrays. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: out – Output array, element-wise comparison of x1 and x2. Typically of type bool, unless dtype=object
is passed. This is a scalar if both x1 and x2 are scalars.Return type: ndarray or scalar See also
Examples
>>> np.greater([4,2],[2,2]) array([ True, False])
If the inputs are ndarrays, then np.greater is equivalent to ‘>’.
>>> a = np.array([4,2]) >>> b = np.array([2,2]) >>> a > b array([ True, False])
-
symjax.tensor.
greater_equal
(x1, x2)¶ Return the truth value of (x1 >= x2) element-wise.
LAX-backend implementation of
greater_equal()
. Original docstring below.greater_equal(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – Input arrays. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: out – Output array, element-wise comparison of x1 and x2. Typically of type bool, unless dtype=object
is passed. This is a scalar if both x1 and x2 are scalars.Return type: bool or ndarray of bool See also
Examples
>>> np.greater_equal([4, 2, 1], [2, 2, 2]) array([ True, True, False])
-
symjax.tensor.
heaviside
(x1, x2)[source]¶ Compute the Heaviside step function.
LAX-backend implementation of
heaviside()
. Original docstring below.heaviside(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The Heaviside step function is defined as:
0 if x1 < 0 heaviside(x1, x2) = x2 if x1 == 0 1 if x1 > 0
where x2 is often taken to be 0.5, but 0 and 1 are also sometimes used.
Parameters: - x1 (array_like) – Input values.
- x2 (array_like) – The value of the function when x1 is 0.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: out – The output array, element-wise Heaviside step function of x1. This is a scalar if both x1 and x2 are scalars.
Return type: ndarray or scalar
Notes
New in version 1.13.0.
References
Examples
>>> np.heaviside([-1.5, 0, 2.0], 0.5) array([ 0. , 0.5, 1. ]) >>> np.heaviside([-1.5, 0, 2.0], 1) array([ 0., 1., 1.])
-
symjax.tensor.
hsplit
(ary, indices_or_sections)¶ Split an array into multiple sub-arrays horizontally (column-wise).
LAX-backend implementation of
hsplit()
. Original docstring below.Please refer to the split documentation. hsplit is equivalent to split with
axis=1
, the array is always split along the second axis regardless of the array dimension.split : Split an array into multiple sub-arrays of equal size.
>>> x = np.arange(16.0).reshape(4, 4) >>> x array([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.], [12., 13., 14., 15.]]) >>> np.hsplit(x, 2) [array([[ 0., 1.], [ 4., 5.], [ 8., 9.], [12., 13.]]), array([[ 2., 3.], [ 6., 7.], [10., 11.], [14., 15.]])] >>> np.hsplit(x, np.array([3, 6])) [array([[ 0., 1., 2.], [ 4., 5., 6.], [ 8., 9., 10.], [12., 13., 14.]]), array([[ 3.], [ 7.], [11.], [15.]]), array([], shape=(4, 0), dtype=float64)]
With a higher dimensional array the split is still along the second axis.
>>> x = np.arange(8.0).reshape(2, 2, 2) >>> x array([[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]) >>> np.hsplit(x, 2) [array([[[0., 1.]], [[4., 5.]]]), array([[[2., 3.]], [[6., 7.]]])]
-
symjax.tensor.
hstack
(tup)[source]¶ Stack arrays in sequence horizontally (column wise).
LAX-backend implementation of
hstack()
. Original docstring below.This is equivalent to concatenation along the second axis, except for 1-D arrays where it concatenates along the first axis. Rebuilds arrays divided by hsplit.
This function makes most sense for arrays with up to 3 dimensions. For instance, for pixel-data with a height (first axis), width (second axis), and r/g/b channels (third axis). The functions concatenate, stack and block provide more general stacking and concatenation operations.
Parameters: tup (sequence of ndarrays) – The arrays must have the same shape along all but the second axis, except 1-D arrays which can be any length. Returns: stacked – The array formed by stacking the given arrays. Return type: ndarray See also
concatenate()
- Join a sequence of arrays along an existing axis.
stack()
- Join a sequence of arrays along a new axis.
block()
- Assemble an nd-array from nested lists of blocks.
vstack()
- Stack arrays in sequence vertically (row wise).
dstack()
- Stack arrays in sequence depth wise (along third axis).
column_stack()
- Stack 1-D arrays as columns into a 2-D array.
hsplit()
- Split an array into multiple sub-arrays horizontally (column-wise).
Examples
>>> a = np.array((1,2,3)) >>> b = np.array((2,3,4)) >>> np.hstack((a,b)) array([1, 2, 3, 2, 3, 4]) >>> a = np.array([[1],[2],[3]]) >>> b = np.array([[2],[3],[4]]) >>> np.hstack((a,b)) array([[1, 2], [2, 3], [3, 4]])
-
symjax.tensor.
hypot
(x1, x2)[source]¶ Given the “legs” of a right triangle, return its hypotenuse.
LAX-backend implementation of
hypot()
. Original docstring below.hypot(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Equivalent to
sqrt(x1**2 + x2**2)
, element-wise. If x1 or x2 is scalar_like (i.e., unambiguously cast-able to a scalar type), it is broadcast for use with each element of the other argument. (See Examples)Parameters: x2 (x1,) – Leg of the triangle(s). If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: z – The hypotenuse of the triangle(s). This is a scalar if both x1 and x2 are scalars. Return type: ndarray Examples
>>> np.hypot(3*np.ones((3, 3)), 4*np.ones((3, 3))) array([[ 5., 5., 5.], [ 5., 5., 5.], [ 5., 5., 5.]])
Example showing broadcast of scalar_like argument:
>>> np.hypot(3*np.ones((3, 3)), [4]) array([[ 5., 5., 5.], [ 5., 5., 5.], [ 5., 5., 5.]])
-
symjax.tensor.
identity
(n, dtype=None)[source]¶ Return the identity array.
LAX-backend implementation of
identity()
. Original docstring below.The identity array is a square array with ones on the main diagonal.
Parameters: - n (int) – Number of rows (and columns) in n x n output.
- dtype (data-type, optional) – Data-type of the output. Defaults to
float
.
Returns: out – n x n array with its main diagonal set to one, and all other elements 0.
Return type: ndarray
Examples
>>> np.identity(3) array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
-
symjax.tensor.
imag
(val)[source]¶ Return the imaginary part of the complex argument.
LAX-backend implementation of
imag()
. Original docstring below.Parameters: val (array_like) – Input array. Returns: out – The imaginary component of the complex argument. If val is real, the type of val is used for the output. If val has complex elements, the returned type is float. Return type: ndarray or scalar Examples
>>> a = np.array([1+2j, 3+4j, 5+6j]) >>> a.imag array([2., 4., 6.]) >>> a.imag = np.array([8, 10, 12]) >>> a array([1. +8.j, 3.+10.j, 5.+12.j]) >>> np.imag(1 + 1j) 1.0
-
symjax.tensor.
inner
(a, b, *, precision=None)[source]¶ Inner product of two arrays.
LAX-backend implementation of
inner()
. In addition to the original NumPy arguments listed below, also supportsprecision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
inner(a, b)
Ordinary inner product of vectors for 1-D arrays (without complex conjugation), in higher dimensions a sum product over the last axes.
- Returns
- out : ndarray
- out.shape = a.shape[:-1] + b.shape[:-1]
- ValueError
- If the last dimension of a and b has different size.
tensordot : Sum products over arbitrary axes. dot : Generalised matrix product, using second last dimension of b. einsum : Einstein summation convention.
For vectors (1-D arrays) it computes the ordinary inner-product:
np.inner(a, b) = sum(a[:]*b[:])
More generally, if ndim(a) = r > 0 and ndim(b) = s > 0:
np.inner(a, b) = np.tensordot(a, b, axes=(-1,-1))
or explicitly:
np.inner(a, b)[i0,...,ir-1,j0,...,js-1] = sum(a[i0,...,ir-1,:]*b[j0,...,js-1,:])
In addition a or b may be scalars, in which case:
np.inner(a,b) = a*b
Ordinary inner product for vectors:
>>> a = np.array([1,2,3]) >>> b = np.array([0,1,0]) >>> np.inner(a, b) 2
A multidimensional example:
>>> a = np.arange(24).reshape((2,3,4)) >>> b = np.arange(4) >>> np.inner(a, b) array([[ 14, 38, 62], [ 86, 110, 134]])
An example where b is a scalar:
>>> np.inner(np.eye(2), 7) array([[7., 0.], [0., 7.]])
-
symjax.tensor.
isclose
(a, b, rtol=1e-05, atol=1e-08, equal_nan=False)[source]¶ - Returns a boolean array where two arrays are element-wise equal within a
- tolerance.
LAX-backend implementation of
isclose()
. Original docstring below.The tolerance values are positive, typically very small numbers. The relative difference (rtol * abs(b)) and the absolute difference atol are added together to compare against the absolute difference between a and b.
Warning
The default atol is not appropriate for comparing numbers that are much smaller than one (see Notes).
Parameters: - b (a,) – Input arrays to compare.
- rtol (float) – The relative tolerance parameter (see Notes).
- atol (float) – The absolute tolerance parameter (see Notes).
- equal_nan (bool) – Whether to compare NaN’s as equal. If True, NaN’s in a will be considered equal to NaN’s in b in the output array.
Returns: y – Returns a boolean array of where a and b are equal within the given tolerance. If both a and b are scalars, returns a single boolean value.
Return type: array_like
See also
Notes
New in version 1.7.0.
For finite values, isclose uses the following equation to test whether two floating point values are equivalent.
absolute(a - b) <= (atol + rtol * absolute(b))Unlike the built-in math.isclose, the above equation is not symmetric in a and b – it assumes b is the reference value – so that isclose(a, b) might be different from isclose(b, a). Furthermore, the default value of atol is not zero, and is used to determine what small values should be considered close to zero. The default value is appropriate for expected values of order unity: if the expected values are significantly smaller than one, it can result in false positives. atol should be carefully selected for the use case at hand. A zero value for atol will result in False if either a or b is zero.
Examples
>>> np.isclose([1e10,1e-7], [1.00001e10,1e-8]) array([ True, False]) >>> np.isclose([1e10,1e-8], [1.00001e10,1e-9]) array([ True, True]) >>> np.isclose([1e10,1e-8], [1.0001e10,1e-9]) array([False, True]) >>> np.isclose([1.0, np.nan], [1.0, np.nan]) array([ True, False]) >>> np.isclose([1.0, np.nan], [1.0, np.nan], equal_nan=True) array([ True, True]) >>> np.isclose([1e-8, 1e-7], [0.0, 0.0]) array([ True, False]) >>> np.isclose([1e-100, 1e-7], [0.0, 0.0], atol=0.0) array([False, False]) >>> np.isclose([1e-10, 1e-10], [1e-20, 0.0]) array([ True, True]) >>> np.isclose([1e-10, 1e-10], [1e-20, 0.999999e-10], atol=0.0) array([False, True])
-
symjax.tensor.
iscomplex
(x)[source]¶ Returns a bool array, where True if input element is complex.
LAX-backend implementation of
iscomplex()
. Original docstring below.What is tested is whether the input has a non-zero imaginary part, not if the input type is complex.
Parameters: x (array_like) – Input array. Returns: out – Output array. Return type: ndarray of bools Examples
>>> np.iscomplex([1+1j, 1+0j, 4.5, 3, 2, 2j]) array([ True, False, False, False, False, True])
-
symjax.tensor.
isfinite
(x)[source]¶ Test element-wise for finiteness (not infinity or not Not a Number).
LAX-backend implementation of
isfinite()
. Original docstring below.isfinite(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The result is returned as a boolean array.
Parameters: x (array_like) – Input values. Returns: y – True where x
is not positive infinity, negative infinity, or NaN; false otherwise. This is a scalar if x is a scalar.Return type: ndarray, bool See also
Notes
Not a Number, positive infinity and negative infinity are considered to be non-finite.
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754). This means that Not a Number is not equivalent to infinity. Also that positive infinity is not equivalent to negative infinity. But infinity is equivalent to positive infinity. Errors result if the second argument is also supplied when x is a scalar input, or if first and second arguments have different shapes.
Examples
>>> np.isfinite(1) True >>> np.isfinite(0) True >>> np.isfinite(np.nan) False >>> np.isfinite(np.inf) False >>> np.isfinite(np.NINF) False >>> np.isfinite([np.log(-1.),1.,np.log(0)]) array([False, True, False])
>>> x = np.array([-np.inf, 0., np.inf]) >>> y = np.array([2, 2, 2]) >>> np.isfinite(x, y) array([0, 1, 0]) >>> y array([0, 1, 0])
-
symjax.tensor.
isinf
(x)[source]¶ Test element-wise for positive or negative infinity.
LAX-backend implementation of
isinf()
. Original docstring below.isinf(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Returns a boolean array of the same shape as x, True where
x == +/-inf
, otherwise False.Parameters: x (array_like) – Input values Returns: y – True where x
is positive or negative infinity, false otherwise. This is a scalar if x is a scalar.Return type: bool (scalar) or boolean ndarray See also
Notes
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754).
Errors result if the second argument is supplied when the first argument is a scalar, or if the first and second arguments have different shapes.
Examples
>>> np.isinf(np.inf) True >>> np.isinf(np.nan) False >>> np.isinf(np.NINF) True >>> np.isinf([np.inf, -np.inf, 1.0, np.nan]) array([ True, True, False, False])
>>> x = np.array([-np.inf, 0., np.inf]) >>> y = np.array([2, 2, 2]) >>> np.isinf(x, y) array([1, 0, 1]) >>> y array([1, 0, 1])
-
symjax.tensor.
isnan
(x)[source]¶ Test element-wise for NaN and return result as a boolean array.
LAX-backend implementation of
isnan()
. Original docstring below.isnan(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input array. Returns: y – True where x
is NaN, false otherwise. This is a scalar if x is a scalar.Return type: ndarray or bool See also
isinf()
,isneginf()
,isposinf()
,isfinite()
,isnat()
Notes
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754). This means that Not a Number is not equivalent to infinity.
Examples
>>> np.isnan(np.nan) True >>> np.isnan(np.inf) False >>> np.isnan([np.log(-1.),1.,np.log(0)]) array([ True, False, False])
-
symjax.tensor.
isneginf
(x, out=None)¶ Test element-wise for negative infinity, return result as bool array.
LAX-backend implementation of
isneginf()
. Original docstring below.Parameters: - x (array_like) – The input array.
- out (array_like, optional) – A location into which the result is stored. If provided, it must have a shape that the input broadcasts to. If not provided or None, a freshly-allocated boolean array is returned.
Returns: out – A boolean array with the same dimensions as the input. If second argument is not supplied then a numpy boolean array is returned with values True where the corresponding element of the input is negative infinity and values False where the element of the input is not negative infinity.
If a second argument is supplied the result is stored there. If the type of that array is a numeric type the result is represented as zeros and ones, if the type is boolean then as False and True. The return value out is then a reference to that array.
Return type: ndarray
See also
Notes
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754).
Errors result if the second argument is also supplied when x is a scalar input, if first and second arguments have different shapes, or if the first argument has complex values.
Examples
>>> np.isneginf(np.NINF) True >>> np.isneginf(np.inf) False >>> np.isneginf(np.PINF) False >>> np.isneginf([-np.inf, 0., np.inf]) array([ True, False, False])
>>> x = np.array([-np.inf, 0., np.inf]) >>> y = np.array([2, 2, 2]) >>> np.isneginf(x, y) array([1, 0, 0]) >>> y array([1, 0, 0])
-
symjax.tensor.
isposinf
(x, out=None)¶ Test element-wise for positive infinity, return result as bool array.
LAX-backend implementation of
isposinf()
. Original docstring below.Parameters: - x (array_like) – The input array.
- out (array_like, optional) – A location into which the result is stored. If provided, it must have a shape that the input broadcasts to. If not provided or None, a freshly-allocated boolean array is returned.
Returns: out – A boolean array with the same dimensions as the input. If second argument is not supplied then a boolean array is returned with values True where the corresponding element of the input is positive infinity and values False where the element of the input is not positive infinity.
If a second argument is supplied the result is stored there. If the type of that array is a numeric type the result is represented as zeros and ones, if the type is boolean then as False and True. The return value out is then a reference to that array.
Return type: ndarray
See also
Notes
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754).
Errors result if the second argument is also supplied when x is a scalar input, if first and second arguments have different shapes, or if the first argument has complex values
Examples
>>> np.isposinf(np.PINF) True >>> np.isposinf(np.inf) True >>> np.isposinf(np.NINF) False >>> np.isposinf([-np.inf, 0., np.inf]) array([False, False, True])
>>> x = np.array([-np.inf, 0., np.inf]) >>> y = np.array([2, 2, 2]) >>> np.isposinf(x, y) array([0, 0, 1]) >>> y array([0, 0, 1])
-
symjax.tensor.
isreal
(x)[source]¶ Returns a bool array, where True if input element is real.
LAX-backend implementation of
isreal()
. Original docstring below.If element has complex type with zero complex part, the return value for that element is True.
Parameters: x (array_like) – Input array. Returns: out – Boolean array of same shape as x. Return type: ndarray, bool Examples
>>> np.isreal([1+1j, 1+0j, 4.5, 3, 2, 2j]) array([False, True, True, True, True, False])
-
symjax.tensor.
isscalar
(element)[source]¶ Returns True if the type of element is a scalar type.
LAX-backend implementation of
isscalar()
. Original docstring below.Parameters: element (any) – Input argument, can be of any type and shape. Returns: val – True if element is a scalar type, False if it is not. Return type: bool See also
ndim()
- Get the number of dimensions of an array
Notes
If you need a stricter way to identify a numerical scalar, use
isinstance(x, numbers.Number)
, as that returnsFalse
for most non-numerical elements such as strings.In most cases
np.ndim(x) == 0
should be used instead of this function, as that will also return true for 0d arrays. This is how numpy overloads functions in the style of thedx
arguments to gradient and thebins
argument to histogram. Some key differences:x isscalar(x)
np.ndim(x) == 0
PEP 3141 numeric objects (including builtins) True
True
builtin string and buffer objects True
True
other builtin objects, like pathlib.Path, Exception, the result of re.compile False
True
third-party objects like matplotlib.figure.Figure False
True
zero-dimensional numpy arrays False
True
other numpy arrays False
False
list, tuple, and other sequence objects False
False
Examples
>>> np.isscalar(3.1) True >>> np.isscalar(np.array(3.1)) False >>> np.isscalar([3.1]) False >>> np.isscalar(False) True >>> np.isscalar('numpy') True
NumPy supports PEP 3141 numbers:
>>> from fractions import Fraction >>> np.isscalar(Fraction(5, 17)) True >>> from numbers import Number >>> np.isscalar(Number()) True
-
symjax.tensor.
issubdtype
(arg1, arg2)[source]¶ Returns True if first argument is a typecode lower/equal in type hierarchy.
LAX-backend implementation of
issubdtype()
. Original docstring below.Parameters: arg2 (arg1,) – dtype or string representing a typecode. Returns: out Return type: bool Examples
>>> np.issubdtype('S1', np.string_) True >>> np.issubdtype(np.float64, np.float32) False
-
symjax.tensor.
issubsctype
(arg1, arg2)¶ Determine if the first argument is a subclass of the second argument.
Parameters: arg2 (arg1,) – Data-types. Returns: out – The result. Return type: bool See also
issctype()
,issubdtype()
,obj2sctype()
Examples
>>> np.issubsctype('S8', str) False >>> np.issubsctype(np.array([1]), int) True >>> np.issubsctype(np.array([1]), float) False
-
symjax.tensor.
ix_
(*args)[source]¶ Construct an open mesh from multiple sequences.
LAX-backend implementation of
ix_()
. Original docstring below.This function takes N 1-D sequences and returns N outputs with N dimensions each, such that the shape is 1 in all but one dimension and the dimension with the non-unit shape value cycles through all N dimensions.
Using ix_ one can quickly construct index arrays that will index the cross product.
a[np.ix_([1,3],[2,5])]
returns the array[[a[1,2] a[1,5]], [a[3,2] a[3,5]]]
.Parameters: args (1-D sequences) – Each sequence should be of integer or boolean type. Boolean sequences will be interpreted as boolean masks for the corresponding dimension (equivalent to passing in np.nonzero(boolean_sequence)
).Returns: out – N arrays with N dimensions each, with N the number of input sequences. Together these arrays form an open mesh. Return type: tuple of ndarrays See also
ogrid()
,mgrid()
,meshgrid()
Examples
>>> a = np.arange(10).reshape(2, 5) >>> a array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> ixgrid = np.ix_([0, 1], [2, 4]) >>> ixgrid (array([[0], [1]]), array([[2, 4]])) >>> ixgrid[0].shape, ixgrid[1].shape ((2, 1), (1, 2)) >>> a[ixgrid] array([[2, 4], [7, 9]])
>>> ixgrid = np.ix_([True, True], [2, 4]) >>> a[ixgrid] array([[2, 4], [7, 9]]) >>> ixgrid = np.ix_([True, True], [False, False, True, False, True]) >>> a[ixgrid] array([[2, 4], [7, 9]])
-
symjax.tensor.
kron
(a, b)[source]¶ Kronecker product of two arrays.
LAX-backend implementation of
kron()
. Original docstring below.Computes the Kronecker product, a composite array made of blocks of the second array scaled by the first.
Parameters: b (a,) – Returns: out Return type: ndarray See also
outer()
- The outer product
Notes
The function assumes that the number of dimensions of a and b are the same, if necessary prepending the smallest with ones. If a.shape = (r0,r1,..,rN) and b.shape = (s0,s1,…,sN), the Kronecker product has shape (r0*s0, r1*s1, …, rN*SN). The elements are products of elements from a and b, organized explicitly by:
kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]
where:
kt = it * st + jt, t = 0,...,N
In the common 2-D case (N=1), the block structure can be visualized:
[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ], [ ... ... ], [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]
Examples
>>> np.kron([1,10,100], [5,6,7]) array([ 5, 6, 7, ..., 500, 600, 700]) >>> np.kron([5,6,7], [1,10,100]) array([ 5, 50, 500, ..., 7, 70, 700])
>>> np.kron(np.eye(2), np.ones((2,2))) array([[1., 1., 0., 0.], [1., 1., 0., 0.], [0., 0., 1., 1.], [0., 0., 1., 1.]])
>>> a = np.arange(100).reshape((2,5,2,5)) >>> b = np.arange(24).reshape((2,3,4)) >>> c = np.kron(a,b) >>> c.shape (2, 10, 6, 20) >>> I = (1,3,0,2) >>> J = (0,2,1) >>> J1 = (0,) + J # extend to ndim=4 >>> S1 = (1,) + b.shape >>> K = tuple(np.array(I) * np.array(S1) + np.array(J1)) >>> c[K] == a[I]*b[J] True
-
symjax.tensor.
lcm
(x1, x2)[source]¶ Returns the lowest common multiple of
|x1|
and|x2|
LAX-backend implementation of
lcm()
. Original docstring below.lcm(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – Arrays of values. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: y – The lowest common multiple of the absolute value of the inputs This is a scalar if both x1 and x2 are scalars. Return type: ndarray or scalar See also
gcd()
- The greatest common divisor
Examples
>>> np.lcm(12, 20) 60 >>> np.lcm.reduce([3, 12, 20]) 60 >>> np.lcm.reduce([40, 12, 20]) 120 >>> np.lcm(np.arange(6), 20) array([ 0, 20, 20, 60, 20, 20])
-
symjax.tensor.
left_shift
(x1, x2)¶ Shift the bits of an integer to the left.
LAX-backend implementation of
left_shift()
. Original docstring below.left_shift(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Bits are shifted to the left by appending x2 0s at the right of x1. Since the internal representation of numbers is in binary format, this operation is equivalent to multiplying x1 by
2**x2
.Parameters: - x1 (array_like of integer type) – Input values.
- x2 (array_like of integer type) – Number of zeros to append to x1. Has to be non-negative.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: out – Return x1 with bits shifted x2 times to the left. This is a scalar if both x1 and x2 are scalars.
Return type: array of integer type
See also
right_shift()
- Shift the bits of an integer to the right.
binary_repr()
- Return the binary representation of the input number as a string.
Examples
>>> np.binary_repr(5) '101' >>> np.left_shift(5, 2) 20 >>> np.binary_repr(20) '10100'
>>> np.left_shift(5, [1,2,3]) array([10, 20, 40])
Note that the dtype of the second argument may change the dtype of the result and can lead to unexpected results in some cases (see Casting Rules):
>>> a = np.left_shift(np.uint8(255), 1) # Expect 254 >>> print(a, type(a)) # Unexpected result due to upcasting 510 <class 'numpy.int64'> >>> b = np.left_shift(np.uint8(255), np.uint8(1)) >>> print(b, type(b)) 254 <class 'numpy.uint8'>
-
symjax.tensor.
less
(x1, x2)¶ Return the truth value of (x1 < x2) element-wise.
LAX-backend implementation of
less()
. Original docstring below.less(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – Input arrays. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: out – Output array, element-wise comparison of x1 and x2. Typically of type bool, unless dtype=object
is passed. This is a scalar if both x1 and x2 are scalars.Return type: ndarray or scalar See also
greater()
,less_equal()
,greater_equal()
,equal()
,not_equal()
Examples
>>> np.less([1, 2], [2, 2]) array([ True, False])
-
symjax.tensor.
less_equal
(x1, x2)¶ Return the truth value of (x1 =< x2) element-wise.
LAX-backend implementation of
less_equal()
. Original docstring below.less_equal(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – Input arrays. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: out – Output array, element-wise comparison of x1 and x2. Typically of type bool, unless dtype=object
is passed. This is a scalar if both x1 and x2 are scalars.Return type: ndarray or scalar See also
Examples
>>> np.less_equal([4, 2, 1], [2, 2, 2]) array([False, True, True])
-
symjax.tensor.
linspace
(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0)[source]¶ Return evenly spaced numbers over a specified interval.
LAX-backend implementation of
linspace()
. Original docstring below.Returns num evenly spaced samples, calculated over the interval [start, stop].
The endpoint of the interval can optionally be excluded.
Changed in version 1.16.0: Non-scalar start and stop are now supported.
Parameters: - start (array_like) – The starting value of the sequence.
- stop (array_like) – The end value of the sequence, unless endpoint is set to False.
In that case, the sequence consists of all but the last of
num + 1
evenly spaced samples, so that stop is excluded. Note that the step size changes when endpoint is False. - num (int, optional) – Number of samples to generate. Default is 50. Must be non-negative.
- endpoint (bool, optional) – If True, stop is the last sample. Otherwise, it is not included. Default is True.
- retstep (bool, optional) – If True, return (samples, step), where step is the spacing between samples.
- dtype (dtype, optional) – The type of the output array. If dtype is not given, infer the data type from the other input arguments.
- axis (int, optional) – The axis in the result to store the samples. Relevant only if start or stop are array-like. By default (0), the samples will be along a new axis inserted at the beginning. Use -1 to get an axis at the end.
Returns: samples (ndarray) – There are num equally spaced samples in the closed interval
[start, stop]
or the half-open interval[start, stop)
(depending on whether endpoint is True or False).step (float, optional) – Only returned if retstep is True
Size of spacing between samples.
See also
arange()
- Similar to linspace, but uses a step size (instead of the number of samples).
geomspace()
- Similar to linspace, but with numbers spaced evenly on a log scale (a geometric progression).
logspace()
- Similar to geomspace, but with the end points specified as logarithms.
Examples
>>> np.linspace(2.0, 3.0, num=5) array([2. , 2.25, 2.5 , 2.75, 3. ]) >>> np.linspace(2.0, 3.0, num=5, endpoint=False) array([2. , 2.2, 2.4, 2.6, 2.8]) >>> np.linspace(2.0, 3.0, num=5, retstep=True) (array([2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)
Graphical illustration:
>>> import matplotlib.pyplot as plt >>> N = 8 >>> y = np.zeros(N) >>> x1 = np.linspace(0, 10, N, endpoint=True) >>> x2 = np.linspace(0, 10, N, endpoint=False) >>> plt.plot(x1, y, 'o') [<matplotlib.lines.Line2D object at 0x...>] >>> plt.plot(x2, y + 0.5, 'o') [<matplotlib.lines.Line2D object at 0x...>] >>> plt.ylim([-0.5, 1]) (-0.5, 1) >>> plt.show()
-
symjax.tensor.
log
(x)¶ Natural logarithm, element-wise.
LAX-backend implementation of
log()
. Original docstring below.log(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The natural logarithm log is the inverse of the exponential function, so that log(exp(x)) = x. The natural logarithm is logarithm in base e.
Parameters: x (array_like) – Input value. Returns: y – The natural logarithm of x, element-wise. This is a scalar if x is a scalar. Return type: ndarray Notes
Logarithm is a multivalued function: for each x there is an infinite number of z such that exp(z) = x. The convention is to return the z whose imaginary part lies in [-pi, pi].
For real-valued input data types, log always returns real output. For each value that cannot be expressed as a real number or infinity, it yields
nan
and sets the invalid floating point error flag.For complex-valued input, log is a complex analytical function that has a branch cut [-inf, 0] and is continuous from above on it. log handles the floating-point negative zero as an infinitesimal negative number, conforming to the C99 standard.
References
[1] M. Abramowitz and I.A. Stegun, “Handbook of Mathematical Functions”, 10th printing, 1964, pp. 67. http://www.math.sfu.ca/~cbm/aands/ [2] Wikipedia, “Logarithm”. https://en.wikipedia.org/wiki/Logarithm Examples
>>> np.log([1, np.e, np.e**2, 0]) array([ 0., 1., 2., -Inf])
-
symjax.tensor.
log10
(x)[source]¶ Return the base 10 logarithm of the input array, element-wise.
LAX-backend implementation of
log10()
. Original docstring below.log10(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input values. Returns: y – The logarithm to the base 10 of x, element-wise. NaNs are returned where x is negative. This is a scalar if x is a scalar. Return type: ndarray See also
emath.log10()
Notes
Logarithm is a multivalued function: for each x there is an infinite number of z such that 10**z = x. The convention is to return the z whose imaginary part lies in [-pi, pi].
For real-valued input data types, log10 always returns real output. For each value that cannot be expressed as a real number or infinity, it yields
nan
and sets the invalid floating point error flag.For complex-valued input, log10 is a complex analytical function that has a branch cut [-inf, 0] and is continuous from above on it. log10 handles the floating-point negative zero as an infinitesimal negative number, conforming to the C99 standard.
References
[1] M. Abramowitz and I.A. Stegun, “Handbook of Mathematical Functions”, 10th printing, 1964, pp. 67. http://www.math.sfu.ca/~cbm/aands/ [2] Wikipedia, “Logarithm”. https://en.wikipedia.org/wiki/Logarithm Examples
>>> np.log10([1e-15, -3.]) array([-15., nan])
-
symjax.tensor.
log1p
(x)¶ Return the natural logarithm of one plus the input array, element-wise.
LAX-backend implementation of
log1p()
. Original docstring below.log1p(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Calculates
log(1 + x)
.Parameters: x (array_like) – Input values. Returns: y – Natural logarithm of 1 + x, element-wise. This is a scalar if x is a scalar. Return type: ndarray See also
expm1()
exp(x) - 1
, the inverse of log1p.
Notes
For real-valued input, log1p is accurate also for x so small that 1 + x == 1 in floating-point accuracy.
Logarithm is a multivalued function: for each x there is an infinite number of z such that exp(z) = 1 + x. The convention is to return the z whose imaginary part lies in [-pi, pi].
For real-valued input data types, log1p always returns real output. For each value that cannot be expressed as a real number or infinity, it yields
nan
and sets the invalid floating point error flag.For complex-valued input, log1p is a complex analytical function that has a branch cut [-inf, -1] and is continuous from above on it. log1p handles the floating-point negative zero as an infinitesimal negative number, conforming to the C99 standard.
References
[1] M. Abramowitz and I.A. Stegun, “Handbook of Mathematical Functions”, 10th printing, 1964, pp. 67. http://www.math.sfu.ca/~cbm/aands/ [2] Wikipedia, “Logarithm”. https://en.wikipedia.org/wiki/Logarithm Examples
>>> np.log1p(1e-99) 1e-99 >>> np.log(1 + 1e-99) 0.0
-
symjax.tensor.
log2
(x)[source]¶ Base-2 logarithm of x.
LAX-backend implementation of
log2()
. Original docstring below.log2(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input values. Returns: y – Base-2 logarithm of x. This is a scalar if x is a scalar. Return type: ndarray Notes
New in version 1.3.0.
Logarithm is a multivalued function: for each x there is an infinite number of z such that 2**z = x. The convention is to return the z whose imaginary part lies in [-pi, pi].
For real-valued input data types, log2 always returns real output. For each value that cannot be expressed as a real number or infinity, it yields
nan
and sets the invalid floating point error flag.For complex-valued input, log2 is a complex analytical function that has a branch cut [-inf, 0] and is continuous from above on it. log2 handles the floating-point negative zero as an infinitesimal negative number, conforming to the C99 standard.
Examples
>>> x = np.array([0, 1, 2, 2**4]) >>> np.log2(x) array([-Inf, 0., 1., 4.])
>>> xi = np.array([0+1.j, 1, 2+0.j, 4.j]) >>> np.log2(xi) array([ 0.+2.26618007j, 0.+0.j , 1.+0.j , 2.+2.26618007j])
-
symjax.tensor.
logaddexp
(x1, x2)[source]¶ Logarithm of the sum of exponentiations of the inputs.
LAX-backend implementation of
logaddexp()
. Original docstring below.logaddexp(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Calculates
log(exp(x1) + exp(x2))
. This function is useful in statistics where the calculated probabilities of events may be so small as to exceed the range of normal floating point numbers. In such cases the logarithm of the calculated probability is stored. This function allows adding probabilities stored in such a fashion.Parameters: x2 (x1,) – Input values. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: result – Logarithm of exp(x1) + exp(x2)
. This is a scalar if both x1 and x2 are scalars.Return type: ndarray See also
logaddexp2()
- Logarithm of the sum of exponentiations of inputs in base 2.
Notes
New in version 1.3.0.
Examples
>>> prob1 = np.log(1e-50) >>> prob2 = np.log(2.5e-50) >>> prob12 = np.logaddexp(prob1, prob2) >>> prob12 -113.87649168120691 >>> np.exp(prob12) 3.5000000000000057e-50
-
symjax.tensor.
logaddexp2
(x1, x2)[source]¶ Logarithm of the sum of exponentiations of the inputs in base-2.
LAX-backend implementation of
logaddexp2()
. Original docstring below.logaddexp2(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Calculates
log2(2**x1 + 2**x2)
. This function is useful in machine learning when the calculated probabilities of events may be so small as to exceed the range of normal floating point numbers. In such cases the base-2 logarithm of the calculated probability can be used instead. This function allows adding probabilities stored in such a fashion.Parameters: x2 (x1,) – Input values. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: result – Base-2 logarithm of 2**x1 + 2**x2
. This is a scalar if both x1 and x2 are scalars.Return type: ndarray See also
logaddexp()
- Logarithm of the sum of exponentiations of the inputs.
Notes
New in version 1.3.0.
Examples
>>> prob1 = np.log2(1e-50) >>> prob2 = np.log2(2.5e-50) >>> prob12 = np.logaddexp2(prob1, prob2) >>> prob1, prob2, prob12 (-166.09640474436813, -164.77447664948076, -164.28904982231052) >>> 2**prob12 3.4999999999999914e-50
-
symjax.tensor.
logical_and
(*args)¶ Compute the truth value of x1 AND x2 element-wise.
LAX-backend implementation of
logical_and()
. Original docstring below.logical_and(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: - x2 (x1,) – Input arrays.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output). - out (ndarray, None, or tuple of ndarray and None, optional) – A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or None, a freshly-allocated array is returned. A tuple (possible only as a keyword argument) must have length equal to the number of outputs.
- where (array_like, optional) – This condition is broadcast over the input. At locations where the
condition is True, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value.
Note that if an uninitialized out array is created via the default
out=None
, locations within it where the condition is False will remain uninitialized. - **kwargs – For other keyword-only arguments, see the ufunc docs.
Returns: y – Boolean result of the logical AND operation applied to the elements of x1 and x2; the shape is determined by broadcasting. This is a scalar if both x1 and x2 are scalars.
Return type: ndarray or bool
See also
Examples
>>> np.logical_and(True, False) False >>> np.logical_and([True, False], [False, False]) array([False, False])
>>> x = np.arange(5) >>> np.logical_and(x>1, x<4) array([False, False, True, True, False])
- x2 (x1,) – Input arrays.
If
-
symjax.tensor.
logical_not
(*args)¶ Compute the truth value of NOT x element-wise.
LAX-backend implementation of
logical_not()
. Original docstring below.logical_not(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: - x (array_like) – Logical NOT is applied to the elements of x.
- out (ndarray, None, or tuple of ndarray and None, optional) – A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or None, a freshly-allocated array is returned. A tuple (possible only as a keyword argument) must have length equal to the number of outputs.
- where (array_like, optional) – This condition is broadcast over the input. At locations where the
condition is True, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value.
Note that if an uninitialized out array is created via the default
out=None
, locations within it where the condition is False will remain uninitialized. - **kwargs – For other keyword-only arguments, see the ufunc docs.
Returns: y – Boolean result with the same shape as x of the NOT operation on elements of x. This is a scalar if x is a scalar.
Return type: bool or ndarray of bool
See also
Examples
>>> np.logical_not(3) False >>> np.logical_not([True, False, 0, 1]) array([False, True, True, False])
>>> x = np.arange(5) >>> np.logical_not(x<3) array([False, False, False, True, True])
-
symjax.tensor.
logical_or
(*args)¶ Compute the truth value of x1 OR x2 element-wise.
LAX-backend implementation of
logical_or()
. Original docstring below.logical_or(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: - x2 (x1,) – Logical OR is applied to the elements of x1 and x2.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output). - out (ndarray, None, or tuple of ndarray and None, optional) – A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or None, a freshly-allocated array is returned. A tuple (possible only as a keyword argument) must have length equal to the number of outputs.
- where (array_like, optional) – This condition is broadcast over the input. At locations where the
condition is True, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value.
Note that if an uninitialized out array is created via the default
out=None
, locations within it where the condition is False will remain uninitialized. - **kwargs – For other keyword-only arguments, see the ufunc docs.
Returns: y – Boolean result of the logical OR operation applied to the elements of x1 and x2; the shape is determined by broadcasting. This is a scalar if both x1 and x2 are scalars.
Return type: ndarray or bool
See also
Examples
>>> np.logical_or(True, False) True >>> np.logical_or([True, False], [False, False]) array([ True, False])
>>> x = np.arange(5) >>> np.logical_or(x < 1, x > 3) array([ True, False, False, False, True])
- x2 (x1,) – Logical OR is applied to the elements of x1 and x2.
If
-
symjax.tensor.
logical_xor
(*args)¶ Compute the truth value of x1 XOR x2, element-wise.
LAX-backend implementation of
logical_xor()
. Original docstring below.logical_xor(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: - x2 (x1,) – Logical XOR is applied to the elements of x1 and x2.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output). - out (ndarray, None, or tuple of ndarray and None, optional) – A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or None, a freshly-allocated array is returned. A tuple (possible only as a keyword argument) must have length equal to the number of outputs.
- where (array_like, optional) – This condition is broadcast over the input. At locations where the
condition is True, the out array will be set to the ufunc result.
Elsewhere, the out array will retain its original value.
Note that if an uninitialized out array is created via the default
out=None
, locations within it where the condition is False will remain uninitialized. - **kwargs – For other keyword-only arguments, see the ufunc docs.
Returns: y – Boolean result of the logical XOR operation applied to the elements of x1 and x2; the shape is determined by broadcasting. This is a scalar if both x1 and x2 are scalars.
Return type: bool or ndarray of bool
See also
Examples
>>> np.logical_xor(True, False) True >>> np.logical_xor([True, True, False, False], [True, False, True, False]) array([False, True, True, False])
>>> x = np.arange(5) >>> np.logical_xor(x < 1, x > 3) array([ True, False, False, False, True])
Simple example showing support of broadcasting
>>> np.logical_xor(0, np.eye(2)) array([[ True, False], [False, True]])
- x2 (x1,) – Logical XOR is applied to the elements of x1 and x2.
If
-
symjax.tensor.
logspace
(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0)[source]¶ Return numbers spaced evenly on a log scale.
LAX-backend implementation of
logspace()
. Original docstring below.In linear space, the sequence starts at
base ** start
(base to the power of start) and ends withbase ** stop
(see endpoint below).Changed in version 1.16.0: Non-scalar start and stop are now supported.
Parameters: - start (array_like) –
base ** start
is the starting value of the sequence. - stop (array_like) –
base ** stop
is the final value of the sequence, unless endpoint is False. In that case,num + 1
values are spaced over the interval in log-space, of which all but the last (a sequence of length num) are returned. - num (integer, optional) – Number of samples to generate. Default is 50.
- endpoint (boolean, optional) – If true, stop is the last sample. Otherwise, it is not included. Default is True.
- base (float, optional) – The base of the log space. The step size between the elements in
ln(samples) / ln(base)
(orlog_base(samples)
) is uniform. Default is 10.0. - dtype (dtype) – The type of the output array. If dtype is not given, infer the data type from the other input arguments.
- axis (int, optional) – The axis in the result to store the samples. Relevant only if start or stop are array-like. By default (0), the samples will be along a new axis inserted at the beginning. Use -1 to get an axis at the end.
Returns: samples – num samples, equally spaced on a log scale.
Return type: ndarray
See also
arange()
- Similar to linspace, with the step size specified instead of the number of samples. Note that, when used with a float endpoint, the endpoint may or may not be included.
linspace()
- Similar to logspace, but with the samples uniformly distributed in linear space, instead of log space.
geomspace()
- Similar to logspace, but with endpoints specified directly.
Notes
Logspace is equivalent to the code
>>> y = np.linspace(start, stop, num=num, endpoint=endpoint) ... # doctest: +SKIP >>> power(base, y).astype(dtype) ... # doctest: +SKIP
Examples
>>> np.logspace(2.0, 3.0, num=4) array([ 100. , 215.443469 , 464.15888336, 1000. ]) >>> np.logspace(2.0, 3.0, num=4, endpoint=False) array([100. , 177.827941 , 316.22776602, 562.34132519]) >>> np.logspace(2.0, 3.0, num=4, base=2.0) array([4. , 5.0396842 , 6.34960421, 8. ])
Graphical illustration:
>>> import matplotlib.pyplot as plt >>> N = 10 >>> x1 = np.logspace(0.1, 1, N, endpoint=True) >>> x2 = np.logspace(0.1, 1, N, endpoint=False) >>> y = np.zeros(N) >>> plt.plot(x1, y, 'o') [<matplotlib.lines.Line2D object at 0x...>] >>> plt.plot(x2, y + 0.5, 'o') [<matplotlib.lines.Line2D object at 0x...>] >>> plt.ylim([-0.5, 1]) (-0.5, 1) >>> plt.show()
- start (array_like) –
-
symjax.tensor.
matmul
(a, b, *, precision=None)[source]¶ Matrix product of two arrays.
LAX-backend implementation of
matmul()
. In addition to the original NumPy arguments listed below, also supportsprecision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
matmul(x1, x2, /, out=None, *, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: out (ndarray, optional) – A location into which the result is stored. If provided, it must have a shape that matches the signature (n,k),(k,m)->(n,m). If not provided or None, a freshly-allocated array is returned.
Returns: y – The matrix product of the inputs. This is a scalar only when both x1, x2 are 1-d vectors.
Return type: ndarray
Raises: ValueError
– If the last dimension of a is not the same size as the second-to-last dimension of b.If a scalar value is passed in.
See also
vdot()
- Complex-conjugating dot product.
tensordot()
- Sum products over arbitrary axes.
einsum()
- Einstein summation convention.
dot()
- alternative matrix product with different broadcasting rules.
Notes
The behavior depends on the arguments in the following way.
- If both arguments are 2-D they are multiplied like conventional matrices.
- If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
- If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
- If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
matmul
differs fromdot
in two important ways:Multiplication by scalars is not allowed, use
*
instead.Stacks of matrices are broadcast together as if the matrices were elements, respecting the signature
(n,k),(k,m)->(n,m)
:>>> a = np.ones([9, 5, 7, 4]) >>> c = np.ones([9, 5, 4, 3]) >>> np.dot(a, c).shape (9, 5, 7, 9, 5, 3) >>> np.matmul(a, c).shape (9, 5, 7, 3) >>> # n is 7, k is 4, m is 3
The matmul function implements the semantics of the @ operator introduced in Python 3.5 following PEP465.
Examples
For 2-D arrays it is the matrix product:
>>> a = np.array([[1, 0], ... [0, 1]]) >>> b = np.array([[4, 1], ... [2, 2]]) >>> np.matmul(a, b) array([[4, 1], [2, 2]])
For 2-D mixed with 1-D, the result is the usual.
>>> a = np.array([[1, 0], ... [0, 1]]) >>> b = np.array([1, 2]) >>> np.matmul(a, b) array([1, 2]) >>> np.matmul(b, a) array([1, 2])
Broadcasting is conventional for stacks of arrays
>>> a = np.arange(2 * 2 * 4).reshape((2, 2, 4)) >>> b = np.arange(2 * 2 * 4).reshape((2, 4, 2)) >>> np.matmul(a,b).shape (2, 2, 2) >>> np.matmul(a, b)[0, 1, 1] 98 >>> sum(a[0, 1, :] * b[0 , :, 1]) 98
Vector, vector returns the scalar inner product, but neither argument is complex-conjugated:
>>> np.matmul([2j, 3j], [2j, 3j]) (-13+0j)
Scalar multiplication raises an error.
>>> np.matmul([1,2], 3) Traceback (most recent call last): ... ValueError: matmul: Input operand 1 does not have enough dimensions ...
New in version 1.10.0.
-
symjax.tensor.
max
(a, axis=None, out=None, keepdims=None, initial=None, where=None)[source]¶ Return the maximum of an array or maximum along an axis.
LAX-backend implementation of
amax()
. Original docstring below.Parameters: - a (array_like) – Input data.
- axis (None or int or tuple of ints, optional) – Axis or axes along which to operate. By default, flattened input is used.
- out (ndarray, optional) – Alternative output array in which to place the result. Must be of the same shape and buffer length as the expected output. See ufuncs-output-type for more details.
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
- initial (scalar, optional) – The minimum value of an output element. Must be present to allow computation on empty slice. See ~numpy.ufunc.reduce for details.
- where (array_like of bool, optional) – Elements to compare for the maximum. See ~numpy.ufunc.reduce for details.
Returns: amax – Maximum of a. If axis is None, the result is a scalar value. If axis is given, the result is an array of dimension
a.ndim - 1
.Return type: ndarray or scalar
See also
amin()
- The minimum value of an array along a given axis, propagating any NaNs.
nanmax()
- The maximum value of an array along a given axis, ignoring any NaNs.
maximum()
- Element-wise maximum of two arrays, propagating any NaNs.
fmax()
- Element-wise maximum of two arrays, ignoring any NaNs.
argmax()
- Return the indices of the maximum values.
Notes
NaN values are propagated, that is if at least one item is NaN, the corresponding max value will be NaN as well. To ignore NaN values (MATLAB behavior), please use nanmax.
Don’t use amax for element-wise comparison of 2 arrays; when
a.shape[0]
is 2,maximum(a[0], a[1])
is faster thanamax(a, axis=0)
.Examples
>>> a = np.arange(4).reshape((2,2)) >>> a array([[0, 1], [2, 3]]) >>> np.amax(a) # Maximum of the flattened array 3 >>> np.amax(a, axis=0) # Maxima along the first axis array([2, 3]) >>> np.amax(a, axis=1) # Maxima along the second axis array([1, 3]) >>> np.amax(a, where=[False, True], initial=-1, axis=0) array([-1, 3]) >>> b = np.arange(5, dtype=float) >>> b[2] = np.NaN >>> np.amax(b) nan >>> np.amax(b, where=~np.isnan(b), initial=-1) 4.0 >>> np.nanmax(b) 4.0
You can use an initial value to compute the maximum of an empty slice, or to initialize it to a different value:
>>> np.max([[-50], [10]], axis=-1, initial=0) array([ 0, 10])
Notice that the initial value is used as one of the elements for which the maximum is determined, unlike for the default argument Python’s max function, which is only used for empty iterables.
>>> np.max([5], initial=6) 6 >>> max([5], default=6) 5
-
symjax.tensor.
maximum
(x1, x2)¶ Element-wise maximum of array elements.
LAX-backend implementation of
maximum()
. Original docstring below.maximum(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Compare two arrays and returns a new array containing the element-wise maxima. If one of the elements being compared is a NaN, then that element is returned. If both elements are NaNs then the first is returned. The latter distinction is important for complex NaNs, which are defined as at least one of the real or imaginary parts being a NaN. The net effect is that NaNs are propagated.
Parameters: x2 (x1,) – The arrays holding the elements to be compared. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: y – The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. Return type: ndarray or scalar See also
Notes
The maximum is equivalent to
np.where(x1 >= x2, x1, x2)
when neither x1 nor x2 are nans, but it is faster and does proper broadcasting.Examples
>>> np.maximum([2, 3, 4], [1, 5, 2]) array([2, 5, 4])
>>> np.maximum(np.eye(2), [0.5, 2]) # broadcasting array([[ 1. , 2. ], [ 0.5, 2. ]])
>>> np.maximum([np.nan, 0, np.nan], [0, np.nan, np.nan]) array([nan, nan, nan]) >>> np.maximum(np.Inf, 1) inf
-
symjax.tensor.
mean
(a, axis=None, dtype=None, out=None, keepdims=False)[source]¶ Compute the arithmetic mean along the specified axis.
LAX-backend implementation of
mean()
. Original docstring below.Returns the average of the array elements. The average is taken over the flattened array by default, otherwise over the specified axis. float64 intermediate and return values are used for integer inputs.
Parameters: - a (array_like) – Array containing numbers whose mean is desired. If a is not an array, a conversion is attempted.
- axis (None or int or tuple of ints, optional) – Axis or axes along which the means are computed. The default is to compute the mean of the flattened array.
- dtype (data-type, optional) – Type to use in computing the mean. For integer inputs, the default is float64; for floating point inputs, it is the same as the input dtype.
- out (ndarray, optional) – Alternate output array in which to place the result. The default
is
None
; if provided, it must have the same shape as the expected output, but the type will be cast if necessary. See ufuncs-output-type for more details. - keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
Returns: m – If out=None, returns a new array containing the mean values, otherwise a reference to the output array is returned.
Return type: ndarray, see dtype parameter above
Notes
The arithmetic mean is the sum of the elements along the axis divided by the number of elements.
Note that for floating-point input, the mean is computed using the same precision the input has. Depending on the input data, this can cause the results to be inaccurate, especially for float32 (see example below). Specifying a higher-precision accumulator using the dtype keyword can alleviate this issue.
By default, float16 results are computed using float32 intermediates for extra precision.
Examples
>>> a = np.array([[1, 2], [3, 4]]) >>> np.mean(a) 2.5 >>> np.mean(a, axis=0) array([2., 3.]) >>> np.mean(a, axis=1) array([1.5, 3.5])
In single precision, mean can be inaccurate:
>>> a = np.zeros((2, 512*512), dtype=np.float32) >>> a[0, :] = 1.0 >>> a[1, :] = 0.1 >>> np.mean(a) 0.54999924
Computing the mean in float64 is more accurate:
>>> np.mean(a, dtype=np.float64) 0.55000000074505806 # may vary
-
symjax.tensor.
median
(a, axis=None, out=None, overwrite_input=False, keepdims=False)[source]¶ Compute the median along the specified axis.
LAX-backend implementation of
median()
. Original docstring below.Returns the median of the array elements.
Parameters: - a (array_like) – Input array or object that can be converted to an array.
- axis ({int, sequence of int, None}, optional) – Axis or axes along which the medians are computed. The default is to compute the median along a flattened version of the array. A sequence of axes is supported since version 1.9.0.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output, but the type (of the output) will be cast if necessary.
- overwrite_input (bool, optional) –
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original arr.
Returns: median – A new array holding the result. If the input contains integers or floats smaller than
float64
, then the output data-type isnp.float64
. Otherwise, the data-type of the output is the same as that of the input. If out is specified, that array is returned instead.Return type: ndarray
See also
Notes
Given a vector
V
of lengthN
, the median ofV
is the middle value of a sorted copy ofV
,V_sorted
- i e.,V_sorted[(N-1)/2]
, whenN
is odd, and the average of the two middle values ofV_sorted
whenN
is even.Examples
>>> a = np.array([[10, 7, 4], [3, 2, 1]]) >>> a array([[10, 7, 4], [ 3, 2, 1]]) >>> np.median(a) 3.5 >>> np.median(a, axis=0) array([6.5, 4.5, 2.5]) >>> np.median(a, axis=1) array([7., 2.]) >>> m = np.median(a, axis=0) >>> out = np.zeros_like(m) >>> np.median(a, axis=0, out=m) array([6.5, 4.5, 2.5]) >>> m array([6.5, 4.5, 2.5]) >>> b = a.copy() >>> np.median(b, axis=1, overwrite_input=True) array([7., 2.]) >>> assert not np.all(a==b) >>> b = a.copy() >>> np.median(b, axis=None, overwrite_input=True) 3.5 >>> assert not np.all(a==b)
-
symjax.tensor.
meshgrid
(*args, **kwargs)[source]¶ Return coordinate matrices from coordinate vectors.
LAX-backend implementation of
meshgrid()
. Original docstring below.Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
Changed in version 1.9: 1-D and 0-D cases are allowed.
Parameters: - indexing ({'xy', 'ij'}, optional) – Cartesian (‘xy’, default) or matrix (‘ij’) indexing of output. See Notes for more details.
- sparse (bool, optional) – If True a sparse grid is returned in order to conserve memory. Default is False.
- copy (bool, optional) – If False, a view into the original arrays are returned in order to
conserve memory. Default is True. Please note that
sparse=False, copy=False
will likely return non-contiguous arrays. Furthermore, more than one element of a broadcast array may refer to a single memory location. If you need to write to the arrays, make copies first.
Returns: X1, X2,…, XN – For vectors x1, x2,…, ‘xn’ with lengths
Ni=len(xi)
, return(N1, N2, N3,...Nn)
shaped arrays if indexing=’ij’ or(N2, N1, N3,...Nn)
shaped arrays if indexing=’xy’ with the elements of xi repeated to fill the matrix along the first dimension for x1, the second for x2 and so on.Return type: ndarray
Notes
This function supports both indexing conventions through the indexing keyword argument. Giving the string ‘ij’ returns a meshgrid with matrix indexing, while ‘xy’ returns a meshgrid with Cartesian indexing. In the 2-D case with inputs of length M and N, the outputs are of shape (N, M) for ‘xy’ indexing and (M, N) for ‘ij’ indexing. In the 3-D case with inputs of length M, N and P, outputs are of shape (N, M, P) for ‘xy’ indexing and (M, N, P) for ‘ij’ indexing. The difference is illustrated by the following code snippet:
xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij') for i in range(nx): for j in range(ny): # treat xv[i,j], yv[i,j] xv, yv = np.meshgrid(x, y, sparse=False, indexing='xy') for i in range(nx): for j in range(ny): # treat xv[j,i], yv[j,i]
In the 1-D and 0-D case, the indexing and sparse keywords have no effect.
See also
index_tricks.mgrid()
- Construct a multi-dimensional “meshgrid” using indexing notation.
index_tricks.ogrid()
- Construct an open multi-dimensional “meshgrid” using indexing notation.
Examples
>>> nx, ny = (3, 2) >>> x = np.linspace(0, 1, nx) >>> y = np.linspace(0, 1, ny) >>> xv, yv = np.meshgrid(x, y) >>> xv array([[0. , 0.5, 1. ], [0. , 0.5, 1. ]]) >>> yv array([[0., 0., 0.], [1., 1., 1.]]) >>> xv, yv = np.meshgrid(x, y, sparse=True) # make sparse output arrays >>> xv array([[0. , 0.5, 1. ]]) >>> yv array([[0.], [1.]])
meshgrid is very useful to evaluate functions on a grid.
>>> import matplotlib.pyplot as plt >>> x = np.arange(-5, 5, 0.1) >>> y = np.arange(-5, 5, 0.1) >>> xx, yy = np.meshgrid(x, y, sparse=True) >>> z = np.sin(xx**2 + yy**2) / (xx**2 + yy**2) >>> h = plt.contourf(x,y,z) >>> plt.show()
-
symjax.tensor.
min
(a, axis=None, out=None, keepdims=None, initial=None, where=None)[source]¶ Return the minimum of an array or minimum along an axis.
LAX-backend implementation of
amin()
. Original docstring below.Parameters: - a (array_like) – Input data.
- axis (None or int or tuple of ints, optional) – Axis or axes along which to operate. By default, flattened input is used.
- out (ndarray, optional) – Alternative output array in which to place the result. Must be of the same shape and buffer length as the expected output. See ufuncs-output-type for more details.
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
- initial (scalar, optional) – The maximum value of an output element. Must be present to allow computation on empty slice. See ~numpy.ufunc.reduce for details.
- where (array_like of bool, optional) – Elements to compare for the minimum. See ~numpy.ufunc.reduce for details.
Returns: amin – Minimum of a. If axis is None, the result is a scalar value. If axis is given, the result is an array of dimension
a.ndim - 1
.Return type: ndarray or scalar
See also
amax()
- The maximum value of an array along a given axis, propagating any NaNs.
nanmin()
- The minimum value of an array along a given axis, ignoring any NaNs.
minimum()
- Element-wise minimum of two arrays, propagating any NaNs.
fmin()
- Element-wise minimum of two arrays, ignoring any NaNs.
argmin()
- Return the indices of the minimum values.
Notes
NaN values are propagated, that is if at least one item is NaN, the corresponding min value will be NaN as well. To ignore NaN values (MATLAB behavior), please use nanmin.
Don’t use amin for element-wise comparison of 2 arrays; when
a.shape[0]
is 2,minimum(a[0], a[1])
is faster thanamin(a, axis=0)
.Examples
>>> a = np.arange(4).reshape((2,2)) >>> a array([[0, 1], [2, 3]]) >>> np.amin(a) # Minimum of the flattened array 0 >>> np.amin(a, axis=0) # Minima along the first axis array([0, 1]) >>> np.amin(a, axis=1) # Minima along the second axis array([0, 2]) >>> np.amin(a, where=[False, True], initial=10, axis=0) array([10, 1])
>>> b = np.arange(5, dtype=float) >>> b[2] = np.NaN >>> np.amin(b) nan >>> np.amin(b, where=~np.isnan(b), initial=10) 0.0 >>> np.nanmin(b) 0.0
>>> np.min([[-50], [10]], axis=-1, initial=0) array([-50, 0])
Notice that the initial value is used as one of the elements for which the minimum is determined, unlike for the default argument Python’s max function, which is only used for empty iterables.
Notice that this isn’t the same as Python’s
default
argument.>>> np.min([6], initial=5) 5 >>> min([6], default=5) 6
-
symjax.tensor.
minimum
(x1, x2)¶ Element-wise minimum of array elements.
LAX-backend implementation of
minimum()
. Original docstring below.minimum(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Compare two arrays and returns a new array containing the element-wise minima. If one of the elements being compared is a NaN, then that element is returned. If both elements are NaNs then the first is returned. The latter distinction is important for complex NaNs, which are defined as at least one of the real or imaginary parts being a NaN. The net effect is that NaNs are propagated.
Parameters: x2 (x1,) – The arrays holding the elements to be compared. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: y – The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. Return type: ndarray or scalar See also
Notes
The minimum is equivalent to
np.where(x1 <= x2, x1, x2)
when neither x1 nor x2 are NaNs, but it is faster and does proper broadcasting.Examples
>>> np.minimum([2, 3, 4], [1, 5, 2]) array([1, 3, 2])
>>> np.minimum(np.eye(2), [0.5, 2]) # broadcasting array([[ 0.5, 0. ], [ 0. , 1. ]])
>>> np.minimum([np.nan, 0, np.nan],[0, np.nan, np.nan]) array([nan, nan, nan]) >>> np.minimum(-np.Inf, 1) -inf
-
symjax.tensor.
mod
(x1, x2)¶ Return element-wise remainder of division.
LAX-backend implementation of
remainder()
. Original docstring below.remainder(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Computes the remainder complementary to the floor_divide function. It is equivalent to the Python modulus operator``x1 % x2`` and has the same sign as the divisor x2. The MATLAB function equivalent to
np.remainder
ismod
.Warning
This should not be confused with:
- Python 3.7’s math.remainder and C’s
remainder
, which computes the IEEE remainder, which are the complement toround(x1 / x2)
. - The MATLAB
rem
function and or the C%
operator which is the complement toint(x1 / x2)
.
Parameters: - x1 (array_like) – Dividend array.
- x2 (array_like) – Divisor array.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: y – The element-wise remainder of the quotient
floor_divide(x1, x2)
. This is a scalar if both x1 and x2 are scalars.Return type: ndarray
See also
floor_divide()
- Equivalent of Python
//
operator. divmod()
- Simultaneous floor division and remainder.
fmod()
- Equivalent of the MATLAB
rem
function.
Notes
Returns 0 when x2 is 0 and both x1 and x2 are (arrays of) integers.
mod
is an alias ofremainder
.Examples
>>> np.remainder([4, 7], [2, 3]) array([0, 1]) >>> np.remainder(np.arange(7), 5) array([0, 1, 2, 3, 4, 0, 1])
- Python 3.7’s math.remainder and C’s
-
symjax.tensor.
moveaxis
(a, source, destination)[source]¶ Move axes of an array to new positions.
LAX-backend implementation of
moveaxis()
. Original docstring below.Other axes remain in their original order.
New in version 1.11.0.
Parameters: - a (np.ndarray) – The array whose axes should be reordered.
- source (int or sequence of int) – Original positions of the axes to move. These must be unique.
- destination (int or sequence of int) – Destination positions for each of the original axes. These must also be unique.
Returns: result – Array with moved axes. This array is a view of the input array.
Return type: np.ndarray
See also
transpose()
- Permute the dimensions of an array.
swapaxes()
- Interchange two axes of an array.
Examples
>>> x = np.zeros((3, 4, 5)) >>> np.moveaxis(x, 0, -1).shape (4, 5, 3) >>> np.moveaxis(x, -1, 0).shape (5, 3, 4)
These all achieve the same result:
>>> np.transpose(x).shape (5, 4, 3) >>> np.swapaxes(x, 0, -1).shape (5, 4, 3) >>> np.moveaxis(x, [0, 1], [-1, -2]).shape (5, 4, 3) >>> np.moveaxis(x, [0, 1, 2], [-1, -2, -3]).shape (5, 4, 3)
-
symjax.tensor.
multiply
(x1, x2)¶ Multiply arguments element-wise.
LAX-backend implementation of
multiply()
. Original docstring below.multiply(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – Input arrays to be multiplied. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: y – The product of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. Return type: ndarray Notes
Equivalent to x1 * x2 in terms of array broadcasting.
Examples
>>> np.multiply(2.0, 4.0) 8.0
>>> x1 = np.arange(9.0).reshape((3, 3)) >>> x2 = np.arange(3.0) >>> np.multiply(x1, x2) array([[ 0., 1., 4.], [ 0., 4., 10.], [ 0., 7., 16.]])
-
symjax.tensor.
nan_to_num
(x, copy=True, nan=0.0, posinf=None, neginf=None)[source]¶ - Replace NaN with zero and infinity with large finite numbers (default
- behaviour) or with the numbers defined by the user using the nan, posinf and/or neginf keywords.
LAX-backend implementation of
nan_to_num()
. Original docstring below.If x is inexact, NaN is replaced by zero or by the user defined value in nan keyword, infinity is replaced by the largest finite floating point values representable by
x.dtype
or by the user defined value in posinf keyword and -infinity is replaced by the most negative finite floating point values representable byx.dtype
or by the user defined value in neginf keyword.For complex dtypes, the above is applied to each of the real and imaginary components of x separately.
If x is not inexact, then no replacements are made.
Parameters: - x (scalar or array_like) – Input data.
- copy (bool, optional) –
Whether to create a copy of x (True) or to replace values in-place (False). The in-place operation only occurs if casting to an array does not require a copy. Default is True.
New in version 1.13.
- nan (int, float, optional) –
Value to be used to fill NaN values. If no value is passed then NaN values will be replaced with 0.0.
New in version 1.17.
- posinf (int, float, optional) –
Value to be used to fill positive infinity values. If no value is passed then positive infinity values will be replaced with a very large number.
New in version 1.17.
- neginf (int, float, optional) –
Value to be used to fill negative infinity values. If no value is passed then negative infinity values will be replaced with a very small (or negative) number.
New in version 1.17.
Returns: out – x, with the non-finite values replaced. If copy is False, this may be x itself.
Return type: ndarray
See also
isinf()
- Shows which elements are positive or negative infinity.
isneginf()
- Shows which elements are negative infinity.
isposinf()
- Shows which elements are positive infinity.
isnan()
- Shows which elements are Not a Number (NaN).
isfinite()
- Shows which elements are finite (not NaN, not infinity)
Notes
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754). This means that Not a Number is not equivalent to infinity.
Examples
>>> np.nan_to_num(np.inf) 1.7976931348623157e+308 >>> np.nan_to_num(-np.inf) -1.7976931348623157e+308 >>> np.nan_to_num(np.nan) 0.0 >>> x = np.array([np.inf, -np.inf, np.nan, -128, 128]) >>> np.nan_to_num(x) array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary -1.28000000e+002, 1.28000000e+002]) >>> np.nan_to_num(x, nan=-9999, posinf=33333333, neginf=33333333) array([ 3.3333333e+07, 3.3333333e+07, -9.9990000e+03, -1.2800000e+02, 1.2800000e+02]) >>> y = np.array([complex(np.inf, np.nan), np.nan, complex(np.nan, np.inf)]) array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary -1.28000000e+002, 1.28000000e+002]) >>> np.nan_to_num(y) array([ 1.79769313e+308 +0.00000000e+000j, # may vary 0.00000000e+000 +0.00000000e+000j, 0.00000000e+000 +1.79769313e+308j]) >>> np.nan_to_num(y, nan=111111, posinf=222222) array([222222.+111111.j, 111111. +0.j, 111111.+222222.j])
-
symjax.tensor.
nancumprod
(a, axis=None, dtype=None, out=None)¶ - Return the cumulative product of array elements over a given axis treating Not a
- Numbers (NaNs) as one. The cumulative product does not change when NaNs are encountered and leading NaNs are replaced by ones.
LAX-backend implementation of
nancumprod()
. Original docstring below.Ones are returned for slices that are all-NaN or empty.
New in version 1.12.0.
Parameters: - a (array_like) – Input array.
- axis (int, optional) – Axis along which the cumulative product is computed. By default the input is flattened.
- dtype (dtype, optional) – Type of the returned array, as well as of the accumulator in which the elements are multiplied. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used instead.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output but the type of the resulting values will be cast if necessary.
Returns: nancumprod – A new array holding the result is returned unless out is specified, in which case it is returned.
Return type: ndarray
See also
numpy.cumprod()
- Cumulative product across array propagating NaNs.
isnan()
- Show which elements are NaN.
Examples
>>> np.nancumprod(1) array([1]) >>> np.nancumprod([1]) array([1]) >>> np.nancumprod([1, np.nan]) array([1., 1.]) >>> a = np.array([[1, 2], [3, np.nan]]) >>> np.nancumprod(a) array([1., 2., 6., 6.]) >>> np.nancumprod(a, axis=0) array([[1., 2.], [3., 2.]]) >>> np.nancumprod(a, axis=1) array([[1., 2.], [3., 3.]])
-
symjax.tensor.
nancumsum
(a, axis=None, dtype=None, out=None)¶ - Return the cumulative sum of array elements over a given axis treating Not a
- Numbers (NaNs) as zero. The cumulative sum does not change when NaNs are encountered and leading NaNs are replaced by zeros.
LAX-backend implementation of
nancumsum()
. Original docstring below.Zeros are returned for slices that are all-NaN or empty.
New in version 1.12.0.
Parameters: - a (array_like) – Input array.
- axis (int, optional) – Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.
- dtype (dtype, optional) – Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output but the type will be cast if necessary. See ufuncs-output-type for more details.
Returns: nancumsum – A new array holding the result is returned unless out is specified, in which it is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.
Return type: ndarray.
See also
numpy.cumsum()
- Cumulative sum across array propagating NaNs.
isnan()
- Show which elements are NaN.
Examples
>>> np.nancumsum(1) array([1]) >>> np.nancumsum([1]) array([1]) >>> np.nancumsum([1, np.nan]) array([1., 1.]) >>> a = np.array([[1, 2], [3, np.nan]]) >>> np.nancumsum(a) array([1., 3., 6., 6.]) >>> np.nancumsum(a, axis=0) array([[1., 2.], [4., 2.]]) >>> np.nancumsum(a, axis=1) array([[1., 3.], [3., 3.]])
-
symjax.tensor.
nanmax
(a, axis=None, out=None, keepdims=None)[source]¶ - Return the maximum of an array or maximum along an axis, ignoring any
- NaNs. When all-NaN slices are encountered a
RuntimeWarning
is raised and NaN is returned for that slice.
LAX-backend implementation of
nanmax()
. Original docstring below.Parameters: - a (array_like) – Array containing numbers whose maximum is desired. If a is not an array, a conversion is attempted.
- axis ({int, tuple of int, None}, optional) – Axis or axes along which the maximum is computed. The default is to compute the maximum of the flattened array.
- out (ndarray, optional) – Alternate output array in which to place the result. The default
is
None
; if provided, it must have the same shape as the expected output, but the type will be cast if necessary. See ufuncs-output-type for more details. - keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original a.
Returns: nanmax – An array with the same shape as a, with the specified axis removed. If a is a 0-d array, or if axis is None, an ndarray scalar is returned. The same dtype as a is returned.
Return type: ndarray
See also
nanmin()
- The minimum value of an array along a given axis, ignoring any NaNs.
amax()
- The maximum value of an array along a given axis, propagating any NaNs.
fmax()
- Element-wise maximum of two arrays, ignoring any NaNs.
maximum()
- Element-wise maximum of two arrays, propagating any NaNs.
isnan()
- Shows which elements are Not a Number (NaN).
isfinite()
- Shows which elements are neither NaN nor infinity.
Notes
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754). This means that Not a Number is not equivalent to infinity. Positive infinity is treated as a very large number and negative infinity is treated as a very small (i.e. negative) number.
If the input has a integer type the function is equivalent to np.max.
Examples
>>> a = np.array([[1, 2], [3, np.nan]]) >>> np.nanmax(a) 3.0 >>> np.nanmax(a, axis=0) array([3., 2.]) >>> np.nanmax(a, axis=1) array([2., 3.])
When positive infinity and negative infinity are present:
>>> np.nanmax([1, 2, np.nan, np.NINF]) 2.0 >>> np.nanmax([1, 2, np.nan, np.inf]) inf
-
symjax.tensor.
nanmin
(a, axis=None, out=None, keepdims=None)[source]¶ - Return minimum of an array or minimum along an axis, ignoring any NaNs.
- When all-NaN slices are encountered a
RuntimeWarning
is raised and Nan is returned for that slice.
LAX-backend implementation of
nanmin()
. Original docstring below.Parameters: - a (array_like) – Array containing numbers whose minimum is desired. If a is not an array, a conversion is attempted.
- axis ({int, tuple of int, None}, optional) – Axis or axes along which the minimum is computed. The default is to compute the minimum of the flattened array.
- out (ndarray, optional) – Alternate output array in which to place the result. The default
is
None
; if provided, it must have the same shape as the expected output, but the type will be cast if necessary. See ufuncs-output-type for more details. - keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original a.
Returns: nanmin – An array with the same shape as a, with the specified axis removed. If a is a 0-d array, or if axis is None, an ndarray scalar is returned. The same dtype as a is returned.
Return type: ndarray
See also
nanmax()
- The maximum value of an array along a given axis, ignoring any NaNs.
amin()
- The minimum value of an array along a given axis, propagating any NaNs.
fmin()
- Element-wise minimum of two arrays, ignoring any NaNs.
minimum()
- Element-wise minimum of two arrays, propagating any NaNs.
isnan()
- Shows which elements are Not a Number (NaN).
isfinite()
- Shows which elements are neither NaN nor infinity.
Notes
NumPy uses the IEEE Standard for Binary Floating-Point for Arithmetic (IEEE 754). This means that Not a Number is not equivalent to infinity. Positive infinity is treated as a very large number and negative infinity is treated as a very small (i.e. negative) number.
If the input has a integer type the function is equivalent to np.min.
Examples
>>> a = np.array([[1, 2], [3, np.nan]]) >>> np.nanmin(a) 1.0 >>> np.nanmin(a, axis=0) array([1., 2.]) >>> np.nanmin(a, axis=1) array([1., 3.])
When positive infinity and negative infinity are present:
>>> np.nanmin([1, 2, np.nan, np.inf]) 1.0 >>> np.nanmin([1, 2, np.nan, np.NINF]) -inf
-
symjax.tensor.
nanprod
(a, axis=None, dtype=None, out=None, keepdims=None)[source]¶ - Return the product of array elements over a given axis treating Not a
- Numbers (NaNs) as ones.
LAX-backend implementation of
nanprod()
. Original docstring below.One is returned for slices that are all-NaN or empty.
New in version 1.10.0.
Parameters: - a (array_like) – Array containing numbers whose product is desired. If a is not an array, a conversion is attempted.
- axis ({int, tuple of int, None}, optional) – Axis or axes along which the product is computed. The default is to compute the product of the flattened array.
- dtype (data-type, optional) – The type of the returned array and of the accumulator in which the elements are summed. By default, the dtype of a is used. An exception is when a has an integer type with less precision than the platform (u)intp. In that case, the default will be either (u)int32 or (u)int64 depending on whether the platform is 32 or 64 bits. For inexact inputs, dtype must be inexact.
- out (ndarray, optional) – Alternate output array in which to place the result. The default
is
None
. If provided, it must have the same shape as the expected output, but the type will be cast if necessary. See ufuncs-output-type for more details. The casting of NaN to integer can yield unexpected results. - keepdims (bool, optional) – If True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original arr.
Returns: nanprod – A new array holding the result is returned unless out is specified, in which case it is returned.
Return type: ndarray
See also
numpy.prod()
- Product across array propagating NaNs.
isnan()
- Show which elements are NaN.
Examples
>>> np.nanprod(1) 1 >>> np.nanprod([1]) 1 >>> np.nanprod([1, np.nan]) 1.0 >>> a = np.array([[1, 2], [3, np.nan]]) >>> np.nanprod(a) 6.0 >>> np.nanprod(a, axis=0) array([3., 2.])
-
symjax.tensor.
nansum
(a, axis=None, dtype=None, out=None, keepdims=None)[source]¶ - Return the sum of array elements over a given axis treating Not a
- Numbers (NaNs) as zero.
LAX-backend implementation of
nansum()
. Original docstring below.In NumPy versions <= 1.9.0 Nan is returned for slices that are all-NaN or empty. In later versions zero is returned.
Parameters: - a (array_like) – Array containing numbers whose sum is desired. If a is not an array, a conversion is attempted.
- axis ({int, tuple of int, None}, optional) – Axis or axes along which the sum is computed. The default is to compute the sum of the flattened array.
- dtype (data-type, optional) – The type of the returned array and of the accumulator in which the elements are summed. By default, the dtype of a is used. An exception is when a has an integer type with less precision than the platform (u)intp. In that case, the default will be either (u)int32 or (u)int64 depending on whether the platform is 32 or 64 bits. For inexact inputs, dtype must be inexact.
- out (ndarray, optional) – Alternate output array in which to place the result. The default
is
None
. If provided, it must have the same shape as the expected output, but the type will be cast if necessary. See ufuncs-output-type for more details. The casting of NaN to integer can yield unexpected results. - keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original a.
Returns: nansum – A new array holding the result is returned unless out is specified, in which it is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1-d array.
Return type: ndarray.
See also
numpy.sum()
- Sum across array propagating NaNs.
isnan()
- Show which elements are NaN.
isfinite()
- Show which elements are not NaN or +/-inf.
Notes
If both positive and negative infinity are present, the sum will be Not A Number (NaN).
Examples
>>> np.nansum(1) 1 >>> np.nansum([1]) 1 >>> np.nansum([1, np.nan]) 1.0 >>> a = np.array([[1, 1], [1, np.nan]]) >>> np.nansum(a) 3.0 >>> np.nansum(a, axis=0) array([2., 1.]) >>> np.nansum([1, np.nan, np.inf]) inf >>> np.nansum([1, np.nan, np.NINF]) -inf >>> from numpy.testing import suppress_warnings >>> with suppress_warnings() as sup: ... sup.filter(RuntimeWarning) ... np.nansum([1, np.nan, np.inf, -np.inf]) # both +/- infinity present nan
-
symjax.tensor.
negative
(x)¶ Numerical negative, element-wise.
LAX-backend implementation of
negative()
. Original docstring below.negative(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like or scalar) – Input array. Returns: y – Returned array or scalar: y = -x. This is a scalar if x is a scalar. Return type: ndarray or scalar Examples
>>> np.negative([1.,-1.]) array([-1., 1.])
-
symjax.tensor.
nextafter
(x1, x2)¶ Return the next floating-point value after x1 towards x2, element-wise.
LAX-backend implementation of
nextafter()
. Note that in some environments flush-denormal-to-zero semantics is used. This means that, around zero, this function returns strictly non-zero values which appear as zero in any operations. Consider this example:>>> jnp.nextafter(0, 1) # denormal numbers are representable DeviceArray(1.e-45, dtype=float32) >>> jnp.nextafter(0, 1) * 1 # but are flushed to zero DeviceArray(0., dtype=float32)
For the smallest usable (i.e. normal) float, use
tiny
ofjnp.finfo
. Original docstring below.nextafter(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: - x1 (array_like) – Values to find the next representable value of.
- x2 (array_like) – The direction where to look for the next representable value of x1.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: out – The next representable values of x1 in the direction of x2. This is a scalar if both x1 and x2 are scalars.
Return type: ndarray or scalar
Examples
>>> eps = np.finfo(np.float64).eps >>> np.nextafter(1, 2) == eps + 1 True >>> np.nextafter([1, 2], [2, 1]) == [eps + 1, 2 - eps] array([ True, True])
-
symjax.tensor.
nonzero
(a)[source]¶ Return the indices of the elements that are non-zero.
LAX-backend implementation of
nonzero()
. At present, JAX does not support JIT-compilation ofjax.numpy.nonzero()
because its output shape is data-dependent.Original docstring below.
Returns a tuple of arrays, one for each dimension of a, containing the indices of the non-zero elements in that dimension. The values in a are always tested and returned in row-major, C-style order.
To group the indices by element, rather than dimension, use argwhere, which returns a row for each non-zero element.
Note
When called on a zero-d array or scalar,
nonzero(a)
is treated asnonzero(atleast1d(a))
.Deprecated since version 1.17.0: Use atleast1d explicitly if this behavior is deliberate.
Parameters: a (array_like) – Input array. Returns: tuple_of_arrays – Indices of elements that are non-zero. Return type: tuple See also
flatnonzero()
- Return indices that are non-zero in the flattened version of the input array.
ndarray.nonzero()
- Equivalent ndarray method.
count_nonzero()
- Counts the number of non-zero elements in the input array.
Notes
While the nonzero values can be obtained with
a[nonzero(a)]
, it is recommended to usex[x.astype(bool)]
orx[x != 0]
instead, which will correctly handle 0-d arrays.Examples
>>> x = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]]) >>> x array([[3, 0, 0], [0, 4, 0], [5, 6, 0]]) >>> np.nonzero(x) (array([0, 1, 2, 2]), array([0, 1, 0, 1]))
>>> x[np.nonzero(x)] array([3, 4, 5, 6]) >>> np.transpose(np.nonzero(x)) array([[0, 0], [1, 1], [2, 0], [2, 1]])
A common use for
nonzero
is to find the indices of an array, where a condition is True. Given an array a, the condition a > 3 is a boolean array and since False is interpreted as 0, np.nonzero(a > 3) yields the indices of the a where the condition is true.>>> a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> a > 3 array([[False, False, False], [ True, True, True], [ True, True, True]]) >>> np.nonzero(a > 3) (array([1, 1, 1, 2, 2, 2]), array([0, 1, 2, 0, 1, 2]))
Using this result to index a is equivalent to using the mask directly:
>>> a[np.nonzero(a > 3)] array([4, 5, 6, 7, 8, 9]) >>> a[a > 3] # prefer this spelling array([4, 5, 6, 7, 8, 9])
nonzero
can also be called as a method of the array.>>> (a > 3).nonzero() (array([1, 1, 1, 2, 2, 2]), array([0, 1, 2, 0, 1, 2]))
-
symjax.tensor.
not_equal
(x1, x2)¶ Return (x1 != x2) element-wise.
LAX-backend implementation of
not_equal()
. Original docstring below.not_equal(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – Input arrays. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: out – Output array, element-wise comparison of x1 and x2. Typically of type bool, unless dtype=object
is passed. This is a scalar if both x1 and x2 are scalars.Return type: ndarray or scalar See also
Examples
>>> np.not_equal([1.,2.], [1., 3.]) array([False, True]) >>> np.not_equal([1, 2], [[1, 3],[1, 4]]) array([[False, True], [False, True]])
-
symjax.tensor.
ones
(shape, dtype=None)[source]¶ Return a new array of given shape and type, filled with ones.
LAX-backend implementation of
ones()
. Original docstring below.Parameters: - shape (int or sequence of ints) – Shape of the new array, e.g.,
(2, 3)
or2
. - dtype (data-type, optional) – The desired data-type for the array, e.g., numpy.int8. Default is numpy.float64.
Returns: out – Array of ones with the given shape, dtype, and order.
Return type: ndarray
See also
ones_like()
- Return an array of ones with shape and type of input.
empty()
- Return a new uninitialized array.
zeros()
- Return a new array setting values to zero.
full()
- Return a new array of given shape filled with value.
Examples
>>> np.ones(5) array([1., 1., 1., 1., 1.])
>>> np.ones((5,), dtype=int) array([1, 1, 1, 1, 1])
>>> np.ones((2, 1)) array([[1.], [1.]])
>>> s = (2,2) >>> np.ones(s) array([[1., 1.], [1., 1.]])
- shape (int or sequence of ints) – Shape of the new array, e.g.,
-
symjax.tensor.
outer
(a, b, out=None)[source]¶ Compute the outer product of two vectors.
LAX-backend implementation of
outer()
. Original docstring below.Given two vectors,
a = [a0, a1, ..., aM]
andb = [b0, b1, ..., bN]
, the outer product [1]_ is:[[a0*b0 a0*b1 ... a0*bN ] [a1*b0 . [ ... . [aM*b0 aM*bN ]]
Parameters: - a ((M,) array_like) – First input vector. Input is flattened if not already 1-dimensional.
- b ((N,) array_like) – Second input vector. Input is flattened if not already 1-dimensional.
- out ((M, N) ndarray, optional) – A location where the result is stored
Returns: out –
out[i, j] = a[i] * b[j]
Return type: (M, N) ndarray
See also
einsum()
einsum('i,j->ij', a.ravel(), b.ravel())
is the equivalent.ufunc.outer()
- A generalization to dimensions other than 1D and other operations.
np.multiply.outer(a.ravel(), b.ravel())
is the equivalent. tensordot()
np.tensordot(a.ravel(), b.ravel(), axes=((), ()))
is the equivalent.
References
[1] : G. H. Golub and C. F. Van Loan, Matrix Computations, 3rd ed., Baltimore, MD, Johns Hopkins University Press, 1996, pg. 8. Examples
Make a (very coarse) grid for computing a Mandelbrot set:
>>> rl = np.outer(np.ones((5,)), np.linspace(-2, 2, 5)) >>> rl array([[-2., -1., 0., 1., 2.], [-2., -1., 0., 1., 2.], [-2., -1., 0., 1., 2.], [-2., -1., 0., 1., 2.], [-2., -1., 0., 1., 2.]]) >>> im = np.outer(1j*np.linspace(2, -2, 5), np.ones((5,))) >>> im array([[0.+2.j, 0.+2.j, 0.+2.j, 0.+2.j, 0.+2.j], [0.+1.j, 0.+1.j, 0.+1.j, 0.+1.j, 0.+1.j], [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], [0.-1.j, 0.-1.j, 0.-1.j, 0.-1.j, 0.-1.j], [0.-2.j, 0.-2.j, 0.-2.j, 0.-2.j, 0.-2.j]]) >>> grid = rl + im >>> grid array([[-2.+2.j, -1.+2.j, 0.+2.j, 1.+2.j, 2.+2.j], [-2.+1.j, -1.+1.j, 0.+1.j, 1.+1.j, 2.+1.j], [-2.+0.j, -1.+0.j, 0.+0.j, 1.+0.j, 2.+0.j], [-2.-1.j, -1.-1.j, 0.-1.j, 1.-1.j, 2.-1.j], [-2.-2.j, -1.-2.j, 0.-2.j, 1.-2.j, 2.-2.j]])
An example using a “vector” of letters:
>>> x = np.array(['a', 'b', 'c'], dtype=object) >>> np.outer(x, [1, 2, 3]) array([['a', 'aa', 'aaa'], ['b', 'bb', 'bbb'], ['c', 'cc', 'ccc']], dtype=object)
-
symjax.tensor.
pad
(array, pad_width, mode='constant', constant_values=0, stat_length=None)[source]¶ Pad an array.
LAX-backend implementation of
pad()
. Original docstring below.Parameters: - array (array_like of rank N) – The array to pad.
- pad_width ({sequence, array_like, int}) – Number of values padded to the edges of each axis. ((before_1, after_1), … (before_N, after_N)) unique pad widths for each axis. ((before, after),) yields same before and after pad for each axis. (pad,) or int is a shortcut for before = after = pad width for all axes.
- mode (str or function, optional) – One of the following string values or a user supplied function.
- stat_length (sequence or int, optional) – Used in ‘maximum’, ‘mean’, ‘median’, and ‘minimum’. Number of values at edge of each axis used to calculate the statistic value.
- constant_values (sequence or scalar, optional) – Used in ‘constant’. The values to set the padded values for each axis.
Returns: pad – Padded array of rank equal to array with shape increased according to pad_width.
Return type: ndarray
Notes
New in version 1.7.0.
For an array with rank greater than 1, some of the padding of later axes is calculated from padding of previous axes. This is easiest to think about with a rank 2 array where the corners of the padded array are calculated by using padded values from the first axis.
The padding function, if used, should modify a rank 1 array in-place. It has the following signature:
padding_func(vector, iaxis_pad_width, iaxis, kwargs)
where
- vector : ndarray
- A rank 1 array already padded with zeros. Padded values are vector[:iaxis_pad_width[0]] and vector[-iaxis_pad_width[1]:].
- iaxis_pad_width : tuple
- A 2-tuple of ints, iaxis_pad_width[0] represents the number of values padded at the beginning of vector where iaxis_pad_width[1] represents the number of values padded at the end of vector.
- iaxis : int
- The axis currently being calculated.
- kwargs : dict
- Any keyword arguments the function requires.
Examples
>>> a = [1, 2, 3, 4, 5] >>> np.pad(a, (2, 3), 'constant', constant_values=(4, 6)) array([4, 4, 1, ..., 6, 6, 6])
>>> np.pad(a, (2, 3), 'edge') array([1, 1, 1, ..., 5, 5, 5])
>>> np.pad(a, (2, 3), 'linear_ramp', end_values=(5, -4)) array([ 5, 3, 1, 2, 3, 4, 5, 2, -1, -4])
>>> np.pad(a, (2,), 'maximum') array([5, 5, 1, 2, 3, 4, 5, 5, 5])
>>> np.pad(a, (2,), 'mean') array([3, 3, 1, 2, 3, 4, 5, 3, 3])
>>> np.pad(a, (2,), 'median') array([3, 3, 1, 2, 3, 4, 5, 3, 3])
>>> a = [[1, 2], [3, 4]] >>> np.pad(a, ((3, 2), (2, 3)), 'minimum') array([[1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1], [3, 3, 3, 4, 3, 3, 3], [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1]])
>>> a = [1, 2, 3, 4, 5] >>> np.pad(a, (2, 3), 'reflect') array([3, 2, 1, 2, 3, 4, 5, 4, 3, 2])
>>> np.pad(a, (2, 3), 'reflect', reflect_type='odd') array([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> np.pad(a, (2, 3), 'symmetric') array([2, 1, 1, 2, 3, 4, 5, 5, 4, 3])
>>> np.pad(a, (2, 3), 'symmetric', reflect_type='odd') array([0, 1, 1, 2, 3, 4, 5, 5, 6, 7])
>>> np.pad(a, (2, 3), 'wrap') array([4, 5, 1, 2, 3, 4, 5, 1, 2, 3])
>>> def pad_with(vector, pad_width, iaxis, kwargs): ... pad_value = kwargs.get('padder', 10) ... vector[:pad_width[0]] = pad_value ... vector[-pad_width[1]:] = pad_value >>> a = np.arange(6) >>> a = a.reshape((2, 3)) >>> np.pad(a, 2, pad_with) array([[10, 10, 10, 10, 10, 10, 10], [10, 10, 10, 10, 10, 10, 10], [10, 10, 0, 1, 2, 10, 10], [10, 10, 3, 4, 5, 10, 10], [10, 10, 10, 10, 10, 10, 10], [10, 10, 10, 10, 10, 10, 10]]) >>> np.pad(a, 2, pad_with, padder=100) array([[100, 100, 100, 100, 100, 100, 100], [100, 100, 100, 100, 100, 100, 100], [100, 100, 0, 1, 2, 100, 100], [100, 100, 3, 4, 5, 100, 100], [100, 100, 100, 100, 100, 100, 100], [100, 100, 100, 100, 100, 100, 100]])
-
symjax.tensor.
percentile
(a, q, axis=None, out=None, overwrite_input=False, interpolation='linear', keepdims=False)[source]¶ Compute the q-th percentile of the data along the specified axis.
LAX-backend implementation of
percentile()
. Original docstring below.Returns the q-th percentile(s) of the array elements.
Parameters: - a (array_like) – Input array or object that can be converted to an array.
- q (array_like of float) – Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
- axis ({int, tuple of int, None}, optional) – Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output, but the type (of the output) will be cast if necessary.
- overwrite_input (bool, optional) – If True, then allow the input array a to be modified by intermediate calculations, to save memory. In this case, the contents of the input a after this function completes is undefined.
- interpolation ({'linear', 'lower', 'higher', 'midpoint', 'nearest'}) – This optional parameter specifies the interpolation method to
use when the desired percentile lies between two data points
i < j
: - keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array a.
Returns: percentile – If q is a single percentile and axis=None, then the result is a scalar. If multiple percentiles are given, first axis of the result corresponds to the percentiles. The other axes are the axes that remain after the reduction of a. If the input contains integers or floats smaller than
float64
, the output data-type isfloat64
. Otherwise, the output data-type is the same as that of the input. If out is specified, that array is returned instead.Return type: scalar or ndarray
See also
median()
- equivalent to
percentile(..., 50)
nanpercentile()
quantile()
- equivalent to percentile, except with q in the range [0, 1].
Notes
Given a vector
V
of lengthN
, the q-th percentile ofV
is the valueq/100
of the way from the minimum to the maximum in a sorted copy ofV
. The values and distances of the two nearest neighbors as well as the interpolation parameter will determine the percentile if the normalized ranking does not match the location ofq
exactly. This function is the same as the median ifq=50
, the same as the minimum ifq=0
and the same as the maximum ifq=100
.Examples
>>> a = np.array([[10, 7, 4], [3, 2, 1]]) >>> a array([[10, 7, 4], [ 3, 2, 1]]) >>> np.percentile(a, 50) 3.5 >>> np.percentile(a, 50, axis=0) array([6.5, 4.5, 2.5]) >>> np.percentile(a, 50, axis=1) array([7., 2.]) >>> np.percentile(a, 50, axis=1, keepdims=True) array([[7.], [2.]])
>>> m = np.percentile(a, 50, axis=0) >>> out = np.zeros_like(m) >>> np.percentile(a, 50, axis=0, out=out) array([6.5, 4.5, 2.5]) >>> m array([6.5, 4.5, 2.5])
>>> b = a.copy() >>> np.percentile(b, 50, axis=1, overwrite_input=True) array([7., 2.]) >>> assert not np.all(a == b)
The different types of interpolation can be visualized graphically:
-
symjax.tensor.
polyval
(p, x)[source]¶ Evaluate a polynomial at specific values.
LAX-backend implementation of
polyval()
. Original docstring below.If p is of length N, this function returns the value:
p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1]
If x is a sequence, then p(x) is returned for each element of x. If x is another polynomial then the composite polynomial p(x(t)) is returned.
Parameters: - p (array_like or poly1d object) –
- x (array_like or poly1d object) –
Returns: values – If x is a poly1d instance, the result is the composition of the two polynomials, i.e., x is “substituted” in p and the simplified result is returned. In addition, the type of x - array_like or poly1d - governs the type of the output: x array_like => values array_like, x a poly1d object => values is also.
Return type: ndarray or poly1d
See also
poly1d()
- A polynomial class.
Notes
Horner’s scheme [1]_ is used to evaluate the polynomial. Even so, for polynomials of high degree the values may be inaccurate due to rounding errors. Use carefully.
If x is a subtype of ndarray the return value will be of the same type.
References
[1] I. N. Bronshtein, K. A. Semendyayev, and K. A. Hirsch (Eng. trans. Ed.), Handbook of Mathematics, New York, Van Nostrand Reinhold Co., 1985, pg. 720. Examples
>>> np.polyval([3,0,1], 5) # 3 * 5**2 + 0 * 5**1 + 1 76 >>> np.polyval([3,0,1], np.poly1d(5)) poly1d([76.]) >>> np.polyval(np.poly1d([3,0,1]), 5) 76 >>> np.polyval(np.poly1d([3,0,1]), np.poly1d(5)) poly1d([76.])
-
symjax.tensor.
power
(x1, x2)[source]¶ First array elements raised to powers from second array, element-wise.
LAX-backend implementation of
power()
. Original docstring below.power(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Raise each base in x1 to the positionally-corresponding power in x2. x1 and x2 must be broadcastable to the same shape. Note that an integer type raised to a negative integer power will raise a ValueError.
Parameters: - x1 (array_like) – The bases.
- x2 (array_like) – The exponents.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: y – The bases in x1 raised to the exponents in x2. This is a scalar if both x1 and x2 are scalars.
Return type: ndarray
See also
float_power()
- power function that promotes integers to float
Examples
Cube each element in a list.
>>> x1 = range(6) >>> x1 [0, 1, 2, 3, 4, 5] >>> np.power(x1, 3) array([ 0, 1, 8, 27, 64, 125])
Raise the bases to different exponents.
>>> x2 = [1.0, 2.0, 3.0, 3.0, 2.0, 1.0] >>> np.power(x1, x2) array([ 0., 1., 8., 27., 16., 5.])
The effect of broadcasting.
>>> x2 = np.array([[1, 2, 3, 3, 2, 1], [1, 2, 3, 3, 2, 1]]) >>> x2 array([[1, 2, 3, 3, 2, 1], [1, 2, 3, 3, 2, 1]]) >>> np.power(x1, x2) array([[ 0, 1, 8, 27, 16, 5], [ 0, 1, 8, 27, 16, 5]])
-
symjax.tensor.
positive
(x)¶ Numerical positive, element-wise.
LAX-backend implementation of
positive()
. Original docstring below.positive(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
New in version 1.13.0.
Parameters: x (array_like or scalar) – Input array. Returns: y – Returned array or scalar: y = +x. This is a scalar if x is a scalar. Return type: ndarray or scalar Notes
Equivalent to x.copy(), but only defined for types that support arithmetic.
-
symjax.tensor.
prod
(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)[source]¶ Return the product of array elements over a given axis.
LAX-backend implementation of
prod()
. Original docstring below.Parameters: - a (array_like) – Input data.
- axis (None or int or tuple of ints, optional) – Axis or axes along which a product is performed. The default, axis=None, will calculate the product of all the elements in the input array. If axis is negative it counts from the last to the first axis.
- dtype (dtype, optional) – The type of the returned array, as well as of the accumulator in which the elements are multiplied. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape as the expected output, but the type of the output values will be cast if necessary.
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
- initial (scalar, optional) – The starting value for this product. See ~numpy.ufunc.reduce for details.
- where (array_like of bool, optional) – Elements to include in the product. See ~numpy.ufunc.reduce for details.
Returns: product_along_axis – An array shaped as a but with the specified axis removed. Returns a reference to out if specified.
Return type: ndarray, see dtype parameter above.
See also
ndarray.prod()
- equivalent method
ufuncs-output-type()
Notes
Arithmetic is modular when using integer types, and no error is raised on overflow. That means that, on a 32-bit platform:
>>> x = np.array([536870910, 536870910, 536870910, 536870910]) >>> np.prod(x) 16 # may vary
The product of an empty array is the neutral element 1:
>>> np.prod([]) 1.0
Examples
By default, calculate the product of all elements:
>>> np.prod([1.,2.]) 2.0
Even when the input array is two-dimensional:
>>> np.prod([[1.,2.],[3.,4.]]) 24.0
But we can also specify the axis over which to multiply:
>>> np.prod([[1.,2.],[3.,4.]], axis=1) array([ 2., 12.])
Or select specific elements to include:
>>> np.prod([1., np.nan, 3.], where=[True, False, True]) 3.0
If the type of x is unsigned, then the output type is the unsigned platform integer:
>>> x = np.array([1, 2, 3], dtype=np.uint8) >>> np.prod(x).dtype == np.uint True
If x is of a signed integer type, then the output type is the default platform integer:
>>> x = np.array([1, 2, 3], dtype=np.int8) >>> np.prod(x).dtype == int True
You can also start the product with a value other than one:
>>> np.prod([1, 2], initial=5) 10
-
symjax.tensor.
product
(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)¶ Return the product of array elements over a given axis.
LAX-backend implementation of
prod()
. Original docstring below.Parameters: - a (array_like) – Input data.
- axis (None or int or tuple of ints, optional) – Axis or axes along which a product is performed. The default, axis=None, will calculate the product of all the elements in the input array. If axis is negative it counts from the last to the first axis.
- dtype (dtype, optional) – The type of the returned array, as well as of the accumulator in which the elements are multiplied. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape as the expected output, but the type of the output values will be cast if necessary.
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
- initial (scalar, optional) – The starting value for this product. See ~numpy.ufunc.reduce for details.
- where (array_like of bool, optional) – Elements to include in the product. See ~numpy.ufunc.reduce for details.
Returns: product_along_axis – An array shaped as a but with the specified axis removed. Returns a reference to out if specified.
Return type: ndarray, see dtype parameter above.
See also
ndarray.prod()
- equivalent method
ufuncs-output-type()
Notes
Arithmetic is modular when using integer types, and no error is raised on overflow. That means that, on a 32-bit platform:
>>> x = np.array([536870910, 536870910, 536870910, 536870910]) >>> np.prod(x) 16 # may vary
The product of an empty array is the neutral element 1:
>>> np.prod([]) 1.0
Examples
By default, calculate the product of all elements:
>>> np.prod([1.,2.]) 2.0
Even when the input array is two-dimensional:
>>> np.prod([[1.,2.],[3.,4.]]) 24.0
But we can also specify the axis over which to multiply:
>>> np.prod([[1.,2.],[3.,4.]], axis=1) array([ 2., 12.])
Or select specific elements to include:
>>> np.prod([1., np.nan, 3.], where=[True, False, True]) 3.0
If the type of x is unsigned, then the output type is the unsigned platform integer:
>>> x = np.array([1, 2, 3], dtype=np.uint8) >>> np.prod(x).dtype == np.uint True
If x is of a signed integer type, then the output type is the default platform integer:
>>> x = np.array([1, 2, 3], dtype=np.int8) >>> np.prod(x).dtype == int True
You can also start the product with a value other than one:
>>> np.prod([1, 2], initial=5) 10
-
symjax.tensor.
promote_types
(a, b)[source]¶ Returns the type to which a binary operation should cast its arguments.
For details of JAX’s type promotion semantics, see type-promotion.
Parameters: - a – a
numpy.dtype
or a dtype specifier. - b – a
numpy.dtype
or a dtype specifier.
Returns: A
numpy.dtype
object.- a – a
-
symjax.tensor.
ptp
(a, axis=None, out=None, keepdims=False)[source]¶ Range of values (maximum - minimum) along an axis.
LAX-backend implementation of
ptp()
. Original docstring below.The name of the function comes from the acronym for ‘peak to peak’.
Warning
ptp preserves the data type of the array. This means the return value for an input of signed integers with n bits (e.g. np.int8, np.int16, etc) is also a signed integer with n bits. In that case, peak-to-peak values greater than
2**(n-1)-1
will be returned as negative values. An example with a work-around is shown below.Parameters: - a (array_like) – Input values.
- axis (None or int or tuple of ints, optional) – Axis along which to find the peaks. By default, flatten the array. axis may be negative, in which case it counts from the last to the first axis.
- out (array_like) – Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output, but the type of the output values will be cast if necessary.
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
Returns: ptp – A new array holding the result, unless out was specified, in which case a reference to out is returned.
Return type: ndarray
Examples
>>> x = np.array([[4, 9, 2, 10], ... [6, 9, 7, 12]])
>>> np.ptp(x, axis=1) array([8, 6])
>>> np.ptp(x, axis=0) array([2, 0, 5, 2])
>>> np.ptp(x) 10
This example shows that a negative value can be returned when the input is an array of signed integers.
>>> y = np.array([[1, 127], ... [0, 127], ... [-1, 127], ... [-2, 127]], dtype=np.int8) >>> np.ptp(y, axis=1) array([ 126, 127, -128, -127], dtype=int8)
A work-around is to use the view() method to view the result as unsigned integers with the same bit width:
>>> np.ptp(y, axis=1).view(np.uint8) array([126, 127, 128, 129], dtype=uint8)
-
symjax.tensor.
quantile
(a, q, axis=None, out=None, overwrite_input=False, interpolation='linear', keepdims=False)[source]¶ Compute the q-th quantile of the data along the specified axis.
LAX-backend implementation of
quantile()
. Original docstring below.New in version 1.15.0.
Parameters: - a (array_like) – Input array or object that can be converted to an array.
- q (array_like of float) – Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
- axis ({int, tuple of int, None}, optional) – Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape and buffer length as the expected output, but the type (of the output) will be cast if necessary.
- overwrite_input (bool, optional) – If True, then allow the input array a to be modified by intermediate calculations, to save memory. In this case, the contents of the input a after this function completes is undefined.
- interpolation ({'linear', 'lower', 'higher', 'midpoint', 'nearest'}) – This optional parameter specifies the interpolation method to
use when the desired quantile lies between two data points
i < j
: - keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array a.
Returns: quantile – If q is a single quantile and axis=None, then the result is a scalar. If multiple quantiles are given, first axis of the result corresponds to the quantiles. The other axes are the axes that remain after the reduction of a. If the input contains integers or floats smaller than
float64
, the output data-type isfloat64
. Otherwise, the output data-type is the same as that of the input. If out is specified, that array is returned instead.Return type: scalar or ndarray
See also
percentile()
- equivalent to quantile, but with q in the range [0, 100].
median()
- equivalent to
quantile(..., 0.5)
nanquantile()
Notes
Given a vector
V
of lengthN
, the q-th quantile ofV
is the valueq
of the way from the minimum to the maximum in a sorted copy ofV
. The values and distances of the two nearest neighbors as well as the interpolation parameter will determine the quantile if the normalized ranking does not match the location ofq
exactly. This function is the same as the median ifq=0.5
, the same as the minimum ifq=0.0
and the same as the maximum ifq=1.0
.Examples
>>> a = np.array([[10, 7, 4], [3, 2, 1]]) >>> a array([[10, 7, 4], [ 3, 2, 1]]) >>> np.quantile(a, 0.5) 3.5 >>> np.quantile(a, 0.5, axis=0) array([6.5, 4.5, 2.5]) >>> np.quantile(a, 0.5, axis=1) array([7., 2.]) >>> np.quantile(a, 0.5, axis=1, keepdims=True) array([[7.], [2.]]) >>> m = np.quantile(a, 0.5, axis=0) >>> out = np.zeros_like(m) >>> np.quantile(a, 0.5, axis=0, out=out) array([6.5, 4.5, 2.5]) >>> m array([6.5, 4.5, 2.5]) >>> b = a.copy() >>> np.quantile(b, 0.5, axis=1, overwrite_input=True) array([7., 2.]) >>> assert not np.all(a == b)
-
symjax.tensor.
rad2deg
(x)[source]¶ Convert angles from radians to degrees.
LAX-backend implementation of
rad2deg()
. Original docstring below.rad2deg(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Angle in radians. Returns: y – The corresponding angle in degrees. This is a scalar if x is a scalar. Return type: ndarray See also
deg2rad()
- Convert angles from degrees to radians.
unwrap()
- Remove large jumps in angle by wrapping.
Notes
New in version 1.3.0.
rad2deg(x) is
180 * x / pi
.Examples
>>> np.rad2deg(np.pi/2) 90.0
-
symjax.tensor.
radians
(x)¶ Convert angles from degrees to radians.
LAX-backend implementation of
deg2rad()
. Original docstring below.deg2rad(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Angles in degrees. Returns: y – The corresponding angle in radians. This is a scalar if x is a scalar. Return type: ndarray See also
rad2deg()
- Convert angles from radians to degrees.
unwrap()
- Remove large jumps in angle by wrapping.
Notes
New in version 1.3.0.
deg2rad(x)
isx * pi / 180
.Examples
>>> np.deg2rad(180) 3.1415926535897931
-
symjax.tensor.
ravel
(a, order='C')[source]¶ Return a contiguous flattened array.
LAX-backend implementation of
ravel()
. Original docstring below.A 1-D array, containing the elements of the input, is returned. A copy is made only if needed.
As of NumPy 1.10, the returned array will have the same type as the input array. (for example, a masked array will be returned for a masked array input)
Parameters: - a (array_like) – Input array. The elements in a are read in the order specified by order, and packed as a 1-D array.
- order ({'C','F', 'A', 'K'}, optional) –
Returns: y – y is an array of the same subtype as a, with shape
(a.size,)
. Note that matrices are special cased for backward compatibility, if a is a matrix, then y is a 1-D ndarray.Return type: array_like
See also
ndarray.flat()
- 1-D iterator over an array.
ndarray.flatten()
- 1-D array copy of the elements of an array in row-major order.
ndarray.reshape()
- Change the shape of an array without changing its data.
Notes
In row-major, C-style order, in two dimensions, the row index varies the slowest, and the column index the quickest. This can be generalized to multiple dimensions, where row-major order implies that the index along the first axis varies slowest, and the index along the last quickest. The opposite holds for column-major, Fortran-style index ordering.
When a view is desired in as many cases as possible,
arr.reshape(-1)
may be preferable.Examples
It is equivalent to
reshape(-1, order=order)
.>>> x = np.array([[1, 2, 3], [4, 5, 6]]) >>> np.ravel(x) array([1, 2, 3, 4, 5, 6])
>>> x.reshape(-1) array([1, 2, 3, 4, 5, 6])
>>> np.ravel(x, order='F') array([1, 4, 2, 5, 3, 6])
When
order
is ‘A’, it will preserve the array’s ‘C’ or ‘F’ ordering:>>> np.ravel(x.T) array([1, 4, 2, 5, 3, 6]) >>> np.ravel(x.T, order='A') array([1, 2, 3, 4, 5, 6])
When
order
is ‘K’, it will preserve orderings that are neither ‘C’ nor ‘F’, but won’t reverse axes:>>> a = np.arange(3)[::-1]; a array([2, 1, 0]) >>> a.ravel(order='C') array([2, 1, 0]) >>> a.ravel(order='K') array([2, 1, 0])
>>> a = np.arange(12).reshape(2,3,2).swapaxes(1,2); a array([[[ 0, 2, 4], [ 1, 3, 5]], [[ 6, 8, 10], [ 7, 9, 11]]]) >>> a.ravel(order='C') array([ 0, 2, 4, 1, 3, 5, 6, 8, 10, 7, 9, 11]) >>> a.ravel(order='K') array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
-
symjax.tensor.
real
(val)[source]¶ Return the real part of the complex argument.
LAX-backend implementation of
real()
. Original docstring below.Parameters: val (array_like) – Input array. Returns: out – The real component of the complex argument. If val is real, the type of val is used for the output. If val has complex elements, the returned type is float. Return type: ndarray or scalar Examples
>>> a = np.array([1+2j, 3+4j, 5+6j]) >>> a.real array([1., 3., 5.]) >>> a.real = 9 >>> a array([9.+2.j, 9.+4.j, 9.+6.j]) >>> a.real = np.array([9, 8, 7]) >>> a array([9.+2.j, 8.+4.j, 7.+6.j]) >>> np.real(1 + 1j) 1.0
-
symjax.tensor.
reciprocal
(x)[source]¶ Return the reciprocal of the argument, element-wise.
LAX-backend implementation of
reciprocal()
. Original docstring below.reciprocal(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Calculates
1/x
.Parameters: x (array_like) – Input array. Returns: y – Return array. This is a scalar if x is a scalar. Return type: ndarray Notes
Note
This function is not designed to work with integers.
For integer arguments with absolute value larger than 1 the result is always zero because of the way Python handles integer division. For integer zero the result is an overflow.
Examples
>>> np.reciprocal(2.) 0.5 >>> np.reciprocal([1, 2., 3.33]) array([ 1. , 0.5 , 0.3003003])
-
symjax.tensor.
remainder
(x1, x2)[source]¶ Return element-wise remainder of division.
LAX-backend implementation of
remainder()
. Original docstring below.remainder(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Computes the remainder complementary to the floor_divide function. It is equivalent to the Python modulus operator``x1 % x2`` and has the same sign as the divisor x2. The MATLAB function equivalent to
np.remainder
ismod
.Warning
This should not be confused with:
- Python 3.7’s math.remainder and C’s
remainder
, which computes the IEEE remainder, which are the complement toround(x1 / x2)
. - The MATLAB
rem
function and or the C%
operator which is the complement toint(x1 / x2)
.
Parameters: - x1 (array_like) – Dividend array.
- x2 (array_like) – Divisor array.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: y – The element-wise remainder of the quotient
floor_divide(x1, x2)
. This is a scalar if both x1 and x2 are scalars.Return type: ndarray
See also
floor_divide()
- Equivalent of Python
//
operator. divmod()
- Simultaneous floor division and remainder.
fmod()
- Equivalent of the MATLAB
rem
function.
Notes
Returns 0 when x2 is 0 and both x1 and x2 are (arrays of) integers.
mod
is an alias ofremainder
.Examples
>>> np.remainder([4, 7], [2, 3]) array([0, 1]) >>> np.remainder(np.arange(7), 5) array([0, 1, 2, 3, 4, 0, 1])
- Python 3.7’s math.remainder and C’s
-
symjax.tensor.
repeat
(a, repeats, axis=None, *, total_repeat_length=None)[source]¶ Repeat elements of an array.
LAX-backend implementation of
repeat()
. Jax adds the optional total_repeat_length parameter which specifies the total number of repeat, and defaults to sum(repeats). It must be specified for repeat to be compilable. If sum(repeats) is larger than the specified total_repeat_length the remaining values will be discarded. In the case of sum(repeats) being smaller than the specified target length, the final value will be repeated.Original docstring below.
Parameters: - a (array_like) – Input array.
- repeats (int or array of ints) – The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
- axis (int, optional) – The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
Returns: repeated_array – Output array which has the same shape as a, except along the given axis.
Return type: ndarray
See also
tile()
- Tile an array.
Examples
>>> np.repeat(3, 4) array([3, 3, 3, 3]) >>> x = np.array([[1,2],[3,4]]) >>> np.repeat(x, 2) array([1, 1, 2, 2, 3, 3, 4, 4]) >>> np.repeat(x, 3, axis=1) array([[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]]) >>> np.repeat(x, [1, 2], axis=0) array([[1, 2], [3, 4], [3, 4]])
-
symjax.tensor.
reshape
(a, newshape, order='C')[source]¶ Gives a new shape to an array without changing its data.
LAX-backend implementation of
reshape()
. Original docstring below.Parameters: - a (array_like) – Array to be reshaped.
- newshape (int or tuple of ints) – The new shape should be compatible with the original shape. If an integer, then the result will be a 1-D array of that length. One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions.
- order ({'C', 'F', 'A'}, optional) – Read the elements of a using this index order, and place the elements into the reshaped array using this index order. ‘C’ means to read / write the elements using C-like index order, with the last axis index changing fastest, back to the first axis index changing slowest. ‘F’ means to read / write the elements using Fortran-like index order, with the first index changing fastest, and the last index changing slowest. Note that the ‘C’ and ‘F’ options take no account of the memory layout of the underlying array, and only refer to the order of indexing. ‘A’ means to read / write the elements in Fortran-like index order if a is Fortran contiguous in memory, C-like order otherwise.
Returns: reshaped_array – This will be a new view object if possible; otherwise, it will be a copy. Note there is no guarantee of the memory layout (C- or Fortran- contiguous) of the returned array.
Return type: ndarray
See also
ndarray.reshape()
- Equivalent method.
Notes
It is not always possible to change the shape of an array without copying the data. If you want an error to be raised when the data is copied, you should assign the new shape to the shape attribute of the array:
>>> a = np.zeros((10, 2)) # A transpose makes the array non-contiguous >>> b = a.T # Taking a view makes it possible to modify the shape without modifying # the initial object. >>> c = b.view() >>> c.shape = (20) Traceback (most recent call last): ... AttributeError: Incompatible shape for in-place modification. Use `.reshape()` to make a copy with the desired shape.
The order keyword gives the index ordering both for fetching the values from a, and then placing the values into the output array. For example, let’s say you have an array:
>>> a = np.arange(6).reshape((3, 2)) >>> a array([[0, 1], [2, 3], [4, 5]])
You can think of reshaping as first raveling the array (using the given index order), then inserting the elements from the raveled array into the new array using the same kind of index ordering as was used for the raveling.
>>> np.reshape(a, (2, 3)) # C-like index ordering array([[0, 1, 2], [3, 4, 5]]) >>> np.reshape(np.ravel(a), (2, 3)) # equivalent to C ravel then C reshape array([[0, 1, 2], [3, 4, 5]]) >>> np.reshape(a, (2, 3), order='F') # Fortran-like index ordering array([[0, 4, 3], [2, 1, 5]]) >>> np.reshape(np.ravel(a, order='F'), (2, 3), order='F') array([[0, 4, 3], [2, 1, 5]])
Examples
>>> a = np.array([[1,2,3], [4,5,6]]) >>> np.reshape(a, 6) array([1, 2, 3, 4, 5, 6]) >>> np.reshape(a, 6, order='F') array([1, 4, 2, 5, 3, 6])
>>> np.reshape(a, (3,-1)) # the unspecified value is inferred to be 2 array([[1, 2], [3, 4], [5, 6]])
-
symjax.tensor.
result_type
(*args)[source]¶ - Returns the type that results from applying the NumPy
- type promotion rules to the arguments.
LAX-backend implementation of
result_type()
. Original docstring below.result_type(*arrays_and_dtypes)
Type promotion in NumPy works similarly to the rules in languages like C++, with some slight differences. When both scalars and arrays are used, the array’s type takes precedence and the actual value of the scalar is taken into account.
For example, calculating 3*a, where a is an array of 32-bit floats, intuitively should result in a 32-bit float output. If the 3 is a 32-bit integer, the NumPy rules indicate it can’t convert losslessly into a 32-bit float, so a 64-bit float should be the result type. By examining the value of the constant, ‘3’, we see that it fits in an 8-bit integer, which can be cast losslessly into the 32-bit float.
- Returns
- out : dtype
- The result type.
dtype, promote_types, min_scalar_type, can_cast
New in version 1.6.0.
The specific algorithm used is as follows.
Categories are determined by first checking which of boolean, integer (int/uint), or floating point (float/complex) the maximum kind of all the arrays and the scalars are.
If there are only scalars or the maximum category of the scalars is higher than the maximum category of the arrays, the data types are combined with
promote_types()
to produce the return value.Otherwise, min_scalar_type is called on each array, and the resulting data types are all combined with
promote_types()
to produce the return value.The set of int values is not a subset of the uint values for types with the same number of bits, something not reflected in
min_scalar_type()
, but handled as a special case in result_type.>>> np.result_type(3, np.arange(7, dtype='i1')) dtype('int8')
>>> np.result_type('i4', 'c8') dtype('complex128')
>>> np.result_type(3.0, -2) dtype('float64')
-
symjax.tensor.
right_shift
(x1, x2)[source]¶ Shift the bits of an integer to the right.
LAX-backend implementation of
right_shift()
. Original docstring below.right_shift(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Bits are shifted to the right x2. Because the internal representation of numbers is in binary format, this operation is equivalent to dividing x1 by
2**x2
.Parameters: - x1 (array_like, int) – Input values.
- x2 (array_like, int) – Number of bits to remove at the right of x1.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: out – Return x1 with bits shifted x2 times to the right. This is a scalar if both x1 and x2 are scalars.
Return type: ndarray, int
See also
left_shift()
- Shift the bits of an integer to the left.
binary_repr()
- Return the binary representation of the input number as a string.
Examples
>>> np.binary_repr(10) '1010' >>> np.right_shift(10, 1) 5 >>> np.binary_repr(5) '101'
>>> np.right_shift(10, [1,2,3]) array([5, 2, 1])
-
symjax.tensor.
roll
(a, shift, axis=None)[source]¶ Roll array elements along a given axis.
LAX-backend implementation of
roll()
. Original docstring below.Elements that roll beyond the last position are re-introduced at the first.
Parameters: - a (array_like) – Input array.
- shift (int or tuple of ints) – The number of places by which elements are shifted. If a tuple, then axis must be a tuple of the same size, and each of the given axes is shifted by the corresponding number. If an int while axis is a tuple of ints, then the same value is used for all given axes.
- axis (int or tuple of ints, optional) – Axis or axes along which elements are shifted. By default, the array is flattened before shifting, after which the original shape is restored.
Returns: res – Output array, with the same shape as a.
Return type: ndarray
See also
rollaxis()
- Roll the specified axis backwards, until it lies in a given position.
Notes
New in version 1.12.0.
Supports rolling over multiple dimensions simultaneously.
Examples
>>> x = np.arange(10) >>> np.roll(x, 2) array([8, 9, 0, 1, 2, 3, 4, 5, 6, 7]) >>> np.roll(x, -2) array([2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
>>> x2 = np.reshape(x, (2,5)) >>> x2 array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> np.roll(x2, 1) array([[9, 0, 1, 2, 3], [4, 5, 6, 7, 8]]) >>> np.roll(x2, -1) array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0]]) >>> np.roll(x2, 1, axis=0) array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]]) >>> np.roll(x2, -1, axis=0) array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]]) >>> np.roll(x2, 1, axis=1) array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]) >>> np.roll(x2, -1, axis=1) array([[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]])
-
symjax.tensor.
rot90
(m, k=1, axes=(0, 1))[source]¶ Rotate an array by 90 degrees in the plane specified by axes.
LAX-backend implementation of
rot90()
. Original docstring below.Rotation direction is from the first towards the second axis.
Parameters: - m (array_like) – Array of two or more dimensions.
- k (integer) – Number of times the array is rotated by 90 degrees.
Returns: y – A rotated view of m.
Return type: ndarray
See also
Notes
rot90(m, k=1, axes=(1,0)) is the reverse of rot90(m, k=1, axes=(0,1)) rot90(m, k=1, axes=(1,0)) is equivalent to rot90(m, k=-1, axes=(0,1))
Examples
>>> m = np.array([[1,2],[3,4]], int) >>> m array([[1, 2], [3, 4]]) >>> np.rot90(m) array([[2, 4], [1, 3]]) >>> np.rot90(m, 2) array([[4, 3], [2, 1]]) >>> m = np.arange(8).reshape((2,2,2)) >>> np.rot90(m, 1, (1,2)) array([[[1, 3], [0, 2]], [[5, 7], [4, 6]]])
-
symjax.tensor.
round
(a, decimals=0, out=None)[source]¶ Round an array to the given number of decimals.
LAX-backend implementation of
round_()
. Original docstring below.around : equivalent function; see for details.
-
symjax.tensor.
row_stack
(tup)¶ Stack arrays in sequence vertically (row wise).
LAX-backend implementation of
vstack()
. Original docstring below.This is equivalent to concatenation along the first axis after 1-D arrays of shape (N,) have been reshaped to (1,N). Rebuilds arrays divided by vsplit.
This function makes most sense for arrays with up to 3 dimensions. For instance, for pixel-data with a height (first axis), width (second axis), and r/g/b channels (third axis). The functions concatenate, stack and block provide more general stacking and concatenation operations.
Parameters: tup (sequence of ndarrays) – The arrays must have the same shape along all but the first axis. 1-D arrays must have the same length. Returns: stacked – The array formed by stacking the given arrays, will be at least 2-D. Return type: ndarray See also
concatenate()
- Join a sequence of arrays along an existing axis.
stack()
- Join a sequence of arrays along a new axis.
block()
- Assemble an nd-array from nested lists of blocks.
hstack()
- Stack arrays in sequence horizontally (column wise).
dstack()
- Stack arrays in sequence depth wise (along third axis).
column_stack()
- Stack 1-D arrays as columns into a 2-D array.
vsplit()
- Split an array into multiple sub-arrays vertically (row-wise).
Examples
>>> a = np.array([1, 2, 3]) >>> b = np.array([2, 3, 4]) >>> np.vstack((a,b)) array([[1, 2, 3], [2, 3, 4]])
>>> a = np.array([[1], [2], [3]]) >>> b = np.array([[2], [3], [4]]) >>> np.vstack((a,b)) array([[1], [2], [3], [2], [3], [4]])
-
symjax.tensor.
select
(condlist, choicelist, default=0)[source]¶ Return an array drawn from elements in choicelist, depending on conditions.
LAX-backend implementation of
select()
. Original docstring below.Parameters: - condlist (list of bool ndarrays) – The list of conditions which determine from which array in choicelist the output elements are taken. When multiple conditions are satisfied, the first one encountered in condlist is used.
- choicelist (list of ndarrays) – The list of arrays from which the output elements are taken. It has to be of the same length as condlist.
- default (scalar, optional) – The element inserted in output when all conditions evaluate to False.
Returns: output – The output at position m is the m-th element of the array in choicelist where the m-th element of the corresponding array in condlist is True.
Return type: ndarray
See also
where()
- Return elements from one of two arrays depending on condition.
take()
,choose()
,compress()
,diag()
,diagonal()
Examples
>>> x = np.arange(10) >>> condlist = [x<3, x>5] >>> choicelist = [x, x**2] >>> np.select(condlist, choicelist) array([ 0, 1, 2, ..., 49, 64, 81])
-
symjax.tensor.
sign
(x)[source]¶ Returns an element-wise indication of the sign of a number.
LAX-backend implementation of
sign()
. Original docstring below.sign(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
The sign function returns
-1 if x < 0, 0 if x==0, 1 if x > 0
. nan is returned for nan inputs.For complex inputs, the sign function returns
sign(x.real) + 0j if x.real != 0 else sign(x.imag) + 0j
.complex(nan, 0) is returned for complex nan inputs.
Parameters: x (array_like) – Input values. Returns: y – The sign of x. This is a scalar if x is a scalar. Return type: ndarray Notes
There is more than one definition of sign in common use for complex numbers. The definition used here is equivalent to \(x/\sqrt{x*x}\) which is different from a common alternative, \(x/|x|\).
Examples
>>> np.sign([-5., 4.5]) array([-1., 1.]) >>> np.sign(0) 0 >>> np.sign(5-2j) (1+0j)
-
symjax.tensor.
signbit
(x)[source]¶ Returns element-wise True where signbit is set (less than zero).
LAX-backend implementation of
signbit()
. Original docstring below.signbit(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – The input value(s). Returns: result – Output array, or reference to out if that was supplied. This is a scalar if x is a scalar. Return type: ndarray of bool Examples
>>> np.signbit(-1.2) True >>> np.signbit(np.array([1, -2.3, 2.1])) array([False, True, False])
-
symjax.tensor.
sin
(x)¶ Trigonometric sine, element-wise.
LAX-backend implementation of
sin()
. Original docstring below.sin(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Angle, in radians (\(2 \pi\) rad equals 360 degrees). Returns: y – The sine of each element of x. This is a scalar if x is a scalar. Return type: array_like Notes
The sine is one of the fundamental functions of trigonometry (the mathematical study of triangles). Consider a circle of radius 1 centered on the origin. A ray comes in from the \(+x\) axis, makes an angle at the origin (measured counter-clockwise from that axis), and departs from the origin. The \(y\) coordinate of the outgoing ray’s intersection with the unit circle is the sine of that angle. It ranges from -1 for \(x=3\pi / 2\) to +1 for \(\pi / 2.\) The function has zeroes where the angle is a multiple of \(\pi\). Sines of angles between \(\pi\) and \(2\pi\) are negative. The numerous properties of the sine and related functions are included in any standard trigonometry text.
Examples
Print sine of one angle:
>>> np.sin(np.pi/2.) 1.0
Print sines of an array of angles given in degrees:
>>> np.sin(np.array((0., 30., 45., 60., 90.)) * np.pi / 180. ) array([ 0. , 0.5 , 0.70710678, 0.8660254 , 1. ])
Plot the sine function:
>>> import matplotlib.pylab as plt >>> x = np.linspace(-np.pi, np.pi, 201) >>> plt.plot(x, np.sin(x)) >>> plt.xlabel('Angle [rad]') >>> plt.ylabel('sin(x)') >>> plt.axis('tight') >>> plt.show()
-
symjax.tensor.
sinc
(x)[source]¶ Return the sinc function.
LAX-backend implementation of
sinc()
. Original docstring below.The sinc function is \(\sin(\pi x)/(\pi x)\).
- x : ndarray
- Array (possibly multi-dimensional) of values for which to to
calculate
sinc(x)
.
- out : ndarray
sinc(x)
, which has the same shape as the input.
sinc(0)
is the limit value 1.The name sinc is short for “sine cardinal” or “sinus cardinalis”.
The sinc function is used in various signal processing applications, including in anti-aliasing, in the construction of a Lanczos resampling filter, and in interpolation.
For bandlimited interpolation of discrete-time signals, the ideal interpolation kernel is proportional to the sinc function.
[1] Weisstein, Eric W. “Sinc Function.” From MathWorld–A Wolfram Web Resource. http://mathworld.wolfram.com/SincFunction.html [2] Wikipedia, “Sinc function”, https://en.wikipedia.org/wiki/Sinc_function >>> import matplotlib.pyplot as plt >>> x = np.linspace(-4, 4, 41) >>> np.sinc(x) array([-3.89804309e-17, -4.92362781e-02, -8.40918587e-02, # may vary -8.90384387e-02, -5.84680802e-02, 3.89804309e-17, 6.68206631e-02, 1.16434881e-01, 1.26137788e-01, 8.50444803e-02, -3.89804309e-17, -1.03943254e-01, -1.89206682e-01, -2.16236208e-01, -1.55914881e-01, 3.89804309e-17, 2.33872321e-01, 5.04551152e-01, 7.56826729e-01, 9.35489284e-01, 1.00000000e+00, 9.35489284e-01, 7.56826729e-01, 5.04551152e-01, 2.33872321e-01, 3.89804309e-17, -1.55914881e-01, -2.16236208e-01, -1.89206682e-01, -1.03943254e-01, -3.89804309e-17, 8.50444803e-02, 1.26137788e-01, 1.16434881e-01, 6.68206631e-02, 3.89804309e-17, -5.84680802e-02, -8.90384387e-02, -8.40918587e-02, -4.92362781e-02, -3.89804309e-17])
>>> plt.plot(x, np.sinc(x)) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Sinc Function") Text(0.5, 1.0, 'Sinc Function') >>> plt.ylabel("Amplitude") Text(0, 0.5, 'Amplitude') >>> plt.xlabel("X") Text(0.5, 0, 'X') >>> plt.show()
-
symjax.tensor.
sinh
(x)¶ Hyperbolic sine, element-wise.
LAX-backend implementation of
sinh()
. Original docstring below.sinh(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Equivalent to
1/2 * (np.exp(x) - np.exp(-x))
or-1j * np.sin(1j*x)
.Parameters: x (array_like) – Input array. Returns: y – The corresponding hyperbolic sine values. This is a scalar if x is a scalar. Return type: ndarray Notes
If out is provided, the function writes the result into it, and returns a reference to out. (See Examples)
References
M. Abramowitz and I. A. Stegun, Handbook of Mathematical Functions. New York, NY: Dover, 1972, pg. 83.
Examples
>>> np.sinh(0) 0.0 >>> np.sinh(np.pi*1j/2) 1j >>> np.sinh(np.pi*1j) # (exact value is 0) 1.2246063538223773e-016j >>> # Discrepancy due to vagaries of floating point arithmetic.
>>> # Example of providing the optional output parameter >>> out1 = np.array([0], dtype='d') >>> out2 = np.sinh([0.1], out1) >>> out2 is out1 True
>>> # Example of ValueError due to provision of shape mis-matched `out` >>> np.sinh(np.zeros((3,3)),np.zeros((2,2))) Traceback (most recent call last): File "<stdin>", line 1, in <module> ValueError: operands could not be broadcast together with shapes (3,3) (2,2)
-
symjax.tensor.
sometrue
(a, axis=None, out=None, keepdims=None)¶ Test whether any array element along a given axis evaluates to True.
LAX-backend implementation of
any()
. Original docstring below.Returns single boolean unless axis is not
None
Parameters: - a (array_like) – Input array or object that can be converted to an array.
- axis (None or int or tuple of ints, optional) – Axis or axes along which a logical OR reduction is performed.
The default (
axis=None
) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis. - out (ndarray, optional) – Alternate output array in which to place the result. It must have the same shape as the expected output and its type is preserved (e.g., if it is of type float, then it will remain so, returning 1.0 for True and 0.0 for False, regardless of the type of a). See ufuncs-output-type for more details.
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
Returns: any – A new boolean or ndarray is returned unless out is specified, in which case a reference to out is returned.
Return type: bool or ndarray
See also
ndarray.any()
- equivalent method
all()
- Test whether all elements along a given axis evaluate to True.
Notes
Not a Number (NaN), positive infinity and negative infinity evaluate to True because these are not equal to zero.
Examples
>>> np.any([[True, False], [True, True]]) True
>>> np.any([[True, False], [False, False]], axis=0) array([ True, False])
>>> np.any([-1, 0, 5]) True
>>> np.any(np.nan) True
>>> o=np.array(False) >>> z=np.any([-1, 4, 5], out=o) >>> z, o (array(True), array(True)) >>> # Check now that z is a reference to o >>> z is o True >>> id(z), id(o) # identity of z and o # doctest: +SKIP (191614240, 191614240)
-
symjax.tensor.
sort
(a, axis=-1, kind='quicksort', order=None)[source]¶ Return a sorted copy of an array.
LAX-backend implementation of
sort()
. Original docstring below.Parameters: - a (array_like) – Array to be sorted.
- axis (int or None, optional) – Axis along which to sort. If None, the array is flattened before sorting. The default is -1, which sorts along the last axis.
- kind ({'quicksort', 'mergesort', 'heapsort', 'stable'}, optional) – Sorting algorithm. The default is ‘quicksort’. Note that both ‘stable’ and ‘mergesort’ use timsort or radix sort under the covers and, in general, the actual implementation will vary with data type. The ‘mergesort’ option is retained for backwards compatibility.
- order (str or list of str, optional) – When a is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string, and not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties.
Returns: sorted_array – Array of the same type and shape as a.
Return type: ndarray
See also
ndarray.sort()
- Method to sort an array in-place.
argsort()
- Indirect sort.
lexsort()
- Indirect stable sort on multiple keys.
searchsorted()
- Find elements in a sorted array.
partition()
- Partial sort.
Notes
The various sorting algorithms are characterized by their average speed, worst case performance, work space size, and whether they are stable. A stable sort keeps items with the same key in the same relative order. The four algorithms implemented in NumPy have the following properties:
kind speed worst case work space stable ‘quicksort’ 1 O(n^2) 0 no ‘heapsort’ 3 O(n*log(n)) 0 no ‘mergesort’ 2 O(n*log(n)) ~n/2 yes ‘timsort’ 2 O(n*log(n)) ~n/2 yes Note
The datatype determines which of ‘mergesort’ or ‘timsort’ is actually used, even if ‘mergesort’ is specified. User selection at a finer scale is not currently available.
All the sort algorithms make temporary copies of the data when sorting along any but the last axis. Consequently, sorting along the last axis is faster and uses less space than sorting along any other axis.
The sort order for complex numbers is lexicographic. If both the real and imaginary parts are non-nan then the order is determined by the real parts except when they are equal, in which case the order is determined by the imaginary parts.
Previous to numpy 1.4.0 sorting real and complex arrays containing nan values led to undefined behaviour. In numpy versions >= 1.4.0 nan values are sorted to the end. The extended sort order is:
- Real: [R, nan]
- Complex: [R + Rj, R + nanj, nan + Rj, nan + nanj]
where R is a non-nan real value. Complex values with the same nan placements are sorted according to the non-nan part if it exists. Non-nan values are sorted as before.
New in version 1.12.0.
quicksort has been changed to introsort. When sorting does not make enough progress it switches to heapsort. This implementation makes quicksort O(n*log(n)) in the worst case.
‘stable’ automatically chooses the best stable sorting algorithm for the data type being sorted. It, along with ‘mergesort’ is currently mapped to timsort or radix sort depending on the data type. API forward compatibility currently limits the ability to select the implementation and it is hardwired for the different data types.
New in version 1.17.0.
Timsort is added for better performance on already or nearly sorted data. On random data timsort is almost identical to mergesort. It is now used for stable sort while quicksort is still the default sort if none is chosen. For timsort details, refer to CPython listsort.txt. ‘mergesort’ and ‘stable’ are mapped to radix sort for integer data types. Radix sort is an O(n) sort instead of O(n log n).
Changed in version 1.18.0.
NaT now sorts to the end of arrays for consistency with NaN.
Examples
>>> a = np.array([[1,4],[3,1]]) >>> np.sort(a) # sort along the last axis array([[1, 4], [1, 3]]) >>> np.sort(a, axis=None) # sort the flattened array array([1, 1, 3, 4]) >>> np.sort(a, axis=0) # sort along the first axis array([[1, 1], [3, 4]])
Use the order keyword to specify a field to use when sorting a structured array:
>>> dtype = [('name', 'S10'), ('height', float), ('age', int)] >>> values = [('Arthur', 1.8, 41), ('Lancelot', 1.9, 38), ... ('Galahad', 1.7, 38)] >>> a = np.array(values, dtype=dtype) # create a structured array >>> np.sort(a, order='height') # doctest: +SKIP array([('Galahad', 1.7, 38), ('Arthur', 1.8, 41), ('Lancelot', 1.8999999999999999, 38)], dtype=[('name', '|S10'), ('height', '<f8'), ('age', '<i4')])
Sort by age, then height if ages are equal:
>>> np.sort(a, order=['age', 'height']) # doctest: +SKIP array([('Galahad', 1.7, 38), ('Lancelot', 1.8999999999999999, 38), ('Arthur', 1.8, 41)], dtype=[('name', '|S10'), ('height', '<f8'), ('age', '<i4')])
-
symjax.tensor.
split
(ary, indices_or_sections, axis=0)[source]¶ Split an array into multiple sub-arrays as views into ary.
LAX-backend implementation of
split()
. Original docstring below.Parameters: - ary (ndarray) – Array to be divided into sub-arrays.
- indices_or_sections (int or 1-D array) – If indices_or_sections is an integer, N, the array will be divided into N equal arrays along axis. If such a split is not possible, an error is raised.
- axis (int, optional) – The axis along which to split, default is 0.
Returns: sub-arrays – A list of sub-arrays as views into ary.
Return type: list of ndarrays
Raises: ValueError
– If indices_or_sections is given as an integer, but a split does not result in equal division.See also
array_split()
- Split an array into multiple sub-arrays of equal or near-equal size. Does not raise an exception if an equal division cannot be made.
hsplit()
- Split array into multiple sub-arrays horizontally (column-wise).
vsplit()
- Split array into multiple sub-arrays vertically (row wise).
dsplit()
- Split array into multiple sub-arrays along the 3rd axis (depth).
concatenate()
- Join a sequence of arrays along an existing axis.
stack()
- Join a sequence of arrays along a new axis.
hstack()
- Stack arrays in sequence horizontally (column wise).
vstack()
- Stack arrays in sequence vertically (row wise).
dstack()
- Stack arrays in sequence depth wise (along third dimension).
Examples
>>> x = np.arange(9.0) >>> np.split(x, 3) [array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]
>>> x = np.arange(8.0) >>> np.split(x, [3, 5, 6, 10]) [array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.]), array([], dtype=float64)]
-
symjax.tensor.
sqrt
(x)¶ Return the non-negative square-root of an array, element-wise.
LAX-backend implementation of
sqrt()
. Original docstring below.sqrt(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – The values whose square-roots are required. Returns: y – An array of the same shape as x, containing the positive square-root of each element in x. If any element in x is complex, a complex array is returned (and the square-roots of negative reals are calculated). If all of the elements in x are real, so is y, with negative elements returning nan
. If out was provided, y is a reference to it. This is a scalar if x is a scalar.Return type: ndarray See also
lib.scimath.sqrt()
- A version which returns complex numbers when given negative reals.
Notes
sqrt has–consistent with common convention–as its branch cut the real “interval” [-inf, 0), and is continuous from above on it. A branch cut is a curve in the complex plane across which a given complex function fails to be continuous.
Examples
>>> np.sqrt([1,4,9]) array([ 1., 2., 3.])
>>> np.sqrt([4, -1, -3+4J]) array([ 2.+0.j, 0.+1.j, 1.+2.j])
>>> np.sqrt([4, -1, np.inf]) array([ 2., nan, inf])
-
symjax.tensor.
square
(x)[source]¶ Return the element-wise square of the input.
LAX-backend implementation of
square()
. Original docstring below.square(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x (array_like) – Input data. Returns: out – Element-wise x*x, of the same shape and dtype as x. This is a scalar if x is a scalar. Return type: ndarray or scalar Examples
>>> np.square([-1j, 1]) array([-1.-0.j, 1.+0.j])
-
symjax.tensor.
squeeze
(a, axis: Union[int, Tuple[int, ...]] = None)[source]¶ Remove single-dimensional entries from the shape of an array.
LAX-backend implementation of
squeeze()
. Original docstring below.Parameters: - a (array_like) – Input data.
- axis (None or int or tuple of ints, optional) –
New in version 1.7.0.
Returns: squeezed – The input array, but with all or a subset of the dimensions of length 1 removed. This is always a itself or a view into a. Note that if all axes are squeezed, the result is a 0d array and not a scalar.
Return type: ndarray
Raises: ValueError
– If axis is not None, and an axis being squeezed is not of length 1See also
expand_dims()
- The inverse operation, adding singleton dimensions
reshape()
- Insert, remove, and combine dimensions, and resize existing ones
Examples
>>> x = np.array([[[0], [1], [2]]]) >>> x.shape (1, 3, 1) >>> np.squeeze(x).shape (3,) >>> np.squeeze(x, axis=0).shape (3, 1) >>> np.squeeze(x, axis=1).shape Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one >>> np.squeeze(x, axis=2).shape (1, 3) >>> x = np.array([[1234]]) >>> x.shape (1, 1) >>> np.squeeze(x) array(1234) # 0d array >>> np.squeeze(x).shape () >>> np.squeeze(x)[()] 1234
-
symjax.tensor.
stack
(arrays, axis=0, out=None)[source]¶ Join a sequence of arrays along a new axis.
LAX-backend implementation of
stack()
. Original docstring below.The
axis
parameter specifies the index of the new axis in the dimensions of the result. For example, ifaxis=0
it will be the first dimension and ifaxis=-1
it will be the last dimension.New in version 1.10.0.
Parameters: - arrays (sequence of array_like) – Each array must have the same shape.
- axis (int, optional) – The axis in the result array along which the input arrays are stacked.
- out (ndarray, optional) – If provided, the destination to place the result. The shape must be correct, matching that of what stack would have returned if no out argument were specified.
Returns: stacked – The stacked array has one more dimension than the input arrays.
Return type: ndarray
See also
concatenate()
- Join a sequence of arrays along an existing axis.
block()
- Assemble an nd-array from nested lists of blocks.
split()
- Split array into a list of multiple sub-arrays of equal size.
Examples
>>> arrays = [np.random.randn(3, 4) for _ in range(10)] >>> np.stack(arrays, axis=0).shape (10, 3, 4)
>>> np.stack(arrays, axis=1).shape (3, 10, 4)
>>> np.stack(arrays, axis=2).shape (3, 4, 10)
>>> a = np.array([1, 2, 3]) >>> b = np.array([2, 3, 4]) >>> np.stack((a, b)) array([[1, 2, 3], [2, 3, 4]])
>>> np.stack((a, b), axis=-1) array([[1, 2], [2, 3], [3, 4]])
-
symjax.tensor.
std
(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False)[source]¶ Compute the standard deviation along the specified axis.
LAX-backend implementation of
std()
. Original docstring below.Returns the standard deviation, a measure of the spread of a distribution, of the array elements. The standard deviation is computed for the flattened array by default, otherwise over the specified axis.
Parameters: - a (array_like) – Calculate the standard deviation of these values.
- axis (None or int or tuple of ints, optional) – Axis or axes along which the standard deviation is computed. The default is to compute the standard deviation of the flattened array.
- dtype (dtype, optional) – Type to use in computing the standard deviation. For arrays of integer type the default is float64, for arrays of float types it is the same as the array type.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape as the expected output but the type (of the calculated values) will be cast if necessary.
- ddof (int, optional) – Means Delta Degrees of Freedom. The divisor used in calculations
is
N - ddof
, whereN
represents the number of elements. By default ddof is zero. - keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
Returns: standard_deviation – If out is None, return a new array containing the standard deviation, otherwise return a reference to the output array.
Return type: ndarray, see dtype parameter above.
Notes
The standard deviation is the square root of the average of the squared deviations from the mean, i.e.,
std = sqrt(mean(abs(x - x.mean())**2))
.The average squared deviation is normally calculated as
x.sum() / N
, whereN = len(x)
. If, however, ddof is specified, the divisorN - ddof
is used instead. In standard statistical practice,ddof=1
provides an unbiased estimator of the variance of the infinite population.ddof=0
provides a maximum likelihood estimate of the variance for normally distributed variables. The standard deviation computed in this function is the square root of the estimated variance, so even withddof=1
, it will not be an unbiased estimate of the standard deviation per se.Note that, for complex numbers, std takes the absolute value before squaring, so that the result is always real and nonnegative.
For floating-point input, the std is computed using the same precision the input has. Depending on the input data, this can cause the results to be inaccurate, especially for float32 (see example below). Specifying a higher-accuracy accumulator using the dtype keyword can alleviate this issue.
Examples
>>> a = np.array([[1, 2], [3, 4]]) >>> np.std(a) 1.1180339887498949 # may vary >>> np.std(a, axis=0) array([1., 1.]) >>> np.std(a, axis=1) array([0.5, 0.5])
In single precision, std() can be inaccurate:
>>> a = np.zeros((2, 512*512), dtype=np.float32) >>> a[0, :] = 1.0 >>> a[1, :] = 0.1 >>> np.std(a) 0.45000005
Computing the standard deviation in float64 is more accurate:
>>> np.std(a, dtype=np.float64) 0.44999999925494177 # may vary
-
symjax.tensor.
subtract
(x1, x2)¶ Subtract arguments, element-wise.
LAX-backend implementation of
subtract()
. Original docstring below.subtract(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Parameters: x2 (x1,) – The arrays to be subtracted from each other. If x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).Returns: y – The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. Return type: ndarray Notes
Equivalent to
x1 - x2
in terms of array broadcasting.Examples
>>> np.subtract(1.0, 4.0) -3.0
>>> x1 = np.arange(9.0).reshape((3, 3)) >>> x2 = np.arange(3.0) >>> np.subtract(x1, x2) array([[ 0., 0., 0.], [ 3., 3., 3.], [ 6., 6., 6.]])
-
symjax.tensor.
sum
(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)[source]¶ Sum of array elements over a given axis.
LAX-backend implementation of
sum()
. Original docstring below.Parameters: - a (array_like) – Elements to sum.
- axis (None or int or tuple of ints, optional) – Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
- dtype (dtype, optional) – The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
- out (ndarray, optional) – Alternative output array in which to place the result. It must have the same shape as the expected output, but the type of the output values will be cast if necessary.
- keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
- initial (scalar, optional) – Starting value for the sum. See ~numpy.ufunc.reduce for details.
- where (array_like of bool, optional) – Elements to include in the sum. See ~numpy.ufunc.reduce for details.
Returns: sum_along_axis – An array with the same shape as a, with the specified axis removed. If a is a 0-d array, or if axis is None, a scalar is returned. If an output array is specified, a reference to out is returned.
Return type: ndarray
See also
ndarray.sum()
- Equivalent method.
add.reduce()
- Equivalent functionality of add.
cumsum()
- Cumulative sum of array elements.
trapz()
- Integration of array values using the composite trapezoidal rule.
mean()
,average()
Notes
Arithmetic is modular when using integer types, and no error is raised on overflow.
The sum of an empty array is the neutral element 0:
>>> np.sum([]) 0.0
For floating point numbers the numerical precision of sum (and
np.add.reduce
) is in general limited by directly adding each number individually to the result causing rounding errors in every step. However, often numpy will use a numerically better approach (partial pairwise summation) leading to improved precision in many use-cases. This improved precision is always provided when noaxis
is given. Whenaxis
is given, it will depend on which axis is summed. Technically, to provide the best speed possible, the improved precision is only used when the summation is along the fast axis in memory. Note that the exact precision may vary depending on other parameters. In contrast to NumPy, Python’smath.fsum
function uses a slower but more precise approach to summation. Especially when summing a large number of lower precision floating point numbers, such asfloat32
, numerical errors can become significant. In such cases it can be advisable to use dtype=”float64” to use a higher precision for the output.Examples
>>> np.sum([0.5, 1.5]) 2.0 >>> np.sum([0.5, 0.7, 0.2, 1.5], dtype=np.int32) 1 >>> np.sum([[0, 1], [0, 5]]) 6 >>> np.sum([[0, 1], [0, 5]], axis=0) array([0, 6]) >>> np.sum([[0, 1], [0, 5]], axis=1) array([1, 5]) >>> np.sum([[0, 1], [np.nan, 5]], where=[False, True], axis=1) array([1., 5.])
If the accumulator is too small, overflow occurs:
>>> np.ones(128, dtype=np.int8).sum(dtype=np.int8) -128
You can also start the sum with a value other than zero:
>>> np.sum([10], initial=5) 15
-
symjax.tensor.
swapaxes
(a, axis1, axis2)[source]¶ Interchange two axes of an array.
LAX-backend implementation of
swapaxes()
. Original docstring below.Parameters: - a (array_like) – Input array.
- axis1 (int) – First axis.
- axis2 (int) – Second axis.
Returns: a_swapped – For NumPy >= 1.10.0, if a is an ndarray, then a view of a is returned; otherwise a new array is created. For earlier NumPy versions a view of a is returned only if the order of the axes is changed, otherwise the input array is returned.
Return type: ndarray
Examples
>>> x = np.array([[1,2,3]]) >>> np.swapaxes(x,0,1) array([[1], [2], [3]])
>>> x = np.array([[[0,1],[2,3]],[[4,5],[6,7]]]) >>> x array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
>>> np.swapaxes(x,0,2) array([[[0, 4], [2, 6]], [[1, 5], [3, 7]]])
-
symjax.tensor.
take
(a, indices, axis=None, out=None, mode=None)[source]¶ Take elements from an array along an axis.
LAX-backend implementation of
take()
. Original docstring below.When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier to use if you need elements along a given axis. A call such as
np.take(arr, indices, axis=3)
is equivalent toarr[:,:,:,indices,...]
.Explained without fancy indexing, this is equivalent to the following use of ndindex, which sets each of
ii
,jj
, andkk
to a tuple of indices:Ni, Nk = a.shape[:axis], a.shape[axis+1:] Nj = indices.shape for ii in ndindex(Ni): for jj in ndindex(Nj): for kk in ndindex(Nk): out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters: - a (array_like (Ni..., M, Nk...)) – The source array.
- indices (array_like (Nj...)) – The indices of the values to extract.
- axis (int, optional) – The axis over which to select values. By default, the flattened input array is used.
- out (ndarray, optional (Ni..., Nj..., Nk...)) – If provided, the result will be placed in this array. It should be of the appropriate shape and dtype. Note that out is always buffered if mode=’raise’; use other modes for better performance.
- mode ({'raise', 'wrap', 'clip'}, optional) – Specifies how out-of-bounds indices will behave.
Returns: out – The returned array has the same type as a.
Return type: ndarray (Ni…, Nj…, Nk…)
See also
compress()
- Take elements using a boolean mask
ndarray.take()
- equivalent method
take_along_axis()
- Take elements by matching the array and the index arrays
Notes
By eliminating the inner loop in the description above, and using s_ to build simple slice objects, take can be expressed in terms of applying fancy indexing to each 1-d slice:
Ni, Nk = a.shape[:axis], a.shape[axis+1:] for ii in ndindex(Ni): for kk in ndindex(Nj): out[ii + s_[...,] + kk] = a[ii + s_[:,] + kk][indices]
For this reason, it is equivalent to (but faster than) the following use of apply_along_axis:
out = np.apply_along_axis(lambda a_1d: a_1d[indices], axis, a)
Examples
>>> a = [4, 3, 5, 7, 6, 8] >>> indices = [0, 1, 4] >>> np.take(a, indices) array([4, 3, 6])
In this example if a is an ndarray, “fancy” indexing can be used.
>>> a = np.array(a) >>> a[indices] array([4, 3, 6])
If indices is not one dimensional, the output also has these dimensions.
>>> np.take(a, [[0, 1], [2, 3]]) array([[4, 3], [5, 7]])
-
symjax.tensor.
take_along_axis
(arr, indices, axis)[source]¶ Take values from the input array by matching 1d index and data slices.
LAX-backend implementation of
take_along_axis()
. Original docstring below.This iterates over matching 1d slices oriented along the specified axis in the index and data arrays, and uses the former to look up values in the latter. These slices can be different lengths.
Functions returning an index along an axis, like argsort and argpartition, produce suitable indices for this function.
New in version 1.15.0.
- arr: ndarray (Ni…, M, Nk…)
- Source array
- indices: ndarray (Ni…, J, Nk…)
- Indices to take along each 1d slice of arr. This must match the dimension of arr, but dimensions Ni and Nj only need to broadcast against arr.
- axis: int
- The axis to take 1d slices along. If axis is None, the input array is treated as if it had first been flattened to 1d, for consistency with sort and argsort.
- out: ndarray (Ni…, J, Nk…)
- The indexed result.
This is equivalent to (but faster than) the following use of ndindex and s_, which sets each of
ii
andkk
to a tuple of indices:Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:] J = indices.shape[axis] # Need not equal M out = np.empty(Ni + (J,) + Nk) for ii in ndindex(Ni): for kk in ndindex(Nk): a_1d = a [ii + s_[:,] + kk] indices_1d = indices[ii + s_[:,] + kk] out_1d = out [ii + s_[:,] + kk] for j in range(J): out_1d[j] = a_1d[indices_1d[j]]
Equivalently, eliminating the inner loop, the last two lines would be:
out_1d[:] = a_1d[indices_1d]
take : Take along an axis, using the same indices for every 1d slice put_along_axis :
Put values into the destination array by matching 1d index and data slicesFor this sample array
>>> a = np.array([[10, 30, 20], [60, 40, 50]])
We can sort either by using sort directly, or argsort and this function
>>> np.sort(a, axis=1) array([[10, 20, 30], [40, 50, 60]]) >>> ai = np.argsort(a, axis=1); ai array([[0, 2, 1], [1, 2, 0]]) >>> np.take_along_axis(a, ai, axis=1) array([[10, 20, 30], [40, 50, 60]])
The same works for max and min, if you expand the dimensions:
>>> np.expand_dims(np.max(a, axis=1), axis=1) array([[30], [60]]) >>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1) >>> ai array([[1], [0]]) >>> np.take_along_axis(a, ai, axis=1) array([[30], [60]])
If we want to get the max and min at the same time, we can stack the indices first
>>> ai_min = np.expand_dims(np.argmin(a, axis=1), axis=1) >>> ai_max = np.expand_dims(np.argmax(a, axis=1), axis=1) >>> ai = np.concatenate([ai_min, ai_max], axis=1) >>> ai array([[0, 1], [1, 0]]) >>> np.take_along_axis(a, ai, axis=1) array([[10, 30], [40, 60]])
-
symjax.tensor.
tan
(x)¶ Compute tangent element-wise.
LAX-backend implementation of
tan()
. Original docstring below.tan(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Equivalent to
np.sin(x)/np.cos(x)
element-wise.Parameters: x (array_like) – Input array. Returns: y – The corresponding tangent values. This is a scalar if x is a scalar. Return type: ndarray Notes
If out is provided, the function writes the result into it, and returns a reference to out. (See Examples)
References
M. Abramowitz and I. A. Stegun, Handbook of Mathematical Functions. New York, NY: Dover, 1972.
Examples
>>> from math import pi >>> np.tan(np.array([-pi,pi/2,pi])) array([ 1.22460635e-16, 1.63317787e+16, -1.22460635e-16]) >>> >>> # Example of providing the optional output parameter illustrating >>> # that what is returned is a reference to said parameter >>> out1 = np.array([0], dtype='d') >>> out2 = np.cos([0.1], out1) >>> out2 is out1 True >>> >>> # Example of ValueError due to provision of shape mis-matched `out` >>> np.cos(np.zeros((3,3)),np.zeros((2,2))) Traceback (most recent call last): File "<stdin>", line 1, in <module> ValueError: operands could not be broadcast together with shapes (3,3) (2,2)
-
symjax.tensor.
tanh
(x)¶ Compute hyperbolic tangent element-wise.
LAX-backend implementation of
tanh()
. Original docstring below.tanh(x, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Equivalent to
np.sinh(x)/np.cosh(x)
or-1j * np.tan(1j*x)
.Parameters: x (array_like) – Input array. Returns: y – The corresponding hyperbolic tangent values. This is a scalar if x is a scalar. Return type: ndarray Notes
If out is provided, the function writes the result into it, and returns a reference to out. (See Examples)
References
[1] M. Abramowitz and I. A. Stegun, Handbook of Mathematical Functions. New York, NY: Dover, 1972, pg. 83. http://www.math.sfu.ca/~cbm/aands/ [2] Wikipedia, “Hyperbolic function”, https://en.wikipedia.org/wiki/Hyperbolic_function Examples
>>> np.tanh((0, np.pi*1j, np.pi*1j/2)) array([ 0. +0.00000000e+00j, 0. -1.22460635e-16j, 0. +1.63317787e+16j])
>>> # Example of providing the optional output parameter illustrating >>> # that what is returned is a reference to said parameter >>> out1 = np.array([0], dtype='d') >>> out2 = np.tanh([0.1], out1) >>> out2 is out1 True
>>> # Example of ValueError due to provision of shape mis-matched `out` >>> np.tanh(np.zeros((3,3)),np.zeros((2,2))) Traceback (most recent call last): File "<stdin>", line 1, in <module> ValueError: operands could not be broadcast together with shapes (3,3) (2,2)
-
symjax.tensor.
tensordot
(a, b, axes=2, *, precision=None)[source]¶ Compute tensor dot product along specified axes.
LAX-backend implementation of
tensordot()
. In addition to the original NumPy arguments listed below, also supportsprecision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
Given two tensors, a and b, and an array_like object containing two array_like objects,
(a_axes, b_axes)
, sum the products of a’s and b’s elements (components) over the axes specified bya_axes
andb_axes
. The third argument can be a single non-negative integer_like scalar,N
; if it is such, then the lastN
dimensions of a and the firstN
dimensions of b are summed over.Parameters: - b (a,) – Tensors to “dot”.
- axes (int or (2,) array_like) –
- integer_like If an int N, sum over the last N axes of a and the first N axes of b in order. The sizes of the corresponding axes must match.
- (2,) array_like Or, a list of axes to be summed over, first sequence applying to a, second to b. Both elements array_like must be of the same length.
Returns: output – The tensor dot product of the input.
Return type: ndarray
Notes
- Three common use cases are:
axes = 0
: tensor product \(a\otimes b\)axes = 1
: tensor dot product \(a\cdot b\)axes = 2
: (default) tensor double contraction \(a:b\)
When axes is integer_like, the sequence for evaluation will be: first the -Nth axis in a and 0th axis in b, and the -1th axis in a and Nth axis in b last.
When there is more than one axis to sum over - and they are not the last (first) axes of a (b) - the argument axes should consist of two sequences of the same length, with the first axis to sum over given first in both sequences, the second axis second, and so forth.
The shape of the result consists of the non-contracted axes of the first tensor, followed by the non-contracted axes of the second.
Examples
A “traditional” example:
>>> a = np.arange(60.).reshape(3,4,5) >>> b = np.arange(24.).reshape(4,3,2) >>> c = np.tensordot(a,b, axes=([1,0],[0,1])) >>> c.shape (5, 2) >>> c array([[4400., 4730.], [4532., 4874.], [4664., 5018.], [4796., 5162.], [4928., 5306.]]) >>> # A slower but equivalent way of computing the same... >>> d = np.zeros((5,2)) >>> for i in range(5): ... for j in range(2): ... for k in range(3): ... for n in range(4): ... d[i,j] += a[k,n,i] * b[n,k,j] >>> c == d array([[ True, True], [ True, True], [ True, True], [ True, True], [ True, True]])
An extended example taking advantage of the overloading of + and *:
>>> a = np.array(range(1, 9)) >>> a.shape = (2, 2, 2) >>> A = np.array(('a', 'b', 'c', 'd'), dtype=object) >>> A.shape = (2, 2) >>> a; A array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) array([['a', 'b'], ['c', 'd']], dtype=object)
>>> np.tensordot(a, A) # third argument default is 2 for double-contraction array(['abbcccdddd', 'aaaaabbbbbbcccccccdddddddd'], dtype=object)
>>> np.tensordot(a, A, 1) array([[['acc', 'bdd'], ['aaacccc', 'bbbdddd']], [['aaaaacccccc', 'bbbbbdddddd'], ['aaaaaaacccccccc', 'bbbbbbbdddddddd']]], dtype=object)
>>> np.tensordot(a, A, 0) # tensor product (result too long to incl.) array([[[[['a', 'b'], ['c', 'd']], ...
>>> np.tensordot(a, A, (0, 1)) array([[['abbbbb', 'cddddd'], ['aabbbbbb', 'ccdddddd']], [['aaabbbbbbb', 'cccddddddd'], ['aaaabbbbbbbb', 'ccccdddddddd']]], dtype=object)
>>> np.tensordot(a, A, (2, 1)) array([[['abb', 'cdd'], ['aaabbbb', 'cccdddd']], [['aaaaabbbbbb', 'cccccdddddd'], ['aaaaaaabbbbbbbb', 'cccccccdddddddd']]], dtype=object)
>>> np.tensordot(a, A, ((0, 1), (0, 1))) array(['abbbcccccddddddd', 'aabbbbccccccdddddddd'], dtype=object)
>>> np.tensordot(a, A, ((2, 1), (1, 0))) array(['acccbbdddd', 'aaaaacccccccbbbbbbdddddddd'], dtype=object)
-
symjax.tensor.
tile
(A, reps)[source]¶ Construct an array by repeating A the number of times given by reps.
LAX-backend implementation of
tile()
. Original docstring below.If reps has length
d
, the result will have dimension ofmax(d, A.ndim)
.If
A.ndim < d
, A is promoted to be d-dimensional by prepending new axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D replication. If this is not the desired behavior, promote A to d-dimensions manually before calling this function.If
A.ndim > d
, reps is promoted to A.ndim by pre-pending 1’s to it. Thus for an A of shape (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2).Note : Although tile may be used for broadcasting, it is strongly recommended to use numpy’s broadcasting operations and functions.
Parameters: - A (array_like) – The input array.
- reps (array_like) – The number of repetitions of A along each axis.
Returns: c – The tiled output array.
Return type: ndarray
See also
repeat()
- Repeat elements of an array.
broadcast_to()
- Broadcast an array to a new shape
Examples
>>> a = np.array([0, 1, 2]) >>> np.tile(a, 2) array([0, 1, 2, 0, 1, 2]) >>> np.tile(a, (2, 2)) array([[0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2]]) >>> np.tile(a, (2, 1, 2)) array([[[0, 1, 2, 0, 1, 2]], [[0, 1, 2, 0, 1, 2]]])
>>> b = np.array([[1, 2], [3, 4]]) >>> np.tile(b, 2) array([[1, 2, 1, 2], [3, 4, 3, 4]]) >>> np.tile(b, (2, 1)) array([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> c = np.array([1,2,3,4]) >>> np.tile(c,(4,1)) array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]])
-
symjax.tensor.
trace
(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[source]¶ Return the sum along diagonals of the array.
LAX-backend implementation of
trace()
. Original docstring below.If a is 2-D, the sum along its diagonal with the given offset is returned, i.e., the sum of elements
a[i,i+offset]
for all i.If a has more than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2-D sub-arrays whose traces are returned. The shape of the resulting array is the same as that of a with axis1 and axis2 removed.
Parameters: - a (array_like) – Input array, from which the diagonals are taken.
- offset (int, optional) – Offset of the diagonal from the main diagonal. Can be both positive and negative. Defaults to 0.
- axis2 (axis1,) – Axes to be used as the first and second axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults are the first two axes of a.
- dtype (dtype, optional) – Determines the data-type of the returned array and of the accumulator where the elements are summed. If dtype has the value None and a is of integer type of precision less than the default integer precision, then the default integer precision is used. Otherwise, the precision is the same as that of a.
- out (ndarray, optional) – Array into which the output is placed. Its type is preserved and it must be of the right shape to hold the output.
Returns: sum_along_diagonals – If a is 2-D, the sum along the diagonal is returned. If a has larger dimensions, then an array of sums along diagonals is returned.
Return type: ndarray
See also
diag()
,diagonal()
,diagflat()
Examples
>>> np.trace(np.eye(3)) 3.0 >>> a = np.arange(8).reshape((2,2,2)) >>> np.trace(a) array([6, 8])
>>> a = np.arange(24).reshape((2,2,2,3)) >>> np.trace(a).shape (2, 3)
-
symjax.tensor.
transpose
(a, axes=None)[source]¶ Reverse or permute the axes of an array; returns the modified array.
LAX-backend implementation of
transpose()
. Original docstring below.For an array a with two axes, transpose(a) gives the matrix transpose.
Parameters: - a (array_like) – Input array.
- axes (tuple or list of ints, optional) – If specified, it must be a tuple or list which contains a permutation of
[0,1,..,N-1] where N is the number of axes of a. The i’th axis of the
returned array will correspond to the axis numbered
axes[i]
of the input. If not specified, defaults torange(a.ndim)[::-1]
, which reverses the order of the axes.
Returns: p – a with its axes permuted. A view is returned whenever possible.
Return type: ndarray
See also
Notes
Use transpose(a, argsort(axes)) to invert the transposition of tensors when using the axes keyword argument.
Transposing a 1-D array returns an unchanged view of the original array.
Examples
>>> x = np.arange(4).reshape((2,2)) >>> x array([[0, 1], [2, 3]])
>>> np.transpose(x) array([[0, 2], [1, 3]])
>>> x = np.ones((1, 2, 3)) >>> np.transpose(x, (1, 0, 2)).shape (2, 1, 3)
-
symjax.tensor.
tri
(N, M=None, k=0, dtype=None)[source]¶ An array with ones at and below the given diagonal and zeros elsewhere.
LAX-backend implementation of
tri()
. Original docstring below.Parameters: - N (int) – Number of rows in the array.
- M (int, optional) – Number of columns in the array. By default, M is taken equal to N.
- k (int, optional) – The sub-diagonal at and below which the array is filled. k = 0 is the main diagonal, while k < 0 is below it, and k > 0 is above. The default is 0.
- dtype (dtype, optional) – Data type of the returned array. The default is float.
Returns: tri – Array with its lower triangle filled with ones and zero elsewhere; in other words
T[i,j] == 1
forj <= i + k
, 0 otherwise.Return type: ndarray of shape (N, M)
Examples
>>> np.tri(3, 5, 2, dtype=int) array([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]])
>>> np.tri(3, 5, -1) array([[0., 0., 0., 0., 0.], [1., 0., 0., 0., 0.], [1., 1., 0., 0., 0.]])
-
symjax.tensor.
tril
(m, k=0)[source]¶ Lower triangle of an array.
LAX-backend implementation of
tril()
. Original docstring below.Return a copy of an array with elements above the k-th diagonal zeroed.
Parameters: - m (array_like, shape (M, N)) – Input array.
- k (int, optional) – Diagonal above which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above.
Returns: tril – Lower triangle of m, of same shape and data-type as m.
Return type: ndarray, shape (M, N)
See also
triu()
- same thing, only for the upper triangle
Examples
>>> np.tril([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) array([[ 0, 0, 0], [ 4, 0, 0], [ 7, 8, 0], [10, 11, 12]])
-
symjax.tensor.
tril_indices
(*args, **kwargs)¶ Return the indices for the lower-triangle of an (n, m) array.
LAX-backend implementation of
tril_indices()
. Original docstring below.- n : int
- The row dimension of the arrays for which the returned indices will be valid.
- k : int, optional
- Diagonal offset (see tril for details).
- m : int, optional
New in version 1.9.0.
The column dimension of the arrays for which the returned arrays will be valid. By default m is taken equal to n.
- inds : tuple of arrays
- The indices for the triangle. The returned tuple contains two arrays, each with the indices along one dimension of the array.
triu_indices : similar function, for upper-triangular. mask_indices : generic function accepting an arbitrary mask function. tril, triu
New in version 1.4.0.
Compute two different sets of indices to access 4x4 arrays, one for the lower triangular part starting at the main diagonal, and one starting two diagonals further right:
>>> il1 = np.tril_indices(4) >>> il2 = np.tril_indices(4, 2)
Here is how they can be used with a sample array:
>>> a = np.arange(16).reshape(4, 4) >>> a array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]])
Both for indexing:
>>> a[il1] array([ 0, 4, 5, ..., 13, 14, 15])
And for assigning values:
>>> a[il1] = -1 >>> a array([[-1, 1, 2, 3], [-1, -1, 6, 7], [-1, -1, -1, 11], [-1, -1, -1, -1]])
These cover almost the whole array (two diagonals right of the main one):
>>> a[il2] = -10 >>> a array([[-10, -10, -10, 3], [-10, -10, -10, -10], [-10, -10, -10, -10], [-10, -10, -10, -10]])
-
symjax.tensor.
triu
(m, k=0)[source]¶ Upper triangle of an array.
LAX-backend implementation of
triu()
. Original docstring below.Return a copy of a matrix with the elements below the k-th diagonal zeroed.
Please refer to the documentation for tril for further details.
tril : lower triangle of an array
>>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) array([[ 1, 2, 3], [ 4, 5, 6], [ 0, 8, 9], [ 0, 0, 12]])
-
symjax.tensor.
triu_indices
(*args, **kwargs)¶ Return the indices for the upper-triangle of an (n, m) array.
LAX-backend implementation of
triu_indices()
. Original docstring below.- n : int
- The size of the arrays for which the returned indices will be valid.
- k : int, optional
- Diagonal offset (see triu for details).
- m : int, optional
New in version 1.9.0.
The column dimension of the arrays for which the returned arrays will be valid. By default m is taken equal to n.
- inds : tuple, shape(2) of ndarrays, shape(n)
- The indices for the triangle. The returned tuple contains two arrays, each with the indices along one dimension of the array. Can be used to slice a ndarray of shape(n, n).
tril_indices : similar function, for lower-triangular. mask_indices : generic function accepting an arbitrary mask function. triu, tril
New in version 1.4.0.
Compute two different sets of indices to access 4x4 arrays, one for the upper triangular part starting at the main diagonal, and one starting two diagonals further right:
>>> iu1 = np.triu_indices(4) >>> iu2 = np.triu_indices(4, 2)
Here is how they can be used with a sample array:
>>> a = np.arange(16).reshape(4, 4) >>> a array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]])
Both for indexing:
>>> a[iu1] array([ 0, 1, 2, ..., 10, 11, 15])
And for assigning values:
>>> a[iu1] = -1 >>> a array([[-1, -1, -1, -1], [ 4, -1, -1, -1], [ 8, 9, -1, -1], [12, 13, 14, -1]])
These cover only a small part of the whole array (two diagonals right of the main one):
>>> a[iu2] = -10 >>> a array([[ -1, -1, -10, -10], [ 4, -1, -1, -10], [ 8, 9, -1, -1], [ 12, 13, 14, -1]])
-
symjax.tensor.
true_divide
(x1, x2)[source]¶ Returns a true division of the inputs, element-wise.
LAX-backend implementation of
true_divide()
. Original docstring below.true_divide(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
Instead of the Python traditional ‘floor division’, this returns a true division. True division adjusts the output type to present the best answer, regardless of input types.
Parameters: - x1 (array_like) – Dividend array.
- x2 (array_like) – Divisor array.
If
x1.shape != x2.shape
, they must be broadcastable to a common shape (which becomes the shape of the output).
Returns: out – This is a scalar if both x1 and x2 are scalars.
Return type: ndarray or scalar
Notes
In Python,
//
is the floor division operator and/
the true division operator. Thetrue_divide(x1, x2)
function is equivalent to true division in Python.Examples
>>> x = np.arange(5) >>> np.true_divide(x, 4) array([ 0. , 0.25, 0.5 , 0.75, 1. ])
>>> x/4 array([ 0. , 0.25, 0.5 , 0.75, 1. ])
>>> x//4 array([0, 0, 0, 0, 1])
-
symjax.tensor.
vander
(x, N=None, increasing=False)[source]¶ Generate a Vandermonde matrix.
LAX-backend implementation of
vander()
. Original docstring below.The columns of the output matrix are powers of the input vector. The order of the powers is determined by the increasing boolean argument. Specifically, when increasing is False, the i-th output column is the input vector raised element-wise to the power of
N - i - 1
. Such a matrix with a geometric progression in each row is named for Alexandre- Theophile Vandermonde.Parameters: - x (array_like) – 1-D input array.
- N (int, optional) – Number of columns in the output. If N is not specified, a square
array is returned (
N = len(x)
). - increasing (bool, optional) – Order of the powers of the columns. If True, the powers increase from left to right, if False (the default) they are reversed.
Returns: out – Vandermonde matrix. If increasing is False, the first column is
x^(N-1)
, the secondx^(N-2)
and so forth. If increasing is True, the columns arex^0, x^1, ..., x^(N-1)
.Return type: ndarray
See also
polynomial.polynomial.polyvander()
Examples
>>> x = np.array([1, 2, 3, 5]) >>> N = 3 >>> np.vander(x, N) array([[ 1, 1, 1], [ 4, 2, 1], [ 9, 3, 1], [25, 5, 1]])
>>> np.column_stack([x**(N-1-i) for i in range(N)]) array([[ 1, 1, 1], [ 4, 2, 1], [ 9, 3, 1], [25, 5, 1]])
>>> x = np.array([1, 2, 3, 5]) >>> np.vander(x) array([[ 1, 1, 1, 1], [ 8, 4, 2, 1], [ 27, 9, 3, 1], [125, 25, 5, 1]]) >>> np.vander(x, increasing=True) array([[ 1, 1, 1, 1], [ 1, 2, 4, 8], [ 1, 3, 9, 27], [ 1, 5, 25, 125]])
The determinant of a square Vandermonde matrix is the product of the differences between the values of the input vector:
>>> np.linalg.det(np.vander(x)) 48.000000000000043 # may vary >>> (5-3)*(5-2)*(5-1)*(3-2)*(3-1)*(2-1) 48
-
symjax.tensor.
var
(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False)[source]¶ Compute the variance along the specified axis.
LAX-backend implementation of
var()
. Original docstring below.Returns the variance of the array elements, a measure of the spread of a distribution. The variance is computed for the flattened array by default, otherwise over the specified axis.
Parameters: - a (array_like) – Array containing numbers whose variance is desired. If a is not an array, a conversion is attempted.
- axis (None or int or tuple of ints, optional) – Axis or axes along which the variance is computed. The default is to compute the variance of the flattened array.
- dtype (data-type, optional) – Type to use in computing the variance. For arrays of integer type the default is float64; for arrays of float types it is the same as the array type.
- out (ndarray, optional) – Alternate output array in which to place the result. It must have the same shape as the expected output, but the type is cast if necessary.
- ddof (int, optional) – “Delta Degrees of Freedom”: the divisor used in the calculation is
N - ddof
, whereN
represents the number of elements. By default ddof is zero. - keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
Returns: variance – If
out=None
, returns a new array containing the variance; otherwise, a reference to the output array is returned.Return type: ndarray, see dtype parameter above
Notes
The variance is the average of the squared deviations from the mean, i.e.,
var = mean(abs(x - x.mean())**2)
.The mean is normally calculated as
x.sum() / N
, whereN = len(x)
. If, however, ddof is specified, the divisorN - ddof
is used instead. In standard statistical practice,ddof=1
provides an unbiased estimator of the variance of a hypothetical infinite population.ddof=0
provides a maximum likelihood estimate of the variance for normally distributed variables.Note that for complex numbers, the absolute value is taken before squaring, so that the result is always real and nonnegative.
For floating-point input, the variance is computed using the same precision the input has. Depending on the input data, this can cause the results to be inaccurate, especially for float32 (see example below). Specifying a higher-accuracy accumulator using the
dtype
keyword can alleviate this issue.Examples
>>> a = np.array([[1, 2], [3, 4]]) >>> np.var(a) 1.25 >>> np.var(a, axis=0) array([1., 1.]) >>> np.var(a, axis=1) array([0.25, 0.25])
In single precision, var() can be inaccurate:
>>> a = np.zeros((2, 512*512), dtype=np.float32) >>> a[0, :] = 1.0 >>> a[1, :] = 0.1 >>> np.var(a) 0.20250003
Computing the variance in float64 is more accurate:
>>> np.var(a, dtype=np.float64) 0.20249999932944759 # may vary >>> ((1-0.55)**2 + (0.1-0.55)**2)/2 0.2025
-
symjax.tensor.
vdot
(a, b, *, precision=None)[source]¶ Return the dot product of two vectors.
LAX-backend implementation of
vdot()
. In addition to the original NumPy arguments listed below, also supportsprecision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
vdot(a, b)
The vdot(a, b) function handles complex numbers differently than dot(a, b). If the first argument is complex the complex conjugate of the first argument is used for the calculation of the dot product.
Note that vdot handles multidimensional arrays differently than dot: it does not perform a matrix product, but flattens input arguments to 1-D vectors first. Consequently, it should only be used for vectors.
- Returns
- output : ndarray
- Dot product of a and b. Can be an int, float, or complex depending on the types of a and b.
- dot : Return the dot product without using the complex conjugate of the
- first argument.
>>> a = np.array([1+2j,3+4j]) >>> b = np.array([5+6j,7+8j]) >>> np.vdot(a, b) (70-8j) >>> np.vdot(b, a) (70+8j)
Note that higher-dimensional arrays are flattened!
>>> a = np.array([[1, 4], [5, 6]]) >>> b = np.array([[4, 1], [2, 2]]) >>> np.vdot(a, b) 30 >>> np.vdot(b, a) 30 >>> 1*4 + 4*1 + 5*2 + 6*2 30
-
symjax.tensor.
vsplit
(ary, indices_or_sections)¶ Split an array into multiple sub-arrays vertically (row-wise).
LAX-backend implementation of
vsplit()
. Original docstring below.Please refer to the
split
documentation.vsplit
is equivalent tosplit
with axis=0 (default), the array is always split along the first axis regardless of the array dimension.split : Split an array into multiple sub-arrays of equal size.
>>> x = np.arange(16.0).reshape(4, 4) >>> x array([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.], [12., 13., 14., 15.]]) >>> np.vsplit(x, 2) [array([[0., 1., 2., 3.], [4., 5., 6., 7.]]), array([[ 8., 9., 10., 11.], [12., 13., 14., 15.]])] >>> np.vsplit(x, np.array([3, 6])) [array([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]]), array([[12., 13., 14., 15.]]), array([], shape=(0, 4), dtype=float64)]
With a higher dimensional array the split is still along the first axis.
>>> x = np.arange(8.0).reshape(2, 2, 2) >>> x array([[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]) >>> np.vsplit(x, 2) [array([[[0., 1.], [2., 3.]]]), array([[[4., 5.], [6., 7.]]])]
-
symjax.tensor.
vstack
(tup)[source]¶ Stack arrays in sequence vertically (row wise).
LAX-backend implementation of
vstack()
. Original docstring below.This is equivalent to concatenation along the first axis after 1-D arrays of shape (N,) have been reshaped to (1,N). Rebuilds arrays divided by vsplit.
This function makes most sense for arrays with up to 3 dimensions. For instance, for pixel-data with a height (first axis), width (second axis), and r/g/b channels (third axis). The functions concatenate, stack and block provide more general stacking and concatenation operations.
Parameters: tup (sequence of ndarrays) – The arrays must have the same shape along all but the first axis. 1-D arrays must have the same length. Returns: stacked – The array formed by stacking the given arrays, will be at least 2-D. Return type: ndarray See also
concatenate()
- Join a sequence of arrays along an existing axis.
stack()
- Join a sequence of arrays along a new axis.
block()
- Assemble an nd-array from nested lists of blocks.
hstack()
- Stack arrays in sequence horizontally (column wise).
dstack()
- Stack arrays in sequence depth wise (along third axis).
column_stack()
- Stack 1-D arrays as columns into a 2-D array.
vsplit()
- Split an array into multiple sub-arrays vertically (row-wise).
Examples
>>> a = np.array([1, 2, 3]) >>> b = np.array([2, 3, 4]) >>> np.vstack((a,b)) array([[1, 2, 3], [2, 3, 4]])
>>> a = np.array([[1], [2], [3]]) >>> b = np.array([[2], [3], [4]]) >>> np.vstack((a,b)) array([[1], [2], [3], [2], [3], [4]])
-
symjax.tensor.
zeros
(shape, dtype=None)[source]¶ Return a new array of given shape and type, filled with zeros.
LAX-backend implementation of
zeros()
. Original docstring below.zeros(shape, dtype=float, order=’C’)
- Returns
- out : ndarray
- Array of zeros with the given shape, dtype, and order.
zeros_like : Return an array of zeros with shape and type of input. empty : Return a new uninitialized array. ones : Return a new array setting values to one. full : Return a new array of given shape filled with value.
>>> np.zeros(5) array([ 0., 0., 0., 0., 0.])
>>> np.zeros((5,), dtype=int) array([0, 0, 0, 0, 0])
>>> np.zeros((2, 1)) array([[ 0.], [ 0.]])
>>> s = (2,2) >>> np.zeros(s) array([[ 0., 0.], [ 0., 0.]])
>>> np.zeros((2,), dtype=[('x', 'i4'), ('y', 'i4')]) # custom dtype array([(0, 0), (0, 0)], dtype=[('x', '<i4'), ('y', '<i4')])
-
symjax.tensor.
stop_gradient
(x)[source]¶ Stops gradient computation.
Operationally
stop_gradient
is the identity function, that is, it returns argument x unchanged. However,stop_gradient
prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations,stop_gradient
stops gradients for all of them.For example:
>>> jax.grad(lambda x: x**2)(3.) array(6., dtype=float32) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) array(0., dtype=float32) >>> jax.grad(jax.grad(lambda x: x**2))(3.) array(2., dtype=float32) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) array(0., dtype=float32)
-
symjax.tensor.
dimshuffle
(tensor, pattern)[source]¶ Reorder the dimensions of this variable, optionally inserting broadcasted dimensions.
Parameters: - tensor (Tensor) –
- pattern (list of int and str) – List/tuple of int mixed with ‘x’ for broadcastable dimensions.
Examples
For example, to create a 3D view of a [2D] matrix, call
dimshuffle([0,'x',1])
. This will create a 3D view such that the middle dimension is an implicit broadcasted dimension. To do the same thing on the transpose of that matrix, calldimshuffle([1, 'x', 0])
.Notes
This function supports the pattern passed as a tuple, or as a variable-length argument (e.g.
a.dimshuffle(pattern)
is equivalent toa.dimshuffle(*pattern)
wherepattern
is a list/tuple of ints mixed with ‘x’ characters).
-
symjax.tensor.
index
()¶ Index object singleton
-
symjax.tensor.
index_update
(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]¶ Pure equivalent of
x[idx] = y
.Returns the value of x that would result from the NumPy-style
indexed assignment
:x[idx] = y
Note the index_update operator is pure; x itself is not modified, instead the new value that x would have taken is returned.
Unlike NumPy’s
x[idx] = y
, if multiple indices refer to the same location it is undefined which update is chosen; JAX may choose the order of updates arbitrarily and nondeterministically (e.g., due to concurrent updates on some hardware platforms).Parameters: - x – an array with the values to be updated.
- idx – a Numpy-style index, consisting of None, integers, slice objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
jax.ops.index
object. - y – the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].
- indices_are_sorted – whether idx is known to be sorted
- unique_indices – whether idx is known to be free of duplicates
Returns: An array.
>>> x = jax.numpy.ones((5, 6)) >>> jax.ops.index_update(x, jax.ops.index[::2, 3:], 6.) array([[1., 1., 1., 6., 6., 6.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 6., 6., 6.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 6., 6., 6.]], dtype=float32)
-
symjax.tensor.
index_min
(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]¶ Pure equivalent of
x[idx] = minimum(x[idx], y)
.Returns the value of x that would result from the NumPy-style
indexed assignment
:x[idx] = minimum(x[idx], y)
Note the index_min operator is pure; x itself is not modified, instead the new value that x would have taken is returned.
Unlike the NumPy code
x[idx] = minimum(x[idx], y)
, if multiple indices refer to the same location the final value will be the overall min. (NumPy would only look at the last update, rather than all of the updates.)Parameters: - x – an array with the values to be updated.
- idx – a Numpy-style index, consisting of None, integers, slice objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
jax.ops.index
object. - y – the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].
- indices_are_sorted – whether idx is known to be sorted
- unique_indices – whether idx is known to be free of duplicates
Returns: An array.
>>> x = jax.numpy.ones((5, 6)) >>> jax.ops.index_minimum(x, jax.ops.index[2:4, 3:], 0.) array([[1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 0., 0., 0.], [1., 1., 1., 0., 0., 0.], [1., 1., 1., 1., 1., 1.]], dtype=float32)
-
symjax.tensor.
index_add
(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]¶ Pure equivalent of
x[idx] += y
.Returns the value of x that would result from the NumPy-style
indexed assignment
:x[idx] += y
Note the index_add operator is pure; x itself is not modified, instead the new value that x would have taken is returned.
Unlike the NumPy code
x[idx] += y
, if multiple indices refer to the same location the updates will be summed. (NumPy would only apply the last update, rather than summing the updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).Parameters: - x – an array with the values to be updated.
- idx – a Numpy-style index, consisting of None, integers, slice objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
jax.ops.index
object. - y – the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].
- indices_are_sorted – whether idx is known to be sorted
- unique_indices – whether idx is known to be free of duplicates
Returns: An array.
>>> x = jax.numpy.ones((5, 6)) >>> jax.ops.index_add(x, jax.ops.index[2:4, 3:], 6.) array([[1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 7., 7., 7.], [1., 1., 1., 7., 7., 7.], [1., 1., 1., 1., 1., 1.]], dtype=float32)
-
symjax.tensor.
index_max
(x, idx, y, indices_are_sorted=False, unique_indices=False)[source]¶ Pure equivalent of
x[idx] = maximum(x[idx], y)
.Returns the value of x that would result from the NumPy-style
indexed assignment
:x[idx] = maximum(x[idx], y)
Note the index_max operator is pure; x itself is not modified, instead the new value that x would have taken is returned.
Unlike the NumPy code
x[idx] = maximum(x[idx], y)
, if multiple indices refer to the same location the final value will be the overall max. (NumPy would only look at the last update, rather than all of the updates.)Parameters: - x – an array with the values to be updated.
- idx – a Numpy-style index, consisting of None, integers, slice objects,
ellipses, ndarrays with integer dtypes, or a tuple of the above. A
convenient syntactic sugar for forming indices is via the
jax.ops.index
object. - y – the array of updates. y must be broadcastable to the shape of the array that would be returned by x[idx].
- indices_are_sorted – whether idx is known to be sorted
- unique_indices – whether idx is known to be free of duplicates
Returns: An array.
>>> x = jax.numpy.ones((5, 6)) >>> jax.ops.index_max(x, jax.ops.index[2:4, 3:], 6.) array([[1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1.], [1., 1., 1., 6., 6., 6.], [1., 1., 1., 6., 6., 6.], [1., 1., 1., 1., 1., 1.]], dtype=float32)
-
symjax.tensor.
index_in_dim
(operand: Any, index: int, axis: int = 0, keepdims: bool = True) → Any[source]¶ Convenience wrapper around slice to perform int indexing.
-
symjax.tensor.
dynamic_slice_in_dim
(operand: Any, start_index: Any, slice_size: int, axis: int = 0) → Any[source]¶ Convenience wrapper around dynamic_slice applying to one dimension.
-
symjax.tensor.
dynamic_slice
(operand: Any, start_indices: Sequence[Any], slice_sizes: Sequence[int]) → Any[source]¶ Wraps XLA’s DynamicSlice operator.
Parameters: - operand – an array to slice.
- start_indices – a list of scalar indices, one per dimension. These values may be dynamic.
- slice_sizes – the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).
Returns: An array containing the slice.
Extra¶
symjax.tensor.interpolation
¶
map_coordinates (input, coordinates, order[, …]) |
Map the input array to new coordinates by interpolation. |
upsample_1d (tensor, repeat[, axis, mode, …]) |
1-d upsampling of tensor |
hermite_1d (samples, knots, values, derivatives) |
Real interpolation with hermite cubic spline. |
hermite_2d (values, n_x, n_y) |
TODO: test and finalize this |
thin_plate_spline (input, dest_offsets[, …]) |
applies a thin plate spline transformation [2] on the input as in [1]_. |
affine_transform (input, theta[, order, …]) |
Spatial transformer layer The layer applies an affine transformation on the input. |
Detailed Descriptions¶
-
symjax.tensor.interpolation.
map_coordinates
(input, coordinates, order, mode='constant', cval=0.0)[source]¶ Map the input array to new coordinates by interpolation.
LAX-backend implementation of
map_coordinates()
. Only nearest neighbor (order=0
), linear interpolation (order=1
) and modes'constant'
,'nearest'
and'wrap'
are currently supported. Note that interpolation near boundaries differs from the scipy function, because we fixed an outstanding bug (https://github.com/scipy/scipy/issues/2640); this function interprets themode
argument as documented by SciPy, but not as implemented by SciPy.Original docstring below.
The array of coordinates is used to find, for each point in the output, the corresponding coordinates in the input. The value of the input at those coordinates is determined by spline interpolation of the requested order.
The shape of the output is derived from that of the coordinate array by dropping the first axis. The values of the array along the first axis are the coordinates in the input array at which the output value is found.
Parameters: - input (array_like) – The input array.
- coordinates (array_like) – The coordinates at which input is evaluated.
- order (int, optional) – The order of the spline interpolation, default is 3. The order has to be in the range 0-5.
- mode ({'reflect', 'constant', 'nearest', 'mirror', 'wrap'}, optional) – The mode parameter determines how the input array is extended beyond its boundaries. Default is ‘constant’. Behavior for each valid value is as follows:
- cval (scalar, optional) – Value to fill past edges of input if mode is ‘constant’. Default is 0.0.
Returns: map_coordinates – The result of transforming the input. The shape of the output is derived from that of coordinates by dropping the first axis.
Return type: ndarray
See also
spline_filter()
,geometric_transform()
,scipy.interpolate()
Examples
>>> from scipy import ndimage >>> a = np.arange(12.).reshape((4, 3)) >>> a array([[ 0., 1., 2.], [ 3., 4., 5.], [ 6., 7., 8.], [ 9., 10., 11.]]) >>> ndimage.map_coordinates(a, [[0.5, 2], [0.5, 1]], order=1) array([ 2., 7.])
Above, the interpolated value of a[0.5, 0.5] gives output[0], while a[2, 1] is output[1].
>>> inds = np.array([[0.5, 2], [0.5, 4]]) >>> ndimage.map_coordinates(a, inds, order=1, cval=-33.3) array([ 2. , -33.3]) >>> ndimage.map_coordinates(a, inds, order=1, mode='nearest') array([ 2., 8.]) >>> ndimage.map_coordinates(a, inds, order=1, cval=0, output=bool) array([ True, False], dtype=bool)
-
symjax.tensor.interpolation.
upsample_1d
(tensor, repeat, axis=-1, mode='constant', value=0.0, boundary_condition='periodic')[source]¶ 1-d upsampling of tensor
allow to upsample a tensor by an arbitrary (integer) amount on a given axis by applying a univariate upsampling strategy.
Parameters: - tensor (tensor) – the input tensor to upsample
- repeat (int) – the amount of new values ot insert between each value
- axis (int) – the axis to upsample
- mode (str) – the type of upsample to perform (linear, constant, nearest)
- value (float (default=0)) – the value ot use for the case of constant upsampling
-
symjax.tensor.interpolation.
hermite_1d
(samples, knots, values, derivatives)[source]¶ Real interpolation with hermite cubic spline.
Parameters: - knots (array-like) – The knots onto the function is defined (derivative and antiderivative) tensor of knots can be given in which case the shape is (…, N_KNOTS) the first dimensions are treated independently.
- samples (array-like) – The points where the interpolation is required of shape (TIME). If the shape is more, the first dimensions must be broadcastable agains knots.
- values (array-like) – The real values of amplitude onto knots, same shape as knots.
- derivative (array-like) – The real values of derivatives onto knots, same shape as knots.
Returns: yi – The interpolated real-valued function. :param derivatives:
Return type: array-like
-
symjax.tensor.interpolation.
hermite_2d
(values, n_x, n_y)[source]¶ TODO: test and finalize this
Parameters: - values (array-like) – the values, and 2 directional derivatives and the cross derivative for the 4 knots per region, hence it should be of shape n,N,M,4 values vx vy vxy
- n_x (int) – the number of points in x per region
- n_y (int) – the number of points in y per region
Returns: interpolation
Return type: array-like
-
symjax.tensor.interpolation.
thin_plate_spline
(input, dest_offsets, order=1, downsample_factor=1, border_mode='nearest')[source]¶ applies a thin plate spline transformation [2] on the input as in [1]_.
The thin plate spline transform is determined based on the movement of some number of control points. The starting positions for these control points are fixed. The output is interpolated with a bilinear transformation.
Implementation based on Lasagne
Parameters: - incoming (Tensor) – The input to be transformed, should be a 4D tensor, with shape
(batch_size, num_input_channels, input_rows, input_columns)
. - dest_offsets (Tensor) – The parameters of the thin plate spline
transformation as the x and y coordinates of the destination offsets of
each control point. This should be a
2D tensor, with shape
(batch_size, 2 * num_control_points)
. The number of control points to be used for the thin plate spline transformation. These points will be arranged as a grid along the image, so the value must be a perfect square. Default is 16. - order (int (default 1)) – The order of the interpolation
- downsample_factor (float or iterable of float) – A float or a 2-element tuple specifying the downsample factor for the output image (in both spatial dimensions). A value of 1 will keep the original size of the input. Values larger than 1 will downsample the input. Values below 1 will upsample the input.
- border_mode ('nearest', 'mirror', or 'wrap') – Determines how border conditions are handled during interpolation. If ‘nearest’, points outside the grid are clipped to the boundary’. If ‘mirror’, points are mirrored across the boundary. If ‘wrap’, points wrap around to the other side of the grid. See http://stackoverflow.com/q/22669252/22670830#22670830 for details.
References
[1] Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu (2015): Spatial Transformer Networks. NIPS 2015, http://papers.nips.cc/paper/5854-spatial-transformer-networks.pdf [2] (1, 2) Fred L. Bookstein (1989): Principal warps: thin-plate splines and the decomposition of deformations. IEEE Transactions on Pattern Analysis and Machine Intelligence. http://doi.org/10.1109/34.24792 - incoming (Tensor) – The input to be transformed, should be a 4D tensor, with shape
-
symjax.tensor.interpolation.
affine_transform
(input, theta, order=1, downsample_factor=1, border_mode='nearest')[source]¶ Spatial transformer layer The layer applies an affine transformation on the input. The affine transformation is parameterized with six learned parameters [1]_. The output is interpolated with a bilinear transformation if order is 1.
It is also convenient to interpret (each) \(\theta \in \mathbb{R}^6\) vector in term of an affine transformation of the image coordinates. In that case the 2d coordinates x are transformed to Ax+b with A a 2x2 matrix and b a 2d bias vector with
\[\begin{split}A = \begin{pmatrix} \theta [0]& \theta[1]\\ \theta[3]& \theta[4] \end{pmatrix}, b = \begin{pmatrix} \theta [2]\\ \theta[5] \end{pmatrix}.\end{split}\]Parameters: - incoming (
Tensor
) – The input which should be a 4D tensor, with shape(batch_size, num_input_channels, input_rows, input_columns)
. - theta (
Tensor
) – The parameters of the affine transformation. See the example for how to initialize to the identity transform. - order (int (default 1)) – The order of the interpolation
- downsample_factor (float or iterable of float) – A float or a 2-element tuple specifying the downsample factor for the output image (in both spatial dimensions). A value of 1 will keep the original size of the input. Values larger than 1 will downsample the input. Values below 1 will upsample the input.
- border_mode ('nearest', 'mirror', or 'wrap') – Determines how border conditions are handled during interpolation. If ‘nearest’, points outside the grid are clipped to the boundary. If ‘mirror’, points are mirrored across the boundary. If ‘wrap’, points wrap around to the other side of the grid. See http://stackoverflow.com/q/22669252/22670830#22670830 for details.
References
[1] Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu (2015): Spatial Transformer Networks. NIPS 2015, http://papers.nips.cc/paper/5854-spatial-transformer-networks.pdf - incoming (
symjax.tensor.signal
¶
Implementation of various signal processing related techniques such as time-frequency representations convolution/correlation/pooling operations, as well as various apodization windows and filter-banks creations.
Apodization Windows¶
blackman (*args, **kwargs) |
Return the Blackman window. |
bartlett (*args, **kwargs) |
Return the Bartlett window. |
hamming (*args, **kwargs) |
Return the Hamming window. |
hanning (*args, **kwargs) |
Return the Hanning window. |
kaiser (*args, **kwargs) |
Return the Kaiser window. |
tukey (M[, alpha]) |
Return a Tukey window, also known as a tapered cosine window. |
Time-Frequency Representations¶
mfcc (signal, window, hop, n_filter, …[, …]) |
https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#mfcc |
stft (signal, window, hop[, apod, nfft, mode]) |
Compute the Shoft-Time-Fourier-Transform of a signal given the window length, hop and additional parameters. |
dct (signal[, axes]) |
https://dsp.stackexchange.com/questions/2807/fast-cosine-transform-via-fft |
wvd (signal, window, hop, L[, apod, mode]) |
|
hilbert_transform (signal) |
the time should be the last dimension return the analytical signal |
Filters¶
fourier_complex_morlet (bandwidths, centers, N) |
Complex Morlet wavelet in Fourier |
complex_morlet (bandwidths, centers[, time]) |
Complex Morlet wavelet |
sinc_bandpass (time, f0, f1) |
ensure that f0<f1 and f0>0, f1<1 whenever time is …, -1, 0, 1, … |
mel_filterbank (length, n_filter, low, high, …) |
|
hat_1d |
Convolution/Correlation/Pooling¶
convolve (in1, in2[, mode, method, precision]) |
Convolve two N-dimensional arrays. |
batch_convolve (input, filter[, strides, …]) |
General n-dimensional batch convolution with dilations. |
convolve2d (in1, in2[, mode, boundary, …]) |
Convolve two 2-dimensional arrays. |
correlate (in1, in2[, mode, method, precision]) |
Cross-correlate two N-dimensional arrays. |
correlate2d (in1, in2[, mode, boundary, …]) |
Cross-correlate two 2-dimensional arrays. |
batch_pool |
Utilities¶
extract_signal_patches |
|
extract_image_patches (image, window_shape[, …]) |
extract patches from an input tensor |
Detailed Descriptions¶
-
symjax.tensor.signal.
blackman
(*args, **kwargs)¶ Return the Blackman window.
LAX-backend implementation of
blackman()
. Original docstring below.The Blackman window is a taper formed by using the first three terms of a summation of cosines. It was designed to have close to the minimal leakage possible. It is close to optimal, only slightly worse than a Kaiser window.
- M : int
- Number of points in the output window. If zero or less, an empty array is returned.
- out : ndarray
- The window, with the maximum value normalized to one (the value one appears only if the number of samples is odd).
bartlett, hamming, hanning, kaiser
The Blackman window is defined as
\[w(n) = 0.42 - 0.5 \cos(2\pi n/M) + 0.08 \cos(4\pi n/M)\]Most references to the Blackman window come from the signal processing literature, where it is used as one of many windowing functions for smoothing values. It is also known as an apodization (which means “removing the foot”, i.e. smoothing discontinuities at the beginning and end of the sampled signal) or tapering function. It is known as a “near optimal” tapering function, almost as good (by some measures) as the kaiser window.
Blackman, R.B. and Tukey, J.W., (1958) The measurement of power spectra, Dover Publications, New York.
Oppenheim, A.V., and R.W. Schafer. Discrete-Time Signal Processing. Upper Saddle River, NJ: Prentice-Hall, 1999, pp. 468-471.
>>> import matplotlib.pyplot as plt >>> np.blackman(12) array([-1.38777878e-17, 3.26064346e-02, 1.59903635e-01, # may vary 4.14397981e-01, 7.36045180e-01, 9.67046769e-01, 9.67046769e-01, 7.36045180e-01, 4.14397981e-01, 1.59903635e-01, 3.26064346e-02, -1.38777878e-17])
Plot the window and the frequency response:
>>> from numpy.fft import fft, fftshift >>> window = np.blackman(51) >>> plt.plot(window) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Blackman window") Text(0.5, 1.0, 'Blackman window') >>> plt.ylabel("Amplitude") Text(0, 0.5, 'Amplitude') >>> plt.xlabel("Sample") Text(0.5, 0, 'Sample') >>> plt.show()
>>> plt.figure() <Figure size 640x480 with 0 Axes> >>> A = fft(window, 2048) / 25.5 >>> mag = np.abs(fftshift(A)) >>> freq = np.linspace(-0.5, 0.5, len(A)) >>> with np.errstate(divide='ignore', invalid='ignore'): ... response = 20 * np.log10(mag) ... >>> response = np.clip(response, -100, 100) >>> plt.plot(freq, response) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Frequency response of Blackman window") Text(0.5, 1.0, 'Frequency response of Blackman window') >>> plt.ylabel("Magnitude [dB]") Text(0, 0.5, 'Magnitude [dB]') >>> plt.xlabel("Normalized frequency [cycles per sample]") Text(0.5, 0, 'Normalized frequency [cycles per sample]') >>> _ = plt.axis('tight') >>> plt.show()
-
symjax.tensor.signal.
bartlett
(*args, **kwargs)¶ Return the Bartlett window.
LAX-backend implementation of
bartlett()
. Original docstring below.The Bartlett window is very similar to a triangular window, except that the end points are at zero. It is often used in signal processing for tapering a signal, without generating too much ripple in the frequency domain.
- M : int
- Number of points in the output window. If zero or less, an empty array is returned.
- out : array
- The triangular window, with the maximum value normalized to one (the value one appears only if the number of samples is odd), with the first and last samples equal to zero.
blackman, hamming, hanning, kaiser
The Bartlett window is defined as
\[w(n) = \frac{2}{M-1} \left( \frac{M-1}{2} - \left|n - \frac{M-1}{2}\right| \right)\]Most references to the Bartlett window come from the signal processing literature, where it is used as one of many windowing functions for smoothing values. Note that convolution with this window produces linear interpolation. It is also known as an apodization (which means”removing the foot”, i.e. smoothing discontinuities at the beginning and end of the sampled signal) or tapering function. The fourier transform of the Bartlett is the product of two sinc functions. Note the excellent discussion in Kanasewich.
[1] M.S. Bartlett, “Periodogram Analysis and Continuous Spectra”, Biometrika 37, 1-16, 1950. [2] E.R. Kanasewich, “Time Sequence Analysis in Geophysics”, The University of Alberta Press, 1975, pp. 109-110. [3] A.V. Oppenheim and R.W. Schafer, “Discrete-Time Signal Processing”, Prentice-Hall, 1999, pp. 468-471. [4] Wikipedia, “Window function”, https://en.wikipedia.org/wiki/Window_function [5] W.H. Press, B.P. Flannery, S.A. Teukolsky, and W.T. Vetterling, “Numerical Recipes”, Cambridge University Press, 1986, page 429. >>> import matplotlib.pyplot as plt >>> np.bartlett(12) array([ 0. , 0.18181818, 0.36363636, 0.54545455, 0.72727273, # may vary 0.90909091, 0.90909091, 0.72727273, 0.54545455, 0.36363636, 0.18181818, 0. ])
Plot the window and its frequency response (requires SciPy and matplotlib):
>>> from numpy.fft import fft, fftshift >>> window = np.bartlett(51) >>> plt.plot(window) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Bartlett window") Text(0.5, 1.0, 'Bartlett window') >>> plt.ylabel("Amplitude") Text(0, 0.5, 'Amplitude') >>> plt.xlabel("Sample") Text(0.5, 0, 'Sample') >>> plt.show()
>>> plt.figure() <Figure size 640x480 with 0 Axes> >>> A = fft(window, 2048) / 25.5 >>> mag = np.abs(fftshift(A)) >>> freq = np.linspace(-0.5, 0.5, len(A)) >>> with np.errstate(divide='ignore', invalid='ignore'): ... response = 20 * np.log10(mag) ... >>> response = np.clip(response, -100, 100) >>> plt.plot(freq, response) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Frequency response of Bartlett window") Text(0.5, 1.0, 'Frequency response of Bartlett window') >>> plt.ylabel("Magnitude [dB]") Text(0, 0.5, 'Magnitude [dB]') >>> plt.xlabel("Normalized frequency [cycles per sample]") Text(0.5, 0, 'Normalized frequency [cycles per sample]') >>> _ = plt.axis('tight') >>> plt.show()
-
symjax.tensor.signal.
hamming
(*args, **kwargs)¶ Return the Hamming window.
LAX-backend implementation of
hamming()
. Original docstring below.The Hamming window is a taper formed by using a weighted cosine.
- M : int
- Number of points in the output window. If zero or less, an empty array is returned.
- out : ndarray
- The window, with the maximum value normalized to one (the value one appears only if the number of samples is odd).
bartlett, blackman, hanning, kaiser
The Hamming window is defined as
\[w(n) = 0.54 - 0.46cos\left(\frac{2\pi{n}}{M-1}\right) \qquad 0 \leq n \leq M-1\]The Hamming was named for R. W. Hamming, an associate of J. W. Tukey and is described in Blackman and Tukey. It was recommended for smoothing the truncated autocovariance function in the time domain. Most references to the Hamming window come from the signal processing literature, where it is used as one of many windowing functions for smoothing values. It is also known as an apodization (which means “removing the foot”, i.e. smoothing discontinuities at the beginning and end of the sampled signal) or tapering function.
[1] Blackman, R.B. and Tukey, J.W., (1958) The measurement of power spectra, Dover Publications, New York. [2] E.R. Kanasewich, “Time Sequence Analysis in Geophysics”, The University of Alberta Press, 1975, pp. 109-110. [3] Wikipedia, “Window function”, https://en.wikipedia.org/wiki/Window_function [4] W.H. Press, B.P. Flannery, S.A. Teukolsky, and W.T. Vetterling, “Numerical Recipes”, Cambridge University Press, 1986, page 425. >>> np.hamming(12) array([ 0.08 , 0.15302337, 0.34890909, 0.60546483, 0.84123594, # may vary 0.98136677, 0.98136677, 0.84123594, 0.60546483, 0.34890909, 0.15302337, 0.08 ])
Plot the window and the frequency response:
>>> import matplotlib.pyplot as plt >>> from numpy.fft import fft, fftshift >>> window = np.hamming(51) >>> plt.plot(window) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Hamming window") Text(0.5, 1.0, 'Hamming window') >>> plt.ylabel("Amplitude") Text(0, 0.5, 'Amplitude') >>> plt.xlabel("Sample") Text(0.5, 0, 'Sample') >>> plt.show()
>>> plt.figure() <Figure size 640x480 with 0 Axes> >>> A = fft(window, 2048) / 25.5 >>> mag = np.abs(fftshift(A)) >>> freq = np.linspace(-0.5, 0.5, len(A)) >>> response = 20 * np.log10(mag) >>> response = np.clip(response, -100, 100) >>> plt.plot(freq, response) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Frequency response of Hamming window") Text(0.5, 1.0, 'Frequency response of Hamming window') >>> plt.ylabel("Magnitude [dB]") Text(0, 0.5, 'Magnitude [dB]') >>> plt.xlabel("Normalized frequency [cycles per sample]") Text(0.5, 0, 'Normalized frequency [cycles per sample]') >>> plt.axis('tight') ... >>> plt.show()
-
symjax.tensor.signal.
hanning
(*args, **kwargs)¶ Return the Hanning window.
LAX-backend implementation of
hanning()
. Original docstring below.The Hanning window is a taper formed by using a weighted cosine.
- M : int
- Number of points in the output window. If zero or less, an empty array is returned.
- out : ndarray, shape(M,)
- The window, with the maximum value normalized to one (the value one appears only if M is odd).
bartlett, blackman, hamming, kaiser
The Hanning window is defined as
\[w(n) = 0.5 - 0.5cos\left(\frac{2\pi{n}}{M-1}\right) \qquad 0 \leq n \leq M-1\]The Hanning was named for Julius von Hann, an Austrian meteorologist. It is also known as the Cosine Bell. Some authors prefer that it be called a Hann window, to help avoid confusion with the very similar Hamming window.
Most references to the Hanning window come from the signal processing literature, where it is used as one of many windowing functions for smoothing values. It is also known as an apodization (which means “removing the foot”, i.e. smoothing discontinuities at the beginning and end of the sampled signal) or tapering function.
[1] Blackman, R.B. and Tukey, J.W., (1958) The measurement of power spectra, Dover Publications, New York. [2] E.R. Kanasewich, “Time Sequence Analysis in Geophysics”, The University of Alberta Press, 1975, pp. 106-108. [3] Wikipedia, “Window function”, https://en.wikipedia.org/wiki/Window_function [4] W.H. Press, B.P. Flannery, S.A. Teukolsky, and W.T. Vetterling, “Numerical Recipes”, Cambridge University Press, 1986, page 425. >>> np.hanning(12) array([0. , 0.07937323, 0.29229249, 0.57115742, 0.82743037, 0.97974649, 0.97974649, 0.82743037, 0.57115742, 0.29229249, 0.07937323, 0. ])
Plot the window and its frequency response:
>>> import matplotlib.pyplot as plt >>> from numpy.fft import fft, fftshift >>> window = np.hanning(51) >>> plt.plot(window) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Hann window") Text(0.5, 1.0, 'Hann window') >>> plt.ylabel("Amplitude") Text(0, 0.5, 'Amplitude') >>> plt.xlabel("Sample") Text(0.5, 0, 'Sample') >>> plt.show()
>>> plt.figure() <Figure size 640x480 with 0 Axes> >>> A = fft(window, 2048) / 25.5 >>> mag = np.abs(fftshift(A)) >>> freq = np.linspace(-0.5, 0.5, len(A)) >>> with np.errstate(divide='ignore', invalid='ignore'): ... response = 20 * np.log10(mag) ... >>> response = np.clip(response, -100, 100) >>> plt.plot(freq, response) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Frequency response of the Hann window") Text(0.5, 1.0, 'Frequency response of the Hann window') >>> plt.ylabel("Magnitude [dB]") Text(0, 0.5, 'Magnitude [dB]') >>> plt.xlabel("Normalized frequency [cycles per sample]") Text(0.5, 0, 'Normalized frequency [cycles per sample]') >>> plt.axis('tight') ... >>> plt.show()
-
symjax.tensor.signal.
kaiser
(*args, **kwargs)¶ Return the Kaiser window.
LAX-backend implementation of
kaiser()
. Original docstring below.The Kaiser window is a taper formed by using a Bessel function.
- M : int
- Number of points in the output window. If zero or less, an empty array is returned.
- beta : float
- Shape parameter for window.
- out : array
- The window, with the maximum value normalized to one (the value one appears only if the number of samples is odd).
bartlett, blackman, hamming, hanning
The Kaiser window is defined as
\[w(n) = I_0\left( \beta \sqrt{1-\frac{4n^2}{(M-1)^2}} \right)/I_0(\beta)\]with
\[\quad -\frac{M-1}{2} \leq n \leq \frac{M-1}{2},\]where \(I_0\) is the modified zeroth-order Bessel function.
The Kaiser was named for Jim Kaiser, who discovered a simple approximation to the DPSS window based on Bessel functions. The Kaiser window is a very good approximation to the Digital Prolate Spheroidal Sequence, or Slepian window, which is the transform which maximizes the energy in the main lobe of the window relative to total energy.
The Kaiser can approximate many other windows by varying the beta parameter.
beta Window shape 0 Rectangular 5 Similar to a Hamming 6 Similar to a Hanning 8.6 Similar to a Blackman A beta value of 14 is probably a good starting point. Note that as beta gets large, the window narrows, and so the number of samples needs to be large enough to sample the increasingly narrow spike, otherwise NaNs will get returned.
Most references to the Kaiser window come from the signal processing literature, where it is used as one of many windowing functions for smoothing values. It is also known as an apodization (which means “removing the foot”, i.e. smoothing discontinuities at the beginning and end of the sampled signal) or tapering function.
[1] J. F. Kaiser, “Digital Filters” - Ch 7 in “Systems analysis by digital computer”, Editors: F.F. Kuo and J.F. Kaiser, p 218-285. John Wiley and Sons, New York, (1966). [2] E.R. Kanasewich, “Time Sequence Analysis in Geophysics”, The University of Alberta Press, 1975, pp. 177-178. [3] Wikipedia, “Window function”, https://en.wikipedia.org/wiki/Window_function >>> import matplotlib.pyplot as plt >>> np.kaiser(12, 14) array([7.72686684e-06, 3.46009194e-03, 4.65200189e-02, # may vary 2.29737120e-01, 5.99885316e-01, 9.45674898e-01, 9.45674898e-01, 5.99885316e-01, 2.29737120e-01, 4.65200189e-02, 3.46009194e-03, 7.72686684e-06])
Plot the window and the frequency response:
>>> from numpy.fft import fft, fftshift >>> window = np.kaiser(51, 14) >>> plt.plot(window) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Kaiser window") Text(0.5, 1.0, 'Kaiser window') >>> plt.ylabel("Amplitude") Text(0, 0.5, 'Amplitude') >>> plt.xlabel("Sample") Text(0.5, 0, 'Sample') >>> plt.show()
>>> plt.figure() <Figure size 640x480 with 0 Axes> >>> A = fft(window, 2048) / 25.5 >>> mag = np.abs(fftshift(A)) >>> freq = np.linspace(-0.5, 0.5, len(A)) >>> response = 20 * np.log10(mag) >>> response = np.clip(response, -100, 100) >>> plt.plot(freq, response) [<matplotlib.lines.Line2D object at 0x...>] >>> plt.title("Frequency response of Kaiser window") Text(0.5, 1.0, 'Frequency response of Kaiser window') >>> plt.ylabel("Magnitude [dB]") Text(0, 0.5, 'Magnitude [dB]') >>> plt.xlabel("Normalized frequency [cycles per sample]") Text(0.5, 0, 'Normalized frequency [cycles per sample]') >>> plt.axis('tight') (-0.5, 0.5, -100.0, ...) # may vary >>> plt.show()
-
symjax.tensor.signal.
tukey
(M, alpha=0.5)[source]¶ Return a Tukey window, also known as a tapered cosine window. :param M: Number of points in the output window. If zero or less, an empty
array is returned.Parameters: alpha (float, optional) – Shape parameter of the Tukey window, representing the fraction of the window inside the cosine tapered region. If zero, the Tukey window is equivalent to a rectangular window. If one, the Tukey window is equivalent to a Hann window. Returns: w – The window, with the maximum value normalized to 1 (though the value 1 does not appear if M is even and sym is True). Return type: ndarray References
[1] Harris, Fredric J. (Jan 1978). “On the use of Windows for Harmonic Analysis with the Discrete Fourier Transform”. Proceedings of the IEEE 66 (1): 51-83. :doi:`10.1109/PROC.1978.10837` [2] Wikipedia, “Window function”, https://en.wikipedia.org/wiki/Window_function#Tukey_window
-
symjax.tensor.signal.
mfcc
(signal, window, hop, n_filter, low_freq, high_freq, nyquist, n_mfcc, nfft=None, mode='valid', apod=<function _wrap_numpy_nullary_function.<locals>.wrapper>)[source]¶ https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#mfcc
-
symjax.tensor.signal.
stft
(signal, window, hop, apod=<function ones>, nfft=None, mode='valid')[source]¶ Compute the Shoft-Time-Fourier-Transform of a signal given the window length, hop and additional parameters.
Parameters: - signal (array) – the signal (possibly stacked of signals)
- window (int) – the window length to be considered for the fft
- hop (int) – the amount by which the window is moved
- apod (func) – a function that takes an integer as inumpy.t and return the apodization window of the same length
- nfft (int (optional)) – the number of bin that the fft on the window will use. If not given it is set the same as window.
- mode ('valid', 'same' or 'full') – the padding of the inumpy.t signals
Returns: output – the complex stft
Return type: complex array
-
symjax.tensor.signal.
dct
(signal, axes=(-1, ))[source]¶ https://dsp.stackexchange.com/questions/2807/fast-cosine-transform-via-fft
-
symjax.tensor.signal.
wvd
(signal, window, hop, L, apod=<function _wrap_numpy_nullary_function.<locals>.wrapper>, mode='valid')[source]¶
-
symjax.tensor.signal.
hilbert_transform
(signal)[source]¶ the time should be the last dimension return the analytical signal
-
symjax.tensor.signal.
fourier_complex_morlet
(bandwidths, centers, N)[source]¶ Complex Morlet wavelet in Fourier
Parameters: - bandwidths (array) – the bandwidth of the wavelet
- centers (array) – the centers of the wavelet
- freqs (array (optional)) – the frequency sampling in radion going from 0 to pi and back to 0 :param N:
-
symjax.tensor.signal.
complex_morlet
(bandwidths, centers, time=None)[source]¶ Complex Morlet wavelet
It corresponds to with (B, C):
\phi(t) =
rac{1}{pi B} e^{- rac{t^2}{B}}e^{j2pi C t}
For a filter bank do
J = 8 Q = 1 scales = T.power(2,T.linspace(0, J, J*Q)) scales = scales[:, None] complex_morlet(scales, 1/scales)
- bandwidths: array
- the bandwidth of the wavelet
- centers: array
- the centers of the wavelet
- time: array (optional)
- the time sampling
- wavelet: array like
- the wavelet centered at 0
-
symjax.tensor.signal.
sinc_bandpass
(time, f0, f1)[source]¶ ensure that f0<f1 and f0>0, f1<1 whenever time is …, -1, 0, 1, …
-
symjax.tensor.signal.
convolve
(in1, in2, mode='full', method='auto', precision=None)[source]¶ Convolve two N-dimensional arrays.
LAX-backend implementation of
convolve()
. Original docstring below.Convolve in1 and in2, with the output size determined by the mode argument.
Parameters: - in1 (array_like) – First input.
- in2 (array_like) – Second input. Should have the same number of dimensions as in1.
- mode (str {'full', 'valid', 'same'}, optional) – A string indicating the size of the output:
- method (str {'auto', 'direct', 'fft'}, optional) – A string indicating which method to use to calculate the convolution.
Returns: convolve – An N-dimensional array containing a subset of the discrete linear convolution of in1 with in2.
Return type: array
See also
numpy.polymul()
- performs polynomial multiplication (same operation, but also accepts poly1d objects)
choose_conv_method()
- chooses the fastest appropriate convolution method
fftconvolve()
- Always uses the FFT method.
oaconvolve()
- Uses the overlap-add method to do convolution, which is generally faster when the input arrays are large and significantly different in size.
Notes
By default, convolve and correlate use
method='auto'
, which calls choose_conv_method to choose the fastest method using pre-computed values (choose_conv_method can also measure real-world timing with a keyword argument). Because fftconvolve relies on floating point numbers, there are certain constraints that may force method=direct (more detail in choose_conv_method docstring).Examples
Smooth a square pulse using a Hann window:
>>> from scipy import signal >>> sig = np.repeat([0., 1., 0.], 100) >>> win = signal.hann(50) >>> filtered = signal.convolve(sig, win, mode='same') / sum(win)
>>> import matplotlib.pyplot as plt >>> fig, (ax_orig, ax_win, ax_filt) = plt.subplots(3, 1, sharex=True) >>> ax_orig.plot(sig) >>> ax_orig.set_title('Original pulse') >>> ax_orig.margins(0, 0.1) >>> ax_win.plot(win) >>> ax_win.set_title('Filter impulse response') >>> ax_win.margins(0, 0.1) >>> ax_filt.plot(filtered) >>> ax_filt.set_title('Filtered signal') >>> ax_filt.margins(0, 0.1) >>> fig.tight_layout() >>> fig.show()
-
symjax.tensor.signal.
batch_convolve
(input, filter, strides=1, padding='VALID', input_format=None, filter_format=None, output_format=None, input_dilation=None, filter_dilation=None)[source]¶ General n-dimensional batch convolution with dilations.
Wraps Jax’s conv_general_dilated functin, and thus also the XLA’s Conv operator.
Parameters: - input (Tensor) – a rank n+2 dimensional input array.
- filter (Tensor) – a rank n+2 dimensional array of kernel weights.
- strides (int, sequence of int, optional) – a (sequence) of n integers, representing the inter-window strides. If a scalar is given, it is used n times. Defaults to 1.
- padding (sequence of couple, ‘SAME’, ‘VALID’, optional) – a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. For ‘VALID’, those are 0. For ‘SAME’, they are the input length - filter length + 1 for each dim. Defaults to ‘Valid’.
- input_format (None or str, optional) –
a string of same length as the number of dimensions in input which specify their role (see below). Defaults to ‘NCW’ for 1d conv, ‘NCHW’ for 2d conv,
and ‘NDCHW’ for 3d conv. - input_dilation (None, int or sequence of int, optional) – giving the dilation factor to apply in each spatial dimension of input. Inumpy.t dilation is also known as transposed convolution as it allows to increase the output spatial dimension by inserting in the input any number of `0`s between each spatial value.
- filter_dilation (None, int or sequence of int) – giving the dilation factor to apply in each spatial dimension of filter. Filter dilation is also known as atrous convolution as it corresponds to inserting any number of `0`s in between the filter values, similar to performing the non-dilated filter convolution with a subsample version of the input across the spatial dimensions.
Returns: An array containing the convolution result.
Return type: Tensor
Format of input, filter and output: For example, to indicate dimension numbers consistent with the conv function with two spatial dimensions, one could use (‘NCHW’, ‘OIHW’, ‘NCHW’). As another example, to indicate dimension numbers consistent with the TensorFlow Conv2D operation, one could use (‘NHWC’, ‘HWIO’, ‘NHWC’). When using the latter form of convolution dimension specification, window strides are associated with spatial dimension character labels according to the order in which the labels appear in the rhs_spec string, so that window_strides[0] is matched with the dimension corresponding to the first character appearing in rhs_spec that is not ‘I’ or ‘O’. :param filter_format: :param output_format:
-
symjax.tensor.signal.
convolve2d
(in1, in2, mode='full', boundary='fill', fillvalue=0, precision=None)[source]¶ Convolve two 2-dimensional arrays.
LAX-backend implementation of
convolve2d()
. Original docstring below.Convolve in1 and in2 with output size determined by mode, and boundary conditions determined by boundary and fillvalue.
Parameters: - in1 (array_like) – First input.
- in2 (array_like) – Second input. Should have the same number of dimensions as in1.
- mode (str {'full', 'valid', 'same'}, optional) – A string indicating the size of the output:
- boundary (str {'fill', 'wrap', 'symm'}, optional) – A flag indicating how to handle boundaries:
- fillvalue (scalar, optional) – Value to fill pad input arrays with. Default is 0.
Returns: out – A 2-dimensional array containing a subset of the discrete linear convolution of in1 with in2.
Return type: ndarray
Examples
Compute the gradient of an image by 2D convolution with a complex Scharr operator. (Horizontal operator is real, vertical is imaginary.) Use symmetric boundary condition to avoid creating edges at the image boundaries.
>>> from scipy import signal >>> from scipy import misc >>> ascent = misc.ascent() >>> scharr = np.array([[ -3-3j, 0-10j, +3 -3j], ... [-10+0j, 0+ 0j, +10 +0j], ... [ -3+3j, 0+10j, +3 +3j]]) # Gx + j*Gy >>> grad = signal.convolve2d(ascent, scharr, boundary='symm', mode='same')
>>> import matplotlib.pyplot as plt >>> fig, (ax_orig, ax_mag, ax_ang) = plt.subplots(3, 1, figsize=(6, 15)) >>> ax_orig.imshow(ascent, cmap='gray') >>> ax_orig.set_title('Original') >>> ax_orig.set_axis_off() >>> ax_mag.imshow(np.absolute(grad), cmap='gray') >>> ax_mag.set_title('Gradient magnitude') >>> ax_mag.set_axis_off() >>> ax_ang.imshow(np.angle(grad), cmap='hsv') # hsv is cyclic, like angles >>> ax_ang.set_title('Gradient orientation') >>> ax_ang.set_axis_off() >>> fig.show()
-
symjax.tensor.signal.
correlate
(in1, in2, mode='full', method='auto', precision=None)[source]¶ Cross-correlate two N-dimensional arrays.
LAX-backend implementation of
correlate()
. Original docstring below.Cross-correlate in1 and in2, with the output size determined by the mode argument.
Parameters: - in1 (array_like) – First input.
- in2 (array_like) – Second input. Should have the same number of dimensions as in1.
- mode (str {'full', 'valid', 'same'}, optional) – A string indicating the size of the output:
- method (str {'auto', 'direct', 'fft'}, optional) – A string indicating which method to use to calculate the correlation.
Returns: correlate – An N-dimensional array containing a subset of the discrete linear cross-correlation of in1 with in2.
Return type: array
See also
choose_conv_method()
- contains more documentation on method.
Notes
The correlation z of two d-dimensional arrays x and y is defined as:
z[...,k,...] = sum[..., i_l, ...] x[..., i_l,...] * conj(y[..., i_l - k,...])
This way, if x and y are 1-D arrays and
z = correlate(x, y, 'full')
then\[z[k] = (x * y)(k - N + 1) = \sum_{l=0}^{||x||-1}x_l y_{l-k+N-1}^{*}\]for \(k = 0, 1, ..., ||x|| + ||y|| - 2\)
where \(||x||\) is the length of
x
, \(N = \max(||x||,||y||)\), and \(y_m\) is 0 when m is outside the range of y.method='fft'
only works for numerical arrays as it relies on fftconvolve. In certain cases (i.e., arrays of objects or when rounding integers can lose precision),method='direct'
is always used.When using “same” mode with even-length inputs, the outputs of correlate and correlate2d differ: There is a 1-index offset between them.
Examples
Implement a matched filter using cross-correlation, to recover a signal that has passed through a noisy channel.
>>> from scipy import signal >>> sig = np.repeat([0., 1., 1., 0., 1., 0., 0., 1.], 128) >>> sig_noise = sig + np.random.randn(len(sig)) >>> corr = signal.correlate(sig_noise, np.ones(128), mode='same') / 128
>>> import matplotlib.pyplot as plt >>> clock = np.arange(64, len(sig), 128) >>> fig, (ax_orig, ax_noise, ax_corr) = plt.subplots(3, 1, sharex=True) >>> ax_orig.plot(sig) >>> ax_orig.plot(clock, sig[clock], 'ro') >>> ax_orig.set_title('Original signal') >>> ax_noise.plot(sig_noise) >>> ax_noise.set_title('Signal with noise') >>> ax_corr.plot(corr) >>> ax_corr.plot(clock, corr[clock], 'ro') >>> ax_corr.axhline(0.5, ls=':') >>> ax_corr.set_title('Cross-correlated with rectangular pulse') >>> ax_orig.margins(0, 0.1) >>> fig.tight_layout() >>> fig.show()
-
symjax.tensor.signal.
correlate2d
(in1, in2, mode='full', boundary='fill', fillvalue=0, precision=None)[source]¶ Cross-correlate two 2-dimensional arrays.
LAX-backend implementation of
correlate2d()
. Original docstring below.Cross correlate in1 and in2 with output size determined by mode, and boundary conditions determined by boundary and fillvalue.
Parameters: - in1 (array_like) – First input.
- in2 (array_like) – Second input. Should have the same number of dimensions as in1.
- mode (str {'full', 'valid', 'same'}, optional) – A string indicating the size of the output:
- boundary (str {'fill', 'wrap', 'symm'}, optional) – A flag indicating how to handle boundaries:
- fillvalue (scalar, optional) – Value to fill pad input arrays with. Default is 0.
Returns: correlate2d – A 2-dimensional array containing a subset of the discrete linear cross-correlation of in1 with in2.
Return type: ndarray
Notes
When using “same” mode with even-length inputs, the outputs of correlate and correlate2d differ: There is a 1-index offset between them.
Examples
Use 2D cross-correlation to find the location of a template in a noisy image:
>>> from scipy import signal >>> from scipy import misc >>> face = misc.face(gray=True) - misc.face(gray=True).mean() >>> template = np.copy(face[300:365, 670:750]) # right eye >>> template -= template.mean() >>> face = face + np.random.randn(*face.shape) * 50 # add noise >>> corr = signal.correlate2d(face, template, boundary='symm', mode='same') >>> y, x = np.unravel_index(np.argmax(corr), corr.shape) # find the match
>>> import matplotlib.pyplot as plt >>> fig, (ax_orig, ax_template, ax_corr) = plt.subplots(3, 1, ... figsize=(6, 15)) >>> ax_orig.imshow(face, cmap='gray') >>> ax_orig.set_title('Original') >>> ax_orig.set_axis_off() >>> ax_template.imshow(template, cmap='gray') >>> ax_template.set_title('Template') >>> ax_template.set_axis_off() >>> ax_corr.imshow(corr, cmap='gray') >>> ax_corr.set_title('Cross-correlation') >>> ax_corr.set_axis_off() >>> ax_orig.plot(x, y, 'ro') >>> fig.show()
-
symjax.tensor.signal.
extract_signal_patches
()¶
-
symjax.tensor.signal.
extract_image_patches
(image, window_shape, strides=1, data_format='channel_first', padding='valid', flatten_patch=False)¶ extract patches from an input tensor
- image: Tensor-like
- the input to extract patches from
- window_shape: int or tuple of ints
- the spatial shape of the patch to extract
- hop: int or tuple of ints
- the spatial hop of the patch to extract
- data_format: str
- either
channel_first
orchannel_last
- padding: str
- either
same
orvalid
- flatten_patch: bool
- whether to return patches as flattened or not
symjax.tensor.fft
¶
fft (a[, n, axis, norm]) |
Compute the one-dimensional discrete Fourier Transform. |
ifft (a[, n, axis, norm]) |
Compute the one-dimensional inverse discrete Fourier Transform. |
fft2 (a[, s, axes, norm]) |
Compute the 2-dimensional discrete Fourier Transform |
ifft2 (a[, s, axes, norm]) |
Compute the 2-dimensional inverse discrete Fourier Transform. |
fftn (a[, s, axes, norm]) |
Compute the N-dimensional discrete Fourier Transform. |
ifftn (a[, s, axes, norm]) |
Compute the N-dimensional inverse discrete Fourier Transform. |
rfft (a[, n, axis, norm]) |
Compute the one-dimensional discrete Fourier Transform for real input. |
irfft (a[, n, axis, norm]) |
Compute the inverse of the n-point DFT for real input. |
rfft2 (a[, s, axes, norm]) |
Compute the 2-dimensional FFT of a real array. |
irfft2 (a[, s, axes, norm]) |
Compute the 2-dimensional inverse FFT of a real array. |
rfftn (a[, s, axes, norm]) |
Compute the N-dimensional discrete Fourier Transform for real input. |
irfftn (a[, s, axes, norm]) |
Compute the inverse of the N-dimensional FFT of real input. |
fftfreq (n[, d]) |
Return the Discrete Fourier Transform sample frequencies. |
rfftfreq (n[, d]) |
Return the Discrete Fourier Transform sample frequencies |
Detailed Descriptions¶
-
symjax.tensor.fft.
fft
(a, n=None, axis=-1, norm=None)[source]¶ Compute the one-dimensional discrete Fourier Transform.
LAX-backend implementation of
fft()
. Original docstring below.This function computes the one-dimensional n-point discrete Fourier Transform (DFT) with the efficient Fast Fourier Transform (FFT) algorithm [CT].
Parameters: - a (array_like) – Input array, can be complex.
- n (int, optional) – Length of the transformed axis of the output. If n is smaller than the length of the input, the input is cropped. If it is larger, the input is padded with zeros. If n is not given, the length of the input along the axis specified by axis is used.
- axis (int, optional) – Axis over which to compute the FFT. If not given, the last axis is used.
- norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The truncated or zero-padded input, transformed along the axis indicated by axis, or the last one if axis is not specified.
Return type: complex ndarray
Raises: IndexError
– if axes is larger than the last axis of a.See also
Notes
FFT (Fast Fourier Transform) refers to a way the discrete Fourier Transform (DFT) can be calculated efficiently, by using symmetries in the calculated terms. The symmetry is highest when n is a power of 2, and the transform is therefore most efficient for these sizes.
The DFT is defined, with the conventions used in this implementation, in the documentation for the numpy.fft module.
References
[CT] Cooley, James W., and John W. Tukey, 1965, “An algorithm for the machine calculation of complex Fourier series,” Math. Comput. 19: 297-301. Examples
>>> np.fft.fft(np.exp(2j * np.pi * np.arange(8) / 8)) array([-2.33486982e-16+1.14423775e-17j, 8.00000000e+00-1.25557246e-15j, 2.33486982e-16+2.33486982e-16j, 0.00000000e+00+1.22464680e-16j, -1.14423775e-17+2.33486982e-16j, 0.00000000e+00+5.20784380e-16j, 1.14423775e-17+1.14423775e-17j, 0.00000000e+00+1.22464680e-16j])
In this example, real input has an FFT which is Hermitian, i.e., symmetric in the real part and anti-symmetric in the imaginary part, as described in the numpy.fft documentation:
>>> import matplotlib.pyplot as plt >>> t = np.arange(256) >>> sp = np.fft.fft(np.sin(t)) >>> freq = np.fft.fftfreq(t.shape[-1]) >>> plt.plot(freq, sp.real, freq, sp.imag) [<matplotlib.lines.Line2D object at 0x...>, <matplotlib.lines.Line2D object at 0x...>] >>> plt.show()
-
symjax.tensor.fft.
ifft
(a, n=None, axis=-1, norm=None)[source]¶ Compute the one-dimensional inverse discrete Fourier Transform.
LAX-backend implementation of
ifft()
. Original docstring below.This function computes the inverse of the one-dimensional n-point discrete Fourier transform computed by fft. In other words,
ifft(fft(a)) == a
to within numerical accuracy. For a general description of the algorithm and definitions, see numpy.fft.The input should be ordered in the same way as is returned by fft, i.e.,
a[0]
should contain the zero frequency term,a[1:n//2]
should contain the positive-frequency terms,a[n//2 + 1:]
should contain the negative-frequency terms, in increasing order starting from the most negative frequency.
For an even number of input points,
A[n//2]
represents the sum of the values at the positive and negative Nyquist frequencies, as the two are aliased together. See numpy.fft for details.Parameters: - a (array_like) – Input array, can be complex.
- n (int, optional) – Length of the transformed axis of the output. If n is smaller than the length of the input, the input is cropped. If it is larger, the input is padded with zeros. If n is not given, the length of the input along the axis specified by axis is used. See notes about padding issues.
- axis (int, optional) – Axis over which to compute the inverse DFT. If not given, the last axis is used.
- norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The truncated or zero-padded input, transformed along the axis indicated by axis, or the last one if axis is not specified.
Return type: complex ndarray
Raises: IndexError
– If axes is larger than the last axis of a.See also
Notes
If the input parameter n is larger than the size of the input, the input is padded by appending zeros at the end. Even though this is the common approach, it might lead to surprising results. If a different padding is desired, it must be performed before calling ifft.
Examples
>>> np.fft.ifft([0, 4, 0, 0]) array([ 1.+0.j, 0.+1.j, -1.+0.j, 0.-1.j]) # may vary
Create and plot a band-limited signal with random phases:
>>> import matplotlib.pyplot as plt >>> t = np.arange(400) >>> n = np.zeros((400,), dtype=complex) >>> n[40:60] = np.exp(1j*np.random.uniform(0, 2*np.pi, (20,))) >>> s = np.fft.ifft(n) >>> plt.plot(t, s.real, 'b-', t, s.imag, 'r--') [<matplotlib.lines.Line2D object at ...>, <matplotlib.lines.Line2D object at ...>] >>> plt.legend(('real', 'imaginary')) <matplotlib.legend.Legend object at ...> >>> plt.show()
-
symjax.tensor.fft.
fft2
(a, s=None, axes=(-2, -1), norm=None)[source]¶ Compute the 2-dimensional discrete Fourier Transform
LAX-backend implementation of
fft2()
. Original docstring below.This function computes the n-dimensional discrete Fourier Transform over any axes in an M-dimensional array by means of the Fast Fourier Transform (FFT). By default, the transform is computed over the last two axes of the input array, i.e., a 2-dimensional FFT.
Parameters: - a (array_like) – Input array, can be complex
- s (sequence of ints, optional) – Shape (length of each transformed axis) of the output
(
s[0]
refers to axis 0,s[1]
to axis 1, etc.). This corresponds ton
forfft(x, n)
. Along each axis, if the given shape is smaller than that of the input, the input is cropped. If it is larger, the input is padded with zeros. if s is not given, the shape of the input along the axes specified by axes is used. - axes (sequence of ints, optional) – Axes over which to compute the FFT. If not given, the last two axes are used. A repeated index in axes means the transform over that axis is performed multiple times. A one-element sequence means that a one-dimensional FFT is performed.
- norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The truncated or zero-padded input, transformed along the axes indicated by axes, or the last two axes if axes is not given.
Return type: complex ndarray
Raises: ValueError
– If s and axes have different length, or axes not given andlen(s) != 2
.IndexError
– If an element of axes is larger than than the number of axes of a.
See also
numpy.fft()
- Overall view of discrete Fourier transforms, with definitions and conventions used.
ifft2()
- The inverse two-dimensional FFT.
fft()
- The one-dimensional FFT.
fftn()
- The n-dimensional FFT.
fftshift()
- Shifts zero-frequency terms to the center of the array. For two-dimensional input, swaps first and third quadrants, and second and fourth quadrants.
Notes
fft2 is just fftn with a different default for axes.
The output, analogously to fft, contains the term for zero frequency in the low-order corner of the transformed axes, the positive frequency terms in the first half of these axes, the term for the Nyquist frequency in the middle of the axes and the negative frequency terms in the second half of the axes, in order of decreasingly negative frequency.
See fftn for details and a plotting example, and numpy.fft for definitions and conventions used.
Examples
>>> a = np.mgrid[:5, :5][0] >>> np.fft.fft2(a) array([[ 50. +0.j , 0. +0.j , 0. +0.j , # may vary 0. +0.j , 0. +0.j ], [-12.5+17.20477401j, 0. +0.j , 0. +0.j , 0. +0.j , 0. +0.j ], [-12.5 +4.0614962j , 0. +0.j , 0. +0.j , 0. +0.j , 0. +0.j ], [-12.5 -4.0614962j , 0. +0.j , 0. +0.j , 0. +0.j , 0. +0.j ], [-12.5-17.20477401j, 0. +0.j , 0. +0.j , 0. +0.j , 0. +0.j ]])
-
symjax.tensor.fft.
ifft2
(a, s=None, axes=(-2, -1), norm=None)[source]¶ Compute the 2-dimensional inverse discrete Fourier Transform.
LAX-backend implementation of
ifft2()
. Original docstring below.This function computes the inverse of the 2-dimensional discrete Fourier Transform over any number of axes in an M-dimensional array by means of the Fast Fourier Transform (FFT). In other words,
ifft2(fft2(a)) == a
to within numerical accuracy. By default, the inverse transform is computed over the last two axes of the input array.The input, analogously to ifft, should be ordered in the same way as is returned by fft2, i.e. it should have the term for zero frequency in the low-order corner of the two axes, the positive frequency terms in the first half of these axes, the term for the Nyquist frequency in the middle of the axes and the negative frequency terms in the second half of both axes, in order of decreasingly negative frequency.
Parameters: - a (array_like) – Input array, can be complex.
- s (sequence of ints, optional) – Shape (length of each axis) of the output (
s[0]
refers to axis 0,s[1]
to axis 1, etc.). This corresponds to n forifft(x, n)
. Along each axis, if the given shape is smaller than that of the input, the input is cropped. If it is larger, the input is padded with zeros. if s is not given, the shape of the input along the axes specified by axes is used. See notes for issue on ifft zero padding. - axes (sequence of ints, optional) – Axes over which to compute the FFT. If not given, the last two axes are used. A repeated index in axes means the transform over that axis is performed multiple times. A one-element sequence means that a one-dimensional FFT is performed.
- norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The truncated or zero-padded input, transformed along the axes indicated by axes, or the last two axes if axes is not given.
Return type: complex ndarray
Raises: ValueError
– If s and axes have different length, or axes not given andlen(s) != 2
.IndexError
– If an element of axes is larger than than the number of axes of a.
See also
Notes
ifft2 is just ifftn with a different default for axes.
See ifftn for details and a plotting example, and numpy.fft for definition and conventions used.
Zero-padding, analogously with ifft, is performed by appending zeros to the input along the specified dimension. Although this is the common approach, it might lead to surprising results. If another form of zero padding is desired, it must be performed before ifft2 is called.
Examples
>>> a = 4 * np.eye(4) >>> np.fft.ifft2(a) array([[1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], # may vary [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j], [0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j], [0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j]])
-
symjax.tensor.fft.
fftn
(a, s=None, axes=None, norm=None)[source]¶ Compute the N-dimensional discrete Fourier Transform.
LAX-backend implementation of
fftn()
. Original docstring below.This function computes the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional array by means of the Fast Fourier Transform (FFT).
Parameters: - a (array_like) – Input array, can be complex.
- s (sequence of ints, optional) – Shape (length of each transformed axis) of the output
(
s[0]
refers to axis 0,s[1]
to axis 1, etc.). This corresponds ton
forfft(x, n)
. Along any axis, if the given shape is smaller than that of the input, the input is cropped. If it is larger, the input is padded with zeros. if s is not given, the shape of the input along the axes specified by axes is used. - axes (sequence of ints, optional) – Axes over which to compute the FFT. If not given, the last
len(s)
axes are used, or all axes if s is also not specified. Repeated indices in axes means that the transform over that axis is performed multiple times. - norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The truncated or zero-padded input, transformed along the axes indicated by axes, or by a combination of s and a, as explained in the parameters section above.
Return type: complex ndarray
Raises: ValueError
– If s and axes have different length.IndexError
– If an element of axes is larger than than the number of axes of a.
See also
numpy.fft()
- Overall view of discrete Fourier transforms, with definitions and conventions used.
ifftn()
- The inverse of fftn, the inverse n-dimensional FFT.
fft()
- The one-dimensional FFT, with definitions and conventions used.
rfftn()
- The n-dimensional FFT of real input.
fft2()
- The two-dimensional FFT.
fftshift()
- Shifts zero-frequency terms to centre of array
Notes
The output, analogously to fft, contains the term for zero frequency in the low-order corner of all axes, the positive frequency terms in the first half of all axes, the term for the Nyquist frequency in the middle of all axes and the negative frequency terms in the second half of all axes, in order of decreasingly negative frequency.
See numpy.fft for details, definitions and conventions used.
Examples
>>> a = np.mgrid[:3, :3, :3][0] >>> np.fft.fftn(a, axes=(1, 2)) array([[[ 0.+0.j, 0.+0.j, 0.+0.j], # may vary [ 0.+0.j, 0.+0.j, 0.+0.j], [ 0.+0.j, 0.+0.j, 0.+0.j]], [[ 9.+0.j, 0.+0.j, 0.+0.j], [ 0.+0.j, 0.+0.j, 0.+0.j], [ 0.+0.j, 0.+0.j, 0.+0.j]], [[18.+0.j, 0.+0.j, 0.+0.j], [ 0.+0.j, 0.+0.j, 0.+0.j], [ 0.+0.j, 0.+0.j, 0.+0.j]]]) >>> np.fft.fftn(a, (2, 2), axes=(0, 1)) array([[[ 2.+0.j, 2.+0.j, 2.+0.j], # may vary [ 0.+0.j, 0.+0.j, 0.+0.j]], [[-2.+0.j, -2.+0.j, -2.+0.j], [ 0.+0.j, 0.+0.j, 0.+0.j]]])
>>> import matplotlib.pyplot as plt >>> [X, Y] = np.meshgrid(2 * np.pi * np.arange(200) / 12, ... 2 * np.pi * np.arange(200) / 34) >>> S = np.sin(X) + np.cos(Y) + np.random.uniform(0, 1, X.shape) >>> FS = np.fft.fftn(S) >>> plt.imshow(np.log(np.abs(np.fft.fftshift(FS))**2)) <matplotlib.image.AxesImage object at 0x...> >>> plt.show()
-
symjax.tensor.fft.
ifftn
(a, s=None, axes=None, norm=None)[source]¶ Compute the N-dimensional inverse discrete Fourier Transform.
LAX-backend implementation of
ifftn()
. Original docstring below.This function computes the inverse of the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional array by means of the Fast Fourier Transform (FFT). In other words,
ifftn(fftn(a)) == a
to within numerical accuracy. For a description of the definitions and conventions used, see numpy.fft.The input, analogously to ifft, should be ordered in the same way as is returned by fftn, i.e. it should have the term for zero frequency in all axes in the low-order corner, the positive frequency terms in the first half of all axes, the term for the Nyquist frequency in the middle of all axes and the negative frequency terms in the second half of all axes, in order of decreasingly negative frequency.
Parameters: - a (array_like) – Input array, can be complex.
- s (sequence of ints, optional) – Shape (length of each transformed axis) of the output
(
s[0]
refers to axis 0,s[1]
to axis 1, etc.). This corresponds ton
forifft(x, n)
. Along any axis, if the given shape is smaller than that of the input, the input is cropped. If it is larger, the input is padded with zeros. if s is not given, the shape of the input along the axes specified by axes is used. See notes for issue on ifft zero padding. - axes (sequence of ints, optional) – Axes over which to compute the IFFT. If not given, the last
len(s)
axes are used, or all axes if s is also not specified. Repeated indices in axes means that the inverse transform over that axis is performed multiple times. - norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The truncated or zero-padded input, transformed along the axes indicated by axes, or by a combination of s or a, as explained in the parameters section above.
Return type: complex ndarray
Raises: ValueError
– If s and axes have different length.IndexError
– If an element of axes is larger than than the number of axes of a.
See also
numpy.fft()
- Overall view of discrete Fourier transforms, with definitions and conventions used.
fftn()
- The forward n-dimensional FFT, of which ifftn is the inverse.
ifft()
- The one-dimensional inverse FFT.
ifft2()
- The two-dimensional inverse FFT.
ifftshift()
- Undoes fftshift, shifts zero-frequency terms to beginning of array.
Notes
See numpy.fft for definitions and conventions used.
Zero-padding, analogously with ifft, is performed by appending zeros to the input along the specified dimension. Although this is the common approach, it might lead to surprising results. If another form of zero padding is desired, it must be performed before ifftn is called.
Examples
>>> a = np.eye(4) >>> np.fft.ifftn(np.fft.fftn(a, axes=(0,)), axes=(1,)) array([[1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], # may vary [0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j], [0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j], [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j]])
Create and plot an image with band-limited frequency content:
>>> import matplotlib.pyplot as plt >>> n = np.zeros((200,200), dtype=complex) >>> n[60:80, 20:40] = np.exp(1j*np.random.uniform(0, 2*np.pi, (20, 20))) >>> im = np.fft.ifftn(n).real >>> plt.imshow(im) <matplotlib.image.AxesImage object at 0x...> >>> plt.show()
-
symjax.tensor.fft.
rfft
(a, n=None, axis=-1, norm=None)[source]¶ Compute the one-dimensional discrete Fourier Transform for real input.
LAX-backend implementation of
rfft()
. Original docstring below.This function computes the one-dimensional n-point discrete Fourier Transform (DFT) of a real-valued array by means of an efficient algorithm called the Fast Fourier Transform (FFT).
Parameters: - a (array_like) – Input array
- n (int, optional) – Number of points along transformation axis in the input to use. If n is smaller than the length of the input, the input is cropped. If it is larger, the input is padded with zeros. If n is not given, the length of the input along the axis specified by axis is used.
- axis (int, optional) – Axis over which to compute the FFT. If not given, the last axis is used.
- norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The truncated or zero-padded input, transformed along the axis indicated by axis, or the last one if axis is not specified. If n is even, the length of the transformed axis is
(n/2)+1
. If n is odd, the length is(n+1)/2
.Return type: complex ndarray
Raises: IndexError
– If axis is larger than the last axis of a.See also
Notes
When the DFT is computed for purely real input, the output is Hermitian-symmetric, i.e. the negative frequency terms are just the complex conjugates of the corresponding positive-frequency terms, and the negative-frequency terms are therefore redundant. This function does not compute the negative frequency terms, and the length of the transformed axis of the output is therefore
n//2 + 1
.When
A = rfft(a)
and fs is the sampling frequency,A[0]
contains the zero-frequency term 0*fs, which is real due to Hermitian symmetry.If n is even,
A[-1]
contains the term representing both positive and negative Nyquist frequency (+fs/2 and -fs/2), and must also be purely real. If n is odd, there is no term at fs/2;A[-1]
contains the largest positive frequency (fs/2*(n-1)/n), and is complex in the general case.If the input a contains an imaginary part, it is silently discarded.
Examples
>>> np.fft.fft([0, 1, 0, 0]) array([ 1.+0.j, 0.-1.j, -1.+0.j, 0.+1.j]) # may vary >>> np.fft.rfft([0, 1, 0, 0]) array([ 1.+0.j, 0.-1.j, -1.+0.j]) # may vary
Notice how the final element of the fft output is the complex conjugate of the second element, for real input. For rfft, this symmetry is exploited to compute only the non-negative frequency terms.
-
symjax.tensor.fft.
irfft
(a, n=None, axis=-1, norm=None)[source]¶ Compute the inverse of the n-point DFT for real input.
LAX-backend implementation of
irfft()
. Original docstring below.This function computes the inverse of the one-dimensional n-point discrete Fourier Transform of real input computed by rfft. In other words,
irfft(rfft(a), len(a)) == a
to within numerical accuracy. (See Notes below for whylen(a)
is necessary here.)The input is expected to be in the form returned by rfft, i.e. the real zero-frequency term followed by the complex positive frequency terms in order of increasing frequency. Since the discrete Fourier Transform of real input is Hermitian-symmetric, the negative frequency terms are taken to be the complex conjugates of the corresponding positive frequency terms.
Parameters: - a (array_like) – The input array.
- n (int, optional) – Length of the transformed axis of the output.
For n output points,
n//2+1
input points are necessary. If the input is longer than this, it is cropped. If it is shorter than this, it is padded with zeros. If n is not given, it is taken to be2*(m-1)
wherem
is the length of the input along the axis specified by axis. - axis (int, optional) – Axis over which to compute the inverse FFT. If not given, the last axis is used.
- norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The truncated or zero-padded input, transformed along the axis indicated by axis, or the last one if axis is not specified. The length of the transformed axis is n, or, if n is not given,
2*(m-1)
wherem
is the length of the transformed axis of the input. To get an odd number of output points, n must be specified.Return type: ndarray
Raises: IndexError
– If axis is larger than the last axis of a.See also
Notes
Returns the real valued n-point inverse discrete Fourier transform of a, where a contains the non-negative frequency terms of a Hermitian-symmetric sequence. n is the length of the result, not the input.
If you specify an n such that a must be zero-padded or truncated, the extra/removed values will be added/removed at high frequencies. One can thus resample a series to m points via Fourier interpolation by:
a_resamp = irfft(rfft(a), m)
.The correct interpretation of the hermitian input depends on the length of the original data, as given by n. This is because each input shape could correspond to either an odd or even length signal. By default, irfft assumes an even output length which puts the last entry at the Nyquist frequency; aliasing with its symmetric counterpart. By Hermitian symmetry, the value is thus treated as purely real. To avoid losing information, the correct length of the real input must be given.
Examples
>>> np.fft.ifft([1, -1j, -1, 1j]) array([0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j]) # may vary >>> np.fft.irfft([1, -1j, -1]) array([0., 1., 0., 0.])
Notice how the last term in the input to the ordinary ifft is the complex conjugate of the second term, and the output has zero imaginary part everywhere. When calling irfft, the negative frequencies are not specified, and the output array is purely real.
-
symjax.tensor.fft.
rfft2
(a, s=None, axes=(-2, -1), norm=None)[source]¶ Compute the 2-dimensional FFT of a real array.
LAX-backend implementation of
rfft2()
. Original docstring below.Parameters: - a (array) – Input array, taken to be real.
- s (sequence of ints, optional) – Shape of the FFT.
- axes (sequence of ints, optional) – Axes over which to compute the FFT.
- norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The result of the real 2-D FFT.
Return type: ndarray
See also
rfftn()
- Compute the N-dimensional discrete Fourier Transform for real input.
Notes
This is really just rfftn with different default behavior. For more details see rfftn.
-
symjax.tensor.fft.
irfft2
(a, s=None, axes=(-2, -1), norm=None)[source]¶ Compute the 2-dimensional inverse FFT of a real array.
LAX-backend implementation of
irfft2()
. Original docstring below.Parameters: - a (array_like) – The input array
- s (sequence of ints, optional) – Shape of the real output to the inverse FFT.
- axes (sequence of ints, optional) – The axes over which to compute the inverse fft. Default is the last two axes.
- norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The result of the inverse real 2-D FFT.
Return type: ndarray
See also
irfftn()
- Compute the inverse of the N-dimensional FFT of real input.
Notes
This is really irfftn with different defaults. For more details see irfftn.
-
symjax.tensor.fft.
rfftn
(a, s=None, axes=None, norm=None)[source]¶ Compute the N-dimensional discrete Fourier Transform for real input.
LAX-backend implementation of
rfftn()
. Original docstring below.This function computes the N-dimensional discrete Fourier Transform over any number of axes in an M-dimensional real array by means of the Fast Fourier Transform (FFT). By default, all axes are transformed, with the real transform performed over the last axis, while the remaining transforms are complex.
Parameters: - a (array_like) – Input array, taken to be real.
- s (sequence of ints, optional) – Shape (length along each transformed axis) to use from the input.
(
s[0]
refers to axis 0,s[1]
to axis 1, etc.). The final element of s corresponds to n forrfft(x, n)
, while for the remaining axes, it corresponds to n forfft(x, n)
. Along any axis, if the given shape is smaller than that of the input, the input is cropped. If it is larger, the input is padded with zeros. if s is not given, the shape of the input along the axes specified by axes is used. - axes (sequence of ints, optional) – Axes over which to compute the FFT. If not given, the last
len(s)
axes are used, or all axes if s is also not specified. - norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The truncated or zero-padded input, transformed along the axes indicated by axes, or by a combination of s and a, as explained in the parameters section above. The length of the last axis transformed will be
s[-1]//2+1
, while the remaining transformed axes will have lengths according to s, or unchanged from the input.Return type: complex ndarray
Raises: ValueError
– If s and axes have different length.IndexError
– If an element of axes is larger than than the number of axes of a.
See also
Notes
The transform for real input is performed over the last transformation axis, as by rfft, then the transform over the remaining axes is performed as by fftn. The order of the output is as for rfft for the final transformation axis, and as for fftn for the remaining transformation axes.
See fft for details, definitions and conventions used.
Examples
>>> a = np.ones((2, 2, 2)) >>> np.fft.rfftn(a) array([[[8.+0.j, 0.+0.j], # may vary [0.+0.j, 0.+0.j]], [[0.+0.j, 0.+0.j], [0.+0.j, 0.+0.j]]])
>>> np.fft.rfftn(a, axes=(2, 0)) array([[[4.+0.j, 0.+0.j], # may vary [4.+0.j, 0.+0.j]], [[0.+0.j, 0.+0.j], [0.+0.j, 0.+0.j]]])
-
symjax.tensor.fft.
irfftn
(a, s=None, axes=None, norm=None)[source]¶ Compute the inverse of the N-dimensional FFT of real input.
LAX-backend implementation of
irfftn()
. Original docstring below.This function computes the inverse of the N-dimensional discrete Fourier Transform for real input over any number of axes in an M-dimensional array by means of the Fast Fourier Transform (FFT). In other words,
irfftn(rfftn(a), a.shape) == a
to within numerical accuracy. (Thea.shape
is necessary likelen(a)
is for irfft, and for the same reason.)The input should be ordered in the same way as is returned by rfftn, i.e. as for irfft for the final transformation axis, and as for ifftn along all the other axes.
Parameters: - a (array_like) – Input array.
- s (sequence of ints, optional) – Shape (length of each transformed axis) of the output
(
s[0]
refers to axis 0,s[1]
to axis 1, etc.). s is also the number of input points used along this axis, except for the last axis, wheres[-1]//2+1
points of the input are used. Along any axis, if the shape indicated by s is smaller than that of the input, the input is cropped. If it is larger, the input is padded with zeros. If s is not given, the shape of the input along the axes specified by axes is used. Except for the last axis which is taken to be2*(m-1)
wherem
is the length of the input along that axis. - axes (sequence of ints, optional) – Axes over which to compute the inverse FFT. If not given, the last len(s) axes are used, or all axes if s is also not specified. Repeated indices in axes means that the inverse transform over that axis is performed multiple times.
- norm ({None, "ortho"}, optional) –
New in version 1.10.0.
Returns: out – The truncated or zero-padded input, transformed along the axes indicated by axes, or by a combination of s or a, as explained in the parameters section above. The length of each transformed axis is as given by the corresponding element of s, or the length of the input in every axis except for the last one if s is not given. In the final transformed axis the length of the output when s is not given is
2*(m-1)
wherem
is the length of the final transformed axis of the input. To get an odd number of output points in the final axis, s must be specified.Return type: ndarray
Raises: ValueError
– If s and axes have different length.IndexError
– If an element of axes is larger than than the number of axes of a.
See also
Notes
See fft for definitions and conventions used.
See rfft for definitions and conventions used for real input.
The correct interpretation of the hermitian input depends on the shape of the original data, as given by s. This is because each input shape could correspond to either an odd or even length signal. By default, irfftn assumes an even output length which puts the last entry at the Nyquist frequency; aliasing with its symmetric counterpart. When performing the final complex to real transform, the last value is thus treated as purely real. To avoid losing information, the correct shape of the real input must be given.
Examples
>>> a = np.zeros((3, 2, 2)) >>> a[0, 0, 0] = 3 * 2 * 2 >>> np.fft.irfftn(a) array([[[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]]])
-
symjax.tensor.fft.
fftfreq
(n, d=1.0)[source]¶ Return the Discrete Fourier Transform sample frequencies.
LAX-backend implementation of
fftfreq()
. Original docstring below.The returned float array f contains the frequency bin centers in cycles per unit of the sample spacing (with zero at the start). For instance, if the sample spacing is in seconds, then the frequency unit is cycles/second.
Given a window length n and a sample spacing d:
f = [0, 1, ..., n/2-1, -n/2, ..., -1] / (d*n) if n is even f = [0, 1, ..., (n-1)/2, -(n-1)/2, ..., -1] / (d*n) if n is odd
Parameters: - n (int) – Window length.
- d (scalar, optional) – Sample spacing (inverse of the sampling rate). Defaults to 1.
Returns: f – Array of length n containing the sample frequencies.
Return type: ndarray
Examples
>>> signal = np.array([-2, 8, 6, 4, 1, 0, 3, 5], dtype=float) >>> fourier = np.fft.fft(signal) >>> n = signal.size >>> timestep = 0.1 >>> freq = np.fft.fftfreq(n, d=timestep) >>> freq array([ 0. , 1.25, 2.5 , ..., -3.75, -2.5 , -1.25])
-
symjax.tensor.fft.
rfftfreq
(n, d=1.0)[source]¶ - Return the Discrete Fourier Transform sample frequencies
- (for usage with rfft, irfft).
LAX-backend implementation of
rfftfreq()
. Original docstring below.The returned float array f contains the frequency bin centers in cycles per unit of the sample spacing (with zero at the start). For instance, if the sample spacing is in seconds, then the frequency unit is cycles/second.
Given a window length n and a sample spacing d:
f = [0, 1, ..., n/2-1, n/2] / (d*n) if n is even f = [0, 1, ..., (n-1)/2-1, (n-1)/2] / (d*n) if n is odd
Unlike fftfreq (but like scipy.fftpack.rfftfreq) the Nyquist frequency component is considered to be positive.
Parameters: - n (int) – Window length.
- d (scalar, optional) – Sample spacing (inverse of the sampling rate). Defaults to 1.
Returns: f – Array of length
n//2 + 1
containing the sample frequencies.Return type: ndarray
Examples
>>> signal = np.array([-2, 8, 6, 4, 1, 0, 3, 5, -3, 4], dtype=float) >>> fourier = np.fft.rfft(signal) >>> n = signal.size >>> sample_rate = 100 >>> freq = np.fft.fftfreq(n, d=1./sample_rate) >>> freq array([ 0., 10., 20., ..., -30., -20., -10.]) >>> freq = np.fft.rfftfreq(n, d=1./sample_rate) >>> freq array([ 0., 10., 20., 30., 40., 50.])
symjax.tensor.random
¶
bernoulli (key, p, shape) |
Sample Bernoulli random values with given shape and mean. |
beta (key, a, …) |
Sample Beta random values with given shape and float dtype. |
cauchy (key[, shape, dtype]) |
Sample Cauchy random values with given shape and float dtype. |
dirichlet (key, alpha[, shape, dtype]) |
Sample Dirichlet random values with given shape and float dtype. |
gamma (key, a[, shape, dtype]) |
Sample Gamma random values with given shape and float dtype. |
gumbel (key[, shape, dtype]) |
Sample Gumbel random values with given shape and float dtype. |
laplace (key[, shape, dtype]) |
Sample Laplace random values with given shape and float dtype. |
logistic (key[, shape, dtype]) |
Sample logistic random values with given shape and float dtype. |
multivariate_normal (key, mean, cov, shape, dtype) |
Sample multivariate normal random values with given mean and covariance. |
normal (key, shape, dtype) |
Sample standard normal random values with given shape and float dtype. |
pareto (key, b[, shape, dtype]) |
Sample Pareto random values with given shape and float dtype. |
randint (key, shape, minval, …) |
Sample uniform random values in [minval, maxval) with given shape/dtype. |
shuffle (key, x, axis) |
Shuffle the elements of an array uniformly at random along an axis. |
truncated_normal (key, lower, …) |
Sample truncated standard normal random values with given shape and dtype. |
uniform (key, shape, dtype, minval, …) |
Sample uniform random values in [minval, maxval) with given shape/dtype. |
Detailed Description¶
-
symjax.tensor.random.
bernoulli
(key: jax._src.numpy.lax_numpy.ndarray, p: jax._src.numpy.lax_numpy.ndarray = 0.5, shape: Optional[Sequence[int]] = None) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample Bernoulli random values with given shape and mean.
Parameters: - key – a PRNGKey used as the random key.
- p – optional, a float or array of floats for the mean of the random
variables. Must be broadcast-compatible with
shape
. Default 0.5. - shape – optional, a tuple of nonnegative integers representing the result
shape. Must be broadcast-compatible with
p.shape
. The default (None) produces a result shape equal top.shape
.
Returns: A random array with boolean dtype and shape given by
shape
ifshape
is not None, or elsep.shape
.
-
symjax.tensor.random.
beta
(key: jax._src.numpy.lax_numpy.ndarray, a: Union[float, jax._src.numpy.lax_numpy.ndarray], b: Union[float, jax._src.numpy.lax_numpy.ndarray], shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample Beta random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- a – a float or array of floats broadcast-compatible with
shape
representing the first parameter “alpha”. - b – a float or array of floats broadcast-compatible with
shape
representing the second parameter “beta”. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
a
andb
. The default (None) produces a result shape by broadcastinga
andb
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinga
andb
.
-
symjax.tensor.random.
cauchy
(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample Cauchy random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
dirichlet
(key, alpha, shape=None, dtype=<class 'numpy.float64'>)[source]¶ Sample Dirichlet random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- alpha – an array of shape
(..., n)
used as the concentration parameter of the random variables. - shape – optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
element of value
n
. Must be broadcast-compatible withalpha.shape[:-1]
. The default (None) produces a result shape equal toalpha.shape
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape + (alpha.shape[-1],)
ifshape
is not None, or elsealpha.shape
.
-
symjax.tensor.random.
gamma
(key, a, shape=None, dtype=<class 'numpy.float64'>)[source]¶ Sample Gamma random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- a – a float or array of floats broadcast-compatible with
shape
representing the parameter of the distribution. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
a
. The default (None) produces a result shape equal toa.shape
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else bya.shape
.
-
symjax.tensor.random.
gumbel
(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample Gumbel random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
laplace
(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample Laplace random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
logistic
(key, shape=(), dtype=<class 'numpy.float64'>)[source]¶ Sample logistic random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
multivariate_normal
(key: jax._src.numpy.lax_numpy.ndarray, mean: jax._src.numpy.lax_numpy.ndarray, cov: jax._src.numpy.lax_numpy.ndarray, shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample multivariate normal random values with given mean and covariance.
Parameters: - key – a PRNGKey used as the random key.
- mean – a mean vector of shape
(..., n)
. - cov – a positive definite covariance matrix of shape
(..., n, n)
. The batch shape...
must be broadcast-compatible with that ofmean
. - shape – optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
axis. Must be broadcast-compatible with
mean.shape[:-1]
andcov.shape[:-2]
. The default (None) produces a result batch shape by broadcasting together the batch shapes ofmean
andcov
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape + mean.shape[-1:]
ifshape
is not None, or elsebroadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]
.
-
symjax.tensor.random.
normal
(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int] = (), dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample standard normal random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
pareto
(key, b, shape=None, dtype=<class 'numpy.float64'>)[source]¶ Sample Pareto random values with given shape and float dtype.
Parameters: - key – a PRNGKey used as the random key.
- a – a float or array of floats broadcast-compatible with
shape
representing the parameter of the distribution. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
b
. The default (None) produces a result shape equal tob.shape
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else byb.shape
.
-
symjax.tensor.random.
randint
(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int], minval: Union[int, jax._src.numpy.lax_numpy.ndarray], maxval: Union[int, jax._src.numpy.lax_numpy.ndarray], dtype: numpy.dtype = <class 'numpy.int64'>)[source]¶ Sample uniform random values in [minval, maxval) with given shape/dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – a tuple of nonnegative integers representing the shape.
- minval – int or array of ints broadcast-compatible with
shape
, a minimum (inclusive) value for the range. - maxval – int or array of ints broadcast-compatible with
shape
, a maximum (exclusive) value for the range. - dtype – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).
Returns: A random array with the specified shape and dtype.
-
symjax.tensor.random.
shuffle
(key: jax._src.numpy.lax_numpy.ndarray, x: jax._src.numpy.lax_numpy.ndarray, axis: int = 0) → jax._src.numpy.lax_numpy.ndarray[source]¶ Shuffle the elements of an array uniformly at random along an axis.
Parameters: - key – a PRNGKey used as the random key.
- x – the array to be shuffled.
- axis – optional, an int axis along which to shuffle (default 0).
Returns: A shuffled version of x.
-
symjax.tensor.random.
truncated_normal
(key: jax._src.numpy.lax_numpy.ndarray, lower: Union[float, jax._src.numpy.lax_numpy.ndarray], upper: Union[float, jax._src.numpy.lax_numpy.ndarray], shape: Optional[Sequence[int]] = None, dtype: numpy.dtype = <class 'numpy.float64'>) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample truncated standard normal random values with given shape and dtype.
Parameters: - key – a PRNGKey used as the random key.
- lower – a float or array of floats representing the lower bound for
truncation. Must be broadcast-compatible with
upper
. - upper – a float or array of floats representing the upper bound for
truncation. Must be broadcast-compatible with
lower
. - shape – optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with
lower
andupper
. The default (None) produces a result shape by broadcastinglower
andupper
. - dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
Returns: A random array with the specified dtype and shape given by
shape
ifshape
is not None, or else by broadcastinglower
andupper
. Returns values in the open interval(lower, upper)
.
-
symjax.tensor.random.
uniform
(key: jax._src.numpy.lax_numpy.ndarray, shape: Sequence[int] = (), dtype: numpy.dtype = <class 'numpy.float64'>, minval: Union[float, jax._src.numpy.lax_numpy.ndarray] = 0.0, maxval: Union[float, jax._src.numpy.lax_numpy.ndarray] = 1.0) → jax._src.numpy.lax_numpy.ndarray[source]¶ Sample uniform random values in [minval, maxval) with given shape/dtype.
Parameters: - key – a PRNGKey used as the random key.
- shape – optional, a tuple of nonnegative integers representing the result shape. Default ().
- dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- minval – optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).
- maxval – optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).
Returns: A random array with the specified shape and dtype.
symjax.tensor.linalg
¶
cond (x[, p]) |
Compute the condition number of a matrix. |
det |
|
eig (a) |
Compute the eigenvalues and right eigenvectors of a square array. |
eigh (a[, b, lower, eigvals_only, …]) |
Solve a standard or generalized eigenvalue problem for a complex |
eigvals (a) |
Compute the eigenvalues of a general matrix. |
eigvalsh (a[, UPLO]) |
Compute the eigenvalues of a complex Hermitian or real symmetric matrix. |
inv (a[, overwrite_a, check_finite]) |
Compute the inverse of a matrix. |
lstsq (a, b[, rcond, numpy_resid]) |
Return the least-squares solution to a linear matrix equation. |
matrix_power (a, n) |
Raise a square matrix to the (integer) power n. |
matrix_rank (M[, tol]) |
Return matrix rank of array using SVD method |
multi_dot (arrays, *[, precision]) |
Compute the dot product of two or more arrays in a single function call, |
norm (x[, ord, axis, keepdims]) |
Tensor/Matrix/Vector norm. |
pinv (a[, rcond]) |
Compute the (Moore-Penrose) pseudo-inverse of a matrix. |
qr (a[, mode]) |
Compute the qr factorization of a matrix. |
slogdet (a) |
Compute the sign and (natural) logarithm of the determinant of an array. |
solve (a, b) |
Solve a linear matrix equation, or system of linear scalar equations. |
svd (a[, full_matrices, compute_uv]) |
Singular Value Decomposition. |
tensorinv (a[, ind]) |
Compute the ‘inverse’ of an N-dimensional array. |
tensorsolve (a, b[, axes]) |
Solve the tensor equation a x = b for x. |
cholesky (a[, lower, overwrite_a, check_finite]) |
Compute the Cholesky decomposition of a matrix. |
block_diag (*arrs) |
Create a block diagonal matrix from provided arrays. |
cho_solve (c_and_lower, b[, overwrite_b, …]) |
Solve the linear equations A x = b, given the Cholesky factorization of A. |
eigh (a[, b, lower, eigvals_only, …]) |
Solve a standard or generalized eigenvalue problem for a complex |
expm (A, *[, upper_triangular, max_squarings]) |
Compute the matrix exponential using Pade approximation. |
expm_frechet |
|
inv (a[, overwrite_a, check_finite]) |
Compute the inverse of a matrix. |
lu (a[, permute_l, overwrite_a, check_finite]) |
Compute pivoted LU decomposition of a matrix. |
lu_factor (a[, overwrite_a, check_finite]) |
Compute pivoted LU decomposition of a matrix. |
lu_solve (lu_and_piv, b[, trans, …]) |
Solve an equation system, a x = b, given the LU factorization of a |
solve_triangular (a, b[, trans, lower, …]) |
Solve the equation a x = b for x, assuming a is a triangular matrix. |
tril (m[, k]) |
Make a copy of a matrix with elements above the kth diagonal zeroed. |
triu (m[, k]) |
Make a copy of a matrix with elements below the kth diagonal zeroed. |
singular_vectors_power_iteration (weight[, …]) |
|
eigenvector_power_iteration (weight[, axis, …]) |
|
gram_schmidt (V[, normalize]) |
gram-schmidt orthogonalization |
modified_gram_schmidt (V) |
modified gram-schmidt orthogonalization |
Detailed Description¶
-
symjax.tensor.linalg.
cond
(x, p=None)[source]¶ Compute the condition number of a matrix.
LAX-backend implementation of
cond()
. Original docstring below.This function is capable of returning the condition number using one of seven different norms, depending on the value of p (see Parameters below).
Parameters: - x ((.., M, N) array_like) – The matrix whose condition number is sought.
- p ({None, 1, -1, 2, -2, inf, -inf, 'fro'}, optional) – Order of the norm:
Returns: c – The condition number of the matrix. May be infinite.
Return type: {float, inf}
See also
numpy.linalg.norm()
Notes
The condition number of x is defined as the norm of x times the norm of the inverse of x [1]_; the norm can be the usual L2-norm (root-of-sum-of-squares) or one of a number of other matrix norms.
References
[1] G. Strang, Linear Algebra and Its Applications, Orlando, FL, Academic Press, Inc., 1980, pg. 285. Examples
>>> from numpy import linalg as LA >>> a = np.array([[1, 0, -1], [0, 1, 0], [1, 0, 1]]) >>> a array([[ 1, 0, -1], [ 0, 1, 0], [ 1, 0, 1]]) >>> LA.cond(a) 1.4142135623730951 >>> LA.cond(a, 'fro') 3.1622776601683795 >>> LA.cond(a, np.inf) 2.0 >>> LA.cond(a, -np.inf) 1.0 >>> LA.cond(a, 1) 2.0 >>> LA.cond(a, -1) 1.0 >>> LA.cond(a, 2) 1.4142135623730951 >>> LA.cond(a, -2) 0.70710678118654746 # may vary >>> min(LA.svd(a, compute_uv=False))*min(LA.svd(LA.inv(a), compute_uv=False)) 0.70710678118654746 # may vary
-
symjax.tensor.linalg.
eig
(a)[source]¶ Compute the eigenvalues and right eigenvectors of a square array.
LAX-backend implementation of
eig()
. Original docstring below.Parameters: a ((.., M, M) array) – Matrices for which the eigenvalues and right eigenvectors will be computed Returns: - w ((…, M) array) – The eigenvalues, each repeated according to its multiplicity. The eigenvalues are not necessarily ordered. The resulting array will be of complex type, unless the imaginary part is zero in which case it will be cast to a real type. When a is real the resulting eigenvalues will be real (0 imaginary part) or occur in conjugate pairs
- v ((…, M, M) array) – The normalized (unit “length”) eigenvectors, such that the
column
v[:,i]
is the eigenvector corresponding to the eigenvaluew[i]
.
Raises: LinAlgError
– If the eigenvalue computation does not converge.See also
eigvals()
- eigenvalues of a non-symmetric array.
eigh()
- eigenvalues and eigenvectors of a real symmetric or complex Hermitian (conjugate symmetric) array.
eigvalsh()
- eigenvalues of a real symmetric or complex Hermitian (conjugate symmetric) array.
scipy.linalg.eig()
- Similar function in SciPy that also solves the generalized eigenvalue problem.
scipy.linalg.schur()
- Best choice for unitary and other non-Hermitian normal matrices.
Notes
New in version 1.8.0.
Broadcasting rules apply, see the numpy.linalg documentation for details.
This is implemented using the
_geev
LAPACK routines which compute the eigenvalues and eigenvectors of general square arrays.The number w is an eigenvalue of a if there exists a vector v such that
a @ v = w * v
. Thus, the arrays a, w, and v satisfy the equationsa @ v[:,i] = w[i] * v[:,i]
for \(i \in \{0,...,M-1\}\).The array v of eigenvectors may not be of maximum rank, that is, some of the columns may be linearly dependent, although round-off error may obscure that fact. If the eigenvalues are all different, then theoretically the eigenvectors are linearly independent and a can be diagonalized by a similarity transformation using v, i.e,
inv(v) @ a @ v
is diagonal.For non-Hermitian normal matrices the SciPy function scipy.linalg.schur is preferred because the matrix v is guaranteed to be unitary, which is not the case when using eig. The Schur factorization produces an upper triangular matrix rather than a diagonal matrix, but for normal matrices only the diagonal of the upper triangular matrix is needed, the rest is roundoff error.
Finally, it is emphasized that v consists of the right (as in right-hand side) eigenvectors of a. A vector y satisfying
y.T @ a = z * y.T
for some number z is called a left eigenvector of a, and, in general, the left and right eigenvectors of a matrix are not necessarily the (perhaps conjugate) transposes of each other.References
G. Strang, Linear Algebra and Its Applications, 2nd Ed., Orlando, FL, Academic Press, Inc., 1980, Various pp.
Examples
>>> from numpy import linalg as LA
(Almost) trivial example with real e-values and e-vectors.
>>> w, v = LA.eig(np.diag((1, 2, 3))) >>> w; v array([1., 2., 3.]) array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
Real matrix possessing complex e-values and e-vectors; note that the e-values are complex conjugates of each other.
>>> w, v = LA.eig(np.array([[1, -1], [1, 1]])) >>> w; v array([1.+1.j, 1.-1.j]) array([[0.70710678+0.j , 0.70710678-0.j ], [0. -0.70710678j, 0. +0.70710678j]])
Complex-valued matrix with real e-values (but complex-valued e-vectors); note that
a.conj().T == a
, i.e., a is Hermitian.>>> a = np.array([[1, 1j], [-1j, 1]]) >>> w, v = LA.eig(a) >>> w; v array([2.+0.j, 0.+0.j]) array([[ 0. +0.70710678j, 0.70710678+0.j ], # may vary [ 0.70710678+0.j , -0. +0.70710678j]])
Be careful about round-off error!
>>> a = np.array([[1 + 1e-9, 0], [0, 1 - 1e-9]]) >>> # Theor. e-values are 1 +/- 1e-9 >>> w, v = LA.eig(a) >>> w; v array([1., 1.]) array([[1., 0.], [0., 1.]])
-
symjax.tensor.linalg.
eigh
(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, overwrite_b=False, turbo=True, eigvals=None, type=1, check_finite=True)[source]¶ - Solve a standard or generalized eigenvalue problem for a complex
- Hermitian or real symmetric matrix.
LAX-backend implementation of
eigh()
. Original docstring below.Find eigenvalues array
w
and optionally eigenvectors arrayv
of arraya
, whereb
is positive definite such that for every eigenvalue λ (i-th entry of w) and its eigenvectorvi
(i-th column ofv
) satisfies:a @ vi = λ * b @ vi vi.conj().T @ a @ vi = λ vi.conj().T @ b @ vi = 1
In the standard problem,
b
is assumed to be the identity matrix.Parameters: - a ((M, M) array_like) – A complex Hermitian or real symmetric matrix whose eigenvalues and eigenvectors will be computed.
- b ((M, M) array_like, optional) – A complex Hermitian or real symmetric definite positive matrix in. If omitted, identity matrix is assumed.
- lower (bool, optional) – Whether the pertinent array data is taken from the lower or upper
triangle of
a
and, if applicable,b
. (Default: lower) - eigvals_only (bool, optional) – Whether to calculate only eigenvalues and no eigenvectors. (Default: both are calculated)
- type (int, optional) – For the generalized problems, this keyword specifies the problem type
to be solved for
w
andv
(only takes 1, 2, 3 as possible inputs): - overwrite_a (bool, optional) – Whether to overwrite data in
a
(may improve performance). Default is False. - overwrite_b (bool, optional) – Whether to overwrite data in
b
(may improve performance). Default is False. - check_finite (bool, optional) – Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
- turbo (bool, optional) – Deprecated since v1.5.0, use ``driver=gvd`` keyword instead. Use divide and conquer algorithm (faster but expensive in memory, only for generalized eigenvalue problem and if full set of eigenvalues are requested.). Has no significant effect if eigenvectors are not requested.
- eigvals (tuple (lo, hi), optional) – Deprecated since v1.5.0, use ``subset_by_index`` keyword instead. Indexes of the smallest and largest (in ascending order) eigenvalues and corresponding eigenvectors to be returned: 0 <= lo <= hi <= M-1. If omitted, all eigenvalues and eigenvectors are returned.
Returns: - w ((N,) ndarray) – The N (1<=N<=M) selected eigenvalues, in ascending order, each repeated according to its multiplicity.
- v ((M, N) ndarray) – (if
eigvals_only == False
)
Raises: LinAlgError
– If eigenvalue computation does not converge, an error occurred, or b matrix is not definite positive. Note that if input matrices are not symmetric or Hermitian, no error will be reported but results will be wrong.See also
eigvalsh()
- eigenvalues of symmetric or Hermitian arrays
eig()
- eigenvalues and right eigenvectors for non-symmetric arrays
eigh_tridiagonal()
- eigenvalues and right eiegenvectors for symmetric/Hermitian tridiagonal matrices
Notes
This function does not check the input array for being hermitian/symmetric in order to allow for representing arrays with only their upper/lower triangular parts. Also, note that even though not taken into account, finiteness check applies to the whole array and unaffected by “lower” keyword.
This function uses LAPACK drivers for computations in all possible keyword combinations, prefixed with
sy
if arrays are real andhe
if complex, e.g., a float array with “evr” driver is solved via “syevr”, complex arrays with “gvx” driver problem is solved via “hegvx” etc.As a brief summary, the slowest and the most robust driver is the classical
<sy/he>ev
which uses symmetric QR.<sy/he>evr
is seen as the optimal choice for the most general cases. However, there are certain occassions that<sy/he>evd
computes faster at the expense of more memory usage.<sy/he>evx
, while still being faster than<sy/he>ev
, often performs worse than the rest except when very few eigenvalues are requested for large arrays though there is still no performance guarantee.For the generalized problem, normalization with respoect to the given type argument:
type 1 and 3 : v.conj().T @ a @ v = w type 2 : inv(v).conj().T @ a @ inv(v) = w type 1 or 2 : v.conj().T @ b @ v = I type 3 : v.conj().T @ inv(b) @ v = I
Examples
>>> from scipy.linalg import eigh >>> A = np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]) >>> w, v = eigh(A) >>> np.allclose(A @ v - v @ np.diag(w), np.zeros((4, 4))) True
Request only the eigenvalues
>>> w = eigh(A, eigvals_only=True)
Request eigenvalues that are less than 10.
>>> A = np.array([[34, -4, -10, -7, 2], ... [-4, 7, 2, 12, 0], ... [-10, 2, 44, 2, -19], ... [-7, 12, 2, 79, -34], ... [2, 0, -19, -34, 29]]) >>> eigh(A, eigvals_only=True, subset_by_value=[-np.inf, 10]) array([6.69199443e-07, 9.11938152e+00])
Request the largest second eigenvalue and its eigenvector
>>> w, v = eigh(A, subset_by_index=[1, 1]) >>> w array([9.11938152]) >>> v.shape # only a single column is returned (5, 1)
-
symjax.tensor.linalg.
eigvals
(a)[source]¶ Compute the eigenvalues of a general matrix.
LAX-backend implementation of
eigvals()
. Original docstring below.Main difference between eigvals and eig: the eigenvectors aren’t returned.
Parameters: a ((.., M, M) array_like) – A complex- or real-valued matrix whose eigenvalues will be computed. Returns: w – The eigenvalues, each repeated according to its multiplicity. They are not necessarily ordered, nor are they necessarily real for real matrices. Return type: (.., M,) ndarray Raises: LinAlgError
– If the eigenvalue computation does not converge.See also
eig()
- eigenvalues and right eigenvectors of general arrays
eigvalsh()
- eigenvalues of real symmetric or complex Hermitian (conjugate symmetric) arrays.
eigh()
- eigenvalues and eigenvectors of real symmetric or complex Hermitian (conjugate symmetric) arrays.
scipy.linalg.eigvals()
- Similar function in SciPy.
Notes
New in version 1.8.0.
Broadcasting rules apply, see the numpy.linalg documentation for details.
This is implemented using the
_geev
LAPACK routines which compute the eigenvalues and eigenvectors of general square arrays.Examples
Illustration, using the fact that the eigenvalues of a diagonal matrix are its diagonal elements, that multiplying a matrix on the left by an orthogonal matrix, Q, and on the right by Q.T (the transpose of Q), preserves the eigenvalues of the “middle” matrix. In other words, if Q is orthogonal, then
Q * A * Q.T
has the same eigenvalues asA
:>>> from numpy import linalg as LA >>> x = np.random.random() >>> Q = np.array([[np.cos(x), -np.sin(x)], [np.sin(x), np.cos(x)]]) >>> LA.norm(Q[0, :]), LA.norm(Q[1, :]), np.dot(Q[0, :],Q[1, :]) (1.0, 1.0, 0.0)
Now multiply a diagonal matrix by
Q
on one side and byQ.T
on the other:>>> D = np.diag((-1,1)) >>> LA.eigvals(D) array([-1., 1.]) >>> A = np.dot(Q, D) >>> A = np.dot(A, Q.T) >>> LA.eigvals(A) array([ 1., -1.]) # random
-
symjax.tensor.linalg.
eigvalsh
(a, UPLO='L')[source]¶ Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
LAX-backend implementation of
eigvalsh()
. Original docstring below.Main difference from eigh: the eigenvectors are not computed.
Parameters: - a ((.., M, M) array_like) – A complex- or real-valued matrix whose eigenvalues are to be computed.
- UPLO ({'L', 'U'}, optional) – Specifies whether the calculation is done with the lower triangular part of a (‘L’, default) or the upper triangular part (‘U’). Irrespective of this value only the real parts of the diagonal will be considered in the computation to preserve the notion of a Hermitian matrix. It therefore follows that the imaginary part of the diagonal will always be treated as zero.
Returns: w – The eigenvalues in ascending order, each repeated according to its multiplicity.
Return type: (.., M,) ndarray
Raises: LinAlgError
– If the eigenvalue computation does not converge.See also
Notes
New in version 1.8.0.
Broadcasting rules apply, see the numpy.linalg documentation for details.
The eigenvalues are computed using LAPACK routines
_syevd
,_heevd
.Examples
>>> from numpy import linalg as LA >>> a = np.array([[1, -2j], [2j, 5]]) >>> LA.eigvalsh(a) array([ 0.17157288, 5.82842712]) # may vary
>>> # demonstrate the treatment of the imaginary part of the diagonal >>> a = np.array([[5+2j, 9-2j], [0+2j, 2-1j]]) >>> a array([[5.+2.j, 9.-2.j], [0.+2.j, 2.-1.j]]) >>> # with UPLO='L' this is numerically equivalent to using LA.eigvals() >>> # with: >>> b = np.array([[5.+0.j, 0.-2.j], [0.+2.j, 2.-0.j]]) >>> b array([[5.+0.j, 0.-2.j], [0.+2.j, 2.+0.j]]) >>> wa = LA.eigvalsh(a) >>> wb = LA.eigvals(b) >>> wa; wb array([1., 6.]) array([6.+0.j, 1.+0.j])
-
symjax.tensor.linalg.
inv
(a, overwrite_a=False, check_finite=True)[source]¶ Compute the inverse of a matrix.
LAX-backend implementation of
inv()
. Original docstring below.Parameters: - a (array_like) – Square matrix to be inverted.
- overwrite_a (bool, optional) – Discard data in a (may improve performance). Default is False.
- check_finite (bool, optional) – Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns: ainv – Inverse of the matrix a.
Return type: ndarray
Raises: LinAlgError
– If a is singular.ValueError
– If a is not square, or not 2D.
Examples
>>> from scipy import linalg >>> a = np.array([[1., 2.], [3., 4.]]) >>> linalg.inv(a) array([[-2. , 1. ], [ 1.5, -0.5]]) >>> np.dot(a, linalg.inv(a)) array([[ 1., 0.], [ 0., 1.]])
-
symjax.tensor.linalg.
lstsq
(a, b, rcond=None, *, numpy_resid=False)[source]¶ Return the least-squares solution to a linear matrix equation.
LAX-backend implementation of
lstsq()
. It has two important differences:- In numpy.linalg.lstsq, the default rcond is -1, and warns that in the future the default will be None. Here, the default rcond is None.
- In np.linalg.lstsq the returned residuals are empty for low-rank or over-determined solutions. Here, the residuals are returned in all cases, to make the function compatible with jit. The non-jit compatible numpy behavior can be recovered by passing numpy_resid=True.
The lstsq function does not currently have a custom JVP rule, so the gradient is poorly behaved for some inputs, particularly for low-rank a.
Original docstring below.
Computes the vector x that approximatively solves the equation
a @ x = b
. The equation may be under-, well-, or over-determined (i.e., the number of linearly independent rows of a can be less than, equal to, or greater than its number of linearly independent columns). If a is square and of full rank, then x (but for round-off error) is the “exact” solution of the equation. Else, x minimizes the Euclidean 2-norm \(|| b - a x ||\).Parameters: - a ((M, N) array_like) – “Coefficient” matrix.
- b ({(M,), (M, K)} array_like) – Ordinate or “dependent variable” values. If b is two-dimensional, the least-squares solution is calculated for each of the K columns of b.
- rcond (float, optional) – Cut-off ratio for small singular values of a. For the purposes of rank determination, singular values are treated as zero if they are smaller than rcond times the largest singular value of a.
Returns: - x ({(N,), (N, K)} ndarray) – Least-squares solution. If b is two-dimensional, the solutions are in the K columns of x.
- residuals ({(1,), (K,), (0,)} ndarray) – Sums of residuals; squared Euclidean 2-norm for each column in
b - a*x
. If the rank of a is < N or M <= N, this is an empty array. If b is 1-dimensional, this is a (1,) shape array. Otherwise the shape is (K,). - rank (int) – Rank of matrix a.
- s ((min(M, N),) ndarray) – Singular values of a.
Raises: LinAlgError
– If computation does not converge.See also
scipy.linalg.lstsq()
- Similar function in SciPy.
Notes
If b is a matrix, then all array results are returned as matrices.
Examples
Fit a line,
y = mx + c
, through some noisy data-points:>>> x = np.array([0, 1, 2, 3]) >>> y = np.array([-1, 0.2, 0.9, 2.1])
By examining the coefficients, we see that the line should have a gradient of roughly 1 and cut the y-axis at, more or less, -1.
We can rewrite the line equation as
y = Ap
, whereA = [[x 1]]
andp = [[m], [c]]
. Now use lstsq to solve for p:>>> A = np.vstack([x, np.ones(len(x))]).T >>> A array([[ 0., 1.], [ 1., 1.], [ 2., 1.], [ 3., 1.]])
>>> m, c = np.linalg.lstsq(A, y, rcond=None)[0] >>> m, c (1.0 -0.95) # may vary
Plot the data along with the fitted line:
>>> import matplotlib.pyplot as plt >>> _ = plt.plot(x, y, 'o', label='Original data', markersize=10) >>> _ = plt.plot(x, m*x + c, 'r', label='Fitted line') >>> _ = plt.legend() >>> plt.show()
-
symjax.tensor.linalg.
matrix_power
(a, n)[source]¶ Raise a square matrix to the (integer) power n.
LAX-backend implementation of
matrix_power()
. Original docstring below.For positive integers n, the power is computed by repeated matrix squarings and matrix multiplications. If
n == 0
, the identity matrix of the same shape as M is returned. Ifn < 0
, the inverse is computed and then raised to theabs(n)
.Note
Stacks of object matrices are not currently supported.
Parameters: - a ((.., M, M) array_like) – Matrix to be “powered”.
- n (int) – The exponent can be any integer or long integer, positive, negative, or zero.
Returns: a**n – The return value is the same shape and type as M; if the exponent is positive or zero then the type of the elements is the same as those of M. If the exponent is negative the elements are floating-point.
Return type: (.., M, M) ndarray or matrix object
Raises: LinAlgError
– For matrices that are not square or that (for negative powers) cannot be inverted numerically.Examples
>>> from numpy.linalg import matrix_power >>> i = np.array([[0, 1], [-1, 0]]) # matrix equiv. of the imaginary unit >>> matrix_power(i, 3) # should = -i array([[ 0, -1], [ 1, 0]]) >>> matrix_power(i, 0) array([[1, 0], [0, 1]]) >>> matrix_power(i, -3) # should = 1/(-i) = i, but w/ f.p. elements array([[ 0., 1.], [-1., 0.]])
Somewhat more sophisticated example
>>> q = np.zeros((4, 4)) >>> q[0:2, 0:2] = -i >>> q[2:4, 2:4] = i >>> q # one of the three quaternion units not equal to 1 array([[ 0., -1., 0., 0.], [ 1., 0., 0., 0.], [ 0., 0., 0., 1.], [ 0., 0., -1., 0.]]) >>> matrix_power(q, 2) # = -np.eye(4) array([[-1., 0., 0., 0.], [ 0., -1., 0., 0.], [ 0., 0., -1., 0.], [ 0., 0., 0., -1.]])
-
symjax.tensor.linalg.
matrix_rank
(M, tol=None)[source]¶ Return matrix rank of array using SVD method
LAX-backend implementation of
matrix_rank()
. Original docstring below.Rank of the array is the number of singular values of the array that are greater than tol.
Changed in version 1.14: Can now operate on stacks of matrices
Parameters: - M ({(M,), (.., M, N)} array_like) – Input vector or stack of matrices.
- tol ((..) array_like, float, optional) – Threshold below which SVD values are considered zero. If tol is
None, and
S
is an array with singular values for M, andeps
is the epsilon value for datatype ofS
, then tol is set toS.max() * max(M.shape) * eps
.
Returns: rank – Rank of M.
Return type: (..) array_like
Notes
The default threshold to detect rank deficiency is a test on the magnitude of the singular values of M. By default, we identify singular values less than
S.max() * max(M.shape) * eps
as indicating rank deficiency (with the symbols defined above). This is the algorithm MATLAB uses [1]. It also appears in Numerical recipes in the discussion of SVD solutions for linear least squares [2].This default threshold is designed to detect rank deficiency accounting for the numerical errors of the SVD computation. Imagine that there is a column in M that is an exact (in floating point) linear combination of other columns in M. Computing the SVD on M will not produce a singular value exactly equal to 0 in general: any difference of the smallest SVD value from 0 will be caused by numerical imprecision in the calculation of the SVD. Our threshold for small SVD values takes this numerical imprecision into account, and the default threshold will detect such numerical rank deficiency. The threshold may declare a matrix M rank deficient even if the linear combination of some columns of M is not exactly equal to another column of M but only numerically very close to another column of M.
We chose our default threshold because it is in wide use. Other thresholds are possible. For example, elsewhere in the 2007 edition of Numerical recipes there is an alternative threshold of
S.max() * np.finfo(M.dtype).eps / 2. * np.sqrt(m + n + 1.)
. The authors describe this threshold as being based on “expected roundoff error” (p 71).The thresholds above deal with floating point roundoff error in the calculation of the SVD. However, you may have more information about the sources of error in M that would make you consider other tolerance values to detect effective rank deficiency. The most useful measure of the tolerance depends on the operations you intend to use on your matrix. For example, if your data come from uncertain measurements with uncertainties greater than floating point epsilon, choosing a tolerance near that uncertainty may be preferable. The tolerance may be absolute if the uncertainties are absolute rather than relative.
References
[1] MATLAB reference documention, “Rank” https://www.mathworks.com/help/techdoc/ref/rank.html [2] W. H. Press, S. A. Teukolsky, W. T. Vetterling and B. P. Flannery, “Numerical Recipes (3rd edition)”, Cambridge University Press, 2007, page 795. Examples
>>> from numpy.linalg import matrix_rank >>> matrix_rank(np.eye(4)) # Full rank matrix 4 >>> I=np.eye(4); I[-1,-1] = 0. # rank deficient matrix >>> matrix_rank(I) 3 >>> matrix_rank(np.ones((4,))) # 1 dimension - rank 1 unless all 0 1 >>> matrix_rank(np.zeros((4,))) 0
-
symjax.tensor.linalg.
multi_dot
(arrays, *, precision=None)[source]¶ - Compute the dot product of two or more arrays in a single function call,
- while automatically selecting the fastest evaluation order.
LAX-backend implementation of
multi_dot()
. Original docstring below.multi_dot chains numpy.dot and uses optimal parenthesization of the matrices [1]_ [2]_. Depending on the shapes of the matrices, this can speed up the multiplication a lot.
If the first argument is 1-D it is treated as a row vector. If the last argument is 1-D it is treated as a column vector. The other arguments must be 2-D.
Think of multi_dot as:
def multi_dot(arrays): return functools.reduce(np.dot, arrays)
Parameters: arrays (sequence of array_like) – If the first argument is 1-D it is treated as row vector. If the last argument is 1-D it is treated as column vector. The other arguments must be 2-D. Returns: output – Returns the dot product of the supplied arrays. Return type: ndarray See also
dot()
- dot multiplication with two arguments.
References
[1] Cormen, “Introduction to Algorithms”, Chapter 15.2, p. 370-378 [2] https://en.wikipedia.org/wiki/Matrix_chain_multiplication Examples
multi_dot allows you to write:
>>> from numpy.linalg import multi_dot >>> # Prepare some data >>> A = np.random.random((10000, 100)) >>> B = np.random.random((100, 1000)) >>> C = np.random.random((1000, 5)) >>> D = np.random.random((5, 333)) >>> # the actual dot multiplication >>> _ = multi_dot([A, B, C, D])
instead of:
>>> _ = np.dot(np.dot(np.dot(A, B), C), D) >>> # or >>> _ = A.dot(B).dot(C).dot(D)
Notes
The cost for a matrix multiplication can be calculated with the following function:
def cost(A, B): return A.shape[0] * A.shape[1] * B.shape[1]
Assume we have three matrices \(A_{10x100}, B_{100x5}, C_{5x50}\).
The costs for the two different parenthesizations are as follows:
cost((AB)C) = 10*100*5 + 10*5*50 = 5000 + 2500 = 7500 cost(A(BC)) = 10*100*50 + 100*5*50 = 50000 + 25000 = 75000
-
symjax.tensor.linalg.
norm
(x, ord=2, axis=None, keepdims=False)[source]¶ Tensor/Matrix/Vector norm.
For matrices and vectors, this function is able to return one of eight different matrix norms, or one of an infinite number of vector norms (described below), depending on the value of the
ord
parameter.for higher-dimensional tensors, only \(0<ord<\infty\) is supported.
Parameters: - x (array_like) – Input array. If axis is None, x must be 1-D or 2-D, unless ord
is None. If both axis and ord are None, the 2-norm of
x.ravel
will be returned. - ord ({non-zero int, inf, -inf, 'fro', 'nuc'}, optional) – Order of the norm (see table under
Notes
). inf means numpy’s inf object. The default is 2. - axis ({None, int, 2-tuple of ints}, optional.) – If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The default is None. .. versionadded:: 1.8.0
- keepdims (bool, optional) – If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original x. .. versionadded:: 1.10.0
Returns: n – Norm of the matrix or vector(s).
Return type: float or ndarray
See also
scipy.linalg.norm()
- Similar function in SciPy.
Notes
For values of
ord < 1
, the result is, strictly speaking, not a mathematical ‘norm’, but it may still be useful for various numerical purposes. The following norms can be calculated: ===== ============================ ========================== ord norm for matrices norm for vectors ===== ============================ ========================== None Frobenius norm 2-norm ‘fro’ Frobenius norm – ‘nuc’ nuclear norm – inf max(sum(abs(x), axis=1)) max(abs(x)) -inf min(sum(abs(x), axis=1)) min(abs(x)) 0 – sum(x != 0) 1 max(sum(abs(x), axis=0)) as below -1 min(sum(abs(x), axis=0)) as below 2 2-norm (largest sing. value) as below -2 smallest singular value as below other – sum(abs(x)**ord)**(1./ord) ===== ============================ ========================== The Frobenius norm is given by [1]_:\(||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}\)The nuclear norm is the sum of the singular values. Both the Frobenius and nuclear norm orders are only defined for matrices and raise a ValueError when
x.ndim != 2
.References
[1] G. H. Golub and C. F. Van Loan, Matrix Computations, Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 Examples
>>> from numpy import linalg as LA >>> a = np.arange(9) - 4 >>> a array([-4, -3, -2, ..., 2, 3, 4]) >>> b = a.reshape((3, 3)) >>> b array([[-4, -3, -2], [-1, 0, 1], [ 2, 3, 4]]) >>> LA.norm(a) 7.745966692414834 >>> LA.norm(b) 7.745966692414834 >>> LA.norm(b, 'fro') 7.745966692414834 >>> LA.norm(a, np.inf) 4.0 >>> LA.norm(b, np.inf) 9.0 >>> LA.norm(a, -np.inf) 0.0 >>> LA.norm(b, -np.inf) 2.0 >>> LA.norm(a, 1) 20.0 >>> LA.norm(b, 1) 7.0 >>> LA.norm(a, -1) -4.6566128774142013e-010 >>> LA.norm(b, -1) 6.0 >>> LA.norm(a, 2) 7.745966692414834 >>> LA.norm(b, 2) 7.3484692283495345 >>> LA.norm(a, -2) 0.0 >>> LA.norm(b, -2) 1.8570331885190563e-016 # may vary >>> LA.norm(a, 3) 5.8480354764257312 # may vary >>> LA.norm(a, -3) 0.0 Using the `axis` argument to compute vector norms: >>> c = np.array([[ 1, 2, 3], ... [-1, 1, 4]]) >>> LA.norm(c, axis=0) array([ 1.41421356, 2.23606798, 5. ]) >>> LA.norm(c, axis=1) array([ 3.74165739, 4.24264069]) >>> LA.norm(c, ord=1, axis=1) array([ 6., 6.]) Using the `axis` argument to compute matrix norms: >>> m = np.arange(8).reshape(2,2,2) >>> LA.norm(m, axis=(1,2)) array([ 3.74165739, 11.22497216]) >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) (3.7416573867739413, 11.224972160321824)
- x (array_like) – Input array. If axis is None, x must be 1-D or 2-D, unless ord
is None. If both axis and ord are None, the 2-norm of
-
symjax.tensor.linalg.
pinv
(a, rcond=None)[source]¶ Compute the (Moore-Penrose) pseudo-inverse of a matrix.
LAX-backend implementation of
pinv()
. It differs only in default value of rcond. In numpy.linalg.pinv, the default rcond is 1e-15. Here the default is 10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps.Original docstring below.
Calculate the generalized inverse of a matrix using its singular-value decomposition (SVD) and including all large singular values.
Changed in version 1.14: Can now operate on stacks of matrices
Parameters: - a ((.., M, N) array_like) – Matrix or stack of matrices to be pseudo-inverted.
- rcond ((..) array_like of float) – Cutoff for small singular values.
Singular values less than or equal to
rcond * largest_singular_value
are set to zero. Broadcasts against the stack of matrices.
Returns: B – The pseudo-inverse of a. If a is a matrix instance, then so is B.
Return type: (.., N, M) ndarray
Raises: LinAlgError
– If the SVD computation does not converge.See also
scipy.linalg.pinv()
- Similar function in SciPy.
scipy.linalg.pinv2()
- Similar function in SciPy (SVD-based).
scipy.linalg.pinvh()
- Compute the (Moore-Penrose) pseudo-inverse of a Hermitian matrix.
Notes
The pseudo-inverse of a matrix A, denoted \(A^+\), is defined as: “the matrix that ‘solves’ [the least-squares problem] \(Ax = b\),” i.e., if \(\bar{x}\) is said solution, then \(A^+\) is that matrix such that \(\bar{x} = A^+b\).
It can be shown that if \(Q_1 \Sigma Q_2^T = A\) is the singular value decomposition of A, then \(A^+ = Q_2 \Sigma^+ Q_1^T\), where \(Q_{1,2}\) are orthogonal matrices, \(\Sigma\) is a diagonal matrix consisting of A’s so-called singular values, (followed, typically, by zeros), and then \(\Sigma^+\) is simply the diagonal matrix consisting of the reciprocals of A’s singular values (again, followed by zeros). [1]_
References
[1] G. Strang, Linear Algebra and Its Applications, 2nd Ed., Orlando, FL, Academic Press, Inc., 1980, pp. 139-142. Examples
The following example checks that
a * a+ * a == a
anda+ * a * a+ == a+
:>>> a = np.random.randn(9, 6) >>> B = np.linalg.pinv(a) >>> np.allclose(a, np.dot(a, np.dot(B, a))) True >>> np.allclose(B, np.dot(B, np.dot(a, B))) True
-
symjax.tensor.linalg.
qr
(a, mode='reduced')[source]¶ Compute the qr factorization of a matrix.
LAX-backend implementation of
qr()
. Original docstring below.Factor the matrix a as qr, where q is orthonormal and r is upper-triangular.
Parameters: - a (array_like, shape (M, N)) – Matrix to be factored.
- mode ({'reduced', 'complete', 'r', 'raw'}, optional) – If K = min(M, N), then
Returns: - q (ndarray of float or complex, optional) – A matrix with orthonormal columns. When mode = ‘complete’ the result is an orthogonal/unitary matrix depending on whether or not a is real/complex. The determinant may be either +/- 1 in that case.
- r (ndarray of float or complex, optional) – The upper-triangular matrix.
- (h, tau) (ndarrays of np.double or np.cdouble, optional) – The array h contains the Householder reflectors that generate q along with r. The tau array contains scaling factors for the reflectors. In the deprecated ‘economic’ mode only h is returned.
Raises: LinAlgError
– If factoring fails.See also
scipy.linalg.qr()
- Similar function in SciPy.
scipy.linalg.rq()
- Compute RQ decomposition of a matrix.
Notes
This is an interface to the LAPACK routines
dgeqrf
,zgeqrf
,dorgqr
, andzungqr
.For more information on the qr factorization, see for example: https://en.wikipedia.org/wiki/QR_factorization
Subclasses of ndarray are preserved except for the ‘raw’ mode. So if a is of type matrix, all the return values will be matrices too.
New ‘reduced’, ‘complete’, and ‘raw’ options for mode were added in NumPy 1.8.0 and the old option ‘full’ was made an alias of ‘reduced’. In addition the options ‘full’ and ‘economic’ were deprecated. Because ‘full’ was the previous default and ‘reduced’ is the new default, backward compatibility can be maintained by letting mode default. The ‘raw’ option was added so that LAPACK routines that can multiply arrays by q using the Householder reflectors can be used. Note that in this case the returned arrays are of type np.double or np.cdouble and the h array is transposed to be FORTRAN compatible. No routines using the ‘raw’ return are currently exposed by numpy, but some are available in lapack_lite and just await the necessary work.
Examples
>>> a = np.random.randn(9, 6) >>> q, r = np.linalg.qr(a) >>> np.allclose(a, np.dot(q, r)) # a does equal qr True >>> r2 = np.linalg.qr(a, mode='r') >>> np.allclose(r, r2) # mode='r' returns the same r as mode='full' True
Example illustrating a common use of qr: solving of least squares problems
What are the least-squares-best m and y0 in
y = y0 + mx
for the following data: {(0,1), (1,0), (1,2), (2,1)}. (Graph the points and you’ll see that it should be y0 = 0, m = 1.) The answer is provided by solving the over-determined matrix equationAx = b
, where:A = array([[0, 1], [1, 1], [1, 1], [2, 1]]) x = array([[y0], [m]]) b = array([[1], [0], [2], [1]])
If A = qr such that q is orthonormal (which is always possible via Gram-Schmidt), then
x = inv(r) * (q.T) * b
. (In numpy practice, however, we simply use lstsq.)>>> A = np.array([[0, 1], [1, 1], [1, 1], [2, 1]]) >>> A array([[0, 1], [1, 1], [1, 1], [2, 1]]) >>> b = np.array([1, 0, 2, 1]) >>> q, r = np.linalg.qr(A) >>> p = np.dot(q.T, b) >>> np.dot(np.linalg.inv(r), p) array([ 1.1e-16, 1.0e+00])
-
symjax.tensor.linalg.
slogdet
(a)[source]¶ Compute the sign and (natural) logarithm of the determinant of an array.
LAX-backend implementation of
slogdet()
. Original docstring below.If an array has a very small or very large determinant, then a call to det may overflow or underflow. This routine is more robust against such issues, because it computes the logarithm of the determinant rather than the determinant itself.
Returns: - sign ((…) array_like) – A number representing the sign of the determinant. For a real matrix, this is 1, 0, or -1. For a complex matrix, this is a complex number with absolute value 1 (i.e., it is on the unit circle), or else 0.
- logdet ((…) array_like) – The natural log of the absolute value of the determinant.
- If the determinant is zero, then sign will be 0 and logdet will be
- -Inf. In all cases, the determinant is equal to
sign * np.exp(logdet)
.
See also
det()
Notes
New in version 1.8.0.
Broadcasting rules apply, see the numpy.linalg documentation for details.
New in version 1.6.0.
The determinant is computed via LU factorization using the LAPACK routine
z/dgetrf
.Examples
The determinant of a 2-D array
[[a, b], [c, d]]
isad - bc
:>>> a = np.array([[1, 2], [3, 4]]) >>> (sign, logdet) = np.linalg.slogdet(a) >>> (sign, logdet) (-1, 0.69314718055994529) # may vary >>> sign * np.exp(logdet) -2.0
Computing log-determinants for a stack of matrices:
>>> a = np.array([ [[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]] ]) >>> a.shape (3, 2, 2) >>> sign, logdet = np.linalg.slogdet(a) >>> (sign, logdet) (array([-1., -1., -1.]), array([ 0.69314718, 1.09861229, 2.07944154])) >>> sign * np.exp(logdet) array([-2., -3., -8.])
This routine succeeds where ordinary det does not:
>>> np.linalg.det(np.eye(500) * 0.1) 0.0 >>> np.linalg.slogdet(np.eye(500) * 0.1) (1, -1151.2925464970228)
-
symjax.tensor.linalg.
solve
(a, b)[source]¶ Solve a linear matrix equation, or system of linear scalar equations.
LAX-backend implementation of
solve()
. Original docstring below.Computes the “exact” solution, x, of the well-determined, i.e., full rank, linear matrix equation ax = b.
Returns: x – Solution to the system a x = b. Returned shape is identical to b. Return type: {(.., M,), (.., M, K)} ndarray Raises: LinAlgError
– If a is singular or not square.See also
scipy.linalg.solve()
- Similar function in SciPy.
Notes
New in version 1.8.0.
Broadcasting rules apply, see the numpy.linalg documentation for details.
The solutions are computed using LAPACK routine
_gesv
.a must be square and of full-rank, i.e., all rows (or, equivalently, columns) must be linearly independent; if either is not true, use lstsq for the least-squares best “solution” of the system/equation.
References
[1] G. Strang, Linear Algebra and Its Applications, 2nd Ed., Orlando, FL, Academic Press, Inc., 1980, pg. 22. Examples
Solve the system of equations
3 * x0 + x1 = 9
andx0 + 2 * x1 = 8
:>>> a = np.array([[3,1], [1,2]]) >>> b = np.array([9,8]) >>> x = np.linalg.solve(a, b) >>> x array([2., 3.])
Check that the solution is correct:
>>> np.allclose(np.dot(a, x), b) True
-
symjax.tensor.linalg.
svd
(a, full_matrices=True, compute_uv=True)[source]¶ Singular Value Decomposition.
LAX-backend implementation of
svd()
. Original docstring below.When a is a 2D array, it is factorized as
u @ np.diag(s) @ vh = (u * s) @ vh
, where u and vh are 2D unitary arrays and s is a 1D array of a’s singular values. When a is higher-dimensional, SVD is applied in stacked mode as explained below.Parameters: - a ((.., M, N) array_like) – A real or complex array with
a.ndim >= 2
. - full_matrices (bool, optional) – If True (default), u and vh have the shapes
(..., M, M)
and(..., N, N)
, respectively. Otherwise, the shapes are(..., M, K)
and(..., K, N)
, respectively, whereK = min(M, N)
. - compute_uv (bool, optional) – Whether or not to compute u and vh in addition to s. True by default.
Returns: - u ({ (…, M, M), (…, M, K) } array) – Unitary array(s). The first
a.ndim - 2
dimensions have the same size as those of the input a. The size of the last two dimensions depends on the value of full_matrices. Only returned when compute_uv is True. - s ((…, K) array) – Vector(s) with the singular values, within each vector sorted in
descending order. The first
a.ndim - 2
dimensions have the same size as those of the input a. - vh ({ (…, N, N), (…, K, N) } array) – Unitary array(s). The first
a.ndim - 2
dimensions have the same size as those of the input a. The size of the last two dimensions depends on the value of full_matrices. Only returned when compute_uv is True.
Raises: LinAlgError
– If SVD computation does not converge.See also
scipy.linalg.svd()
- Similar function in SciPy.
scipy.linalg.svdvals()
- Compute singular values of a matrix.
Notes
Changed in version 1.8.0: Broadcasting rules apply, see the numpy.linalg documentation for details.
The decomposition is performed using LAPACK routine
_gesdd
.SVD is usually described for the factorization of a 2D matrix \(A\). The higher-dimensional case will be discussed below. In the 2D case, SVD is written as \(A = U S V^H\), where \(A = a\), \(U= u\), \(S= \mathtt{np.diag}(s)\) and \(V^H = vh\). The 1D array s contains the singular values of a and u and vh are unitary. The rows of vh are the eigenvectors of \(A^H A\) and the columns of u are the eigenvectors of \(A A^H\). In both cases the corresponding (possibly non-zero) eigenvalues are given by
s**2
.If a has more than two dimensions, then broadcasting rules apply, as explained in routines.linalg-broadcasting. This means that SVD is working in “stacked” mode: it iterates over all indices of the first
a.ndim - 2
dimensions and for each combination SVD is applied to the last two indices. The matrix a can be reconstructed from the decomposition with either(u * s[..., None, :]) @ vh
oru @ (s[..., None] * vh)
. (The@
operator can be replaced by the functionnp.matmul
for python versions below 3.5.)If a is a
matrix
object (as opposed to anndarray
), then so are all the return values.Examples
>>> a = np.random.randn(9, 6) + 1j*np.random.randn(9, 6) >>> b = np.random.randn(2, 7, 8, 3) + 1j*np.random.randn(2, 7, 8, 3)
Reconstruction based on full SVD, 2D case:
>>> u, s, vh = np.linalg.svd(a, full_matrices=True) >>> u.shape, s.shape, vh.shape ((9, 9), (6,), (6, 6)) >>> np.allclose(a, np.dot(u[:, :6] * s, vh)) True >>> smat = np.zeros((9, 6), dtype=complex) >>> smat[:6, :6] = np.diag(s) >>> np.allclose(a, np.dot(u, np.dot(smat, vh))) True
Reconstruction based on reduced SVD, 2D case:
>>> u, s, vh = np.linalg.svd(a, full_matrices=False) >>> u.shape, s.shape, vh.shape ((9, 6), (6,), (6, 6)) >>> np.allclose(a, np.dot(u * s, vh)) True >>> smat = np.diag(s) >>> np.allclose(a, np.dot(u, np.dot(smat, vh))) True
Reconstruction based on full SVD, 4D case:
>>> u, s, vh = np.linalg.svd(b, full_matrices=True) >>> u.shape, s.shape, vh.shape ((2, 7, 8, 8), (2, 7, 3), (2, 7, 3, 3)) >>> np.allclose(b, np.matmul(u[..., :3] * s[..., None, :], vh)) True >>> np.allclose(b, np.matmul(u[..., :3], s[..., None] * vh)) True
Reconstruction based on reduced SVD, 4D case:
>>> u, s, vh = np.linalg.svd(b, full_matrices=False) >>> u.shape, s.shape, vh.shape ((2, 7, 8, 3), (2, 7, 3), (2, 7, 3, 3)) >>> np.allclose(b, np.matmul(u * s[..., None, :], vh)) True >>> np.allclose(b, np.matmul(u, s[..., None] * vh)) True
- a ((.., M, N) array_like) – A real or complex array with
-
symjax.tensor.linalg.
tensorinv
(a, ind=2)[source]¶ Compute the ‘inverse’ of an N-dimensional array.
LAX-backend implementation of
tensorinv()
. Original docstring below.The result is an inverse for a relative to the tensordot operation
tensordot(a, b, ind)
, i. e., up to floating-point accuracy,tensordot(tensorinv(a), a, ind)
is the “identity” tensor for the tensordot operation.Parameters: - a (array_like) – Tensor to ‘invert’. Its shape must be ‘square’, i. e.,
prod(a.shape[:ind]) == prod(a.shape[ind:])
. - ind (int, optional) – Number of first indices that are involved in the inverse sum. Must be a positive integer, default is 2.
Returns: b – a’s tensordot inverse, shape
a.shape[ind:] + a.shape[:ind]
.Return type: ndarray
Raises: LinAlgError
– If a is singular or not ‘square’ (in the above sense).See also
numpy.tensordot()
,tensorsolve()
Examples
>>> a = np.eye(4*6) >>> a.shape = (4, 6, 8, 3) >>> ainv = np.linalg.tensorinv(a, ind=2) >>> ainv.shape (8, 3, 4, 6) >>> b = np.random.randn(4, 6) >>> np.allclose(np.tensordot(ainv, b), np.linalg.tensorsolve(a, b)) True
>>> a = np.eye(4*6) >>> a.shape = (24, 8, 3) >>> ainv = np.linalg.tensorinv(a, ind=1) >>> ainv.shape (8, 3, 24) >>> b = np.random.randn(24) >>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) True
- a (array_like) – Tensor to ‘invert’. Its shape must be ‘square’, i. e.,
-
symjax.tensor.linalg.
tensorsolve
(a, b, axes=None)[source]¶ Solve the tensor equation
a x = b
for x.LAX-backend implementation of
tensorsolve()
. Original docstring below.It is assumed that all indices of x are summed over in the product, together with the rightmost indices of a, as is done in, for example,
tensordot(a, x, axes=b.ndim)
.Parameters: - a (array_like) – Coefficient tensor, of shape
b.shape + Q
. Q, a tuple, equals the shape of that sub-tensor of a consisting of the appropriate number of its rightmost indices, and must be such thatprod(Q) == prod(b.shape)
(in which sense a is said to be ‘square’). - b (array_like) – Right-hand tensor, which can be of any shape.
- axes (tuple of ints, optional) – Axes in a to reorder to the right, before inversion. If None (default), no reordering is done.
Returns: x
Return type: ndarray, shape Q
Raises: LinAlgError
– If a is singular or not ‘square’ (in the above sense).See also
numpy.tensordot()
,tensorinv()
,numpy.einsum()
Examples
>>> a = np.eye(2*3*4) >>> a.shape = (2*3, 4, 2, 3, 4) >>> b = np.random.randn(2*3, 4) >>> x = np.linalg.tensorsolve(a, b) >>> x.shape (2, 3, 4) >>> np.allclose(np.tensordot(a, x, axes=3), b) True
- a (array_like) – Coefficient tensor, of shape
-
symjax.tensor.linalg.
cholesky
(a, lower=False, overwrite_a=False, check_finite=True)[source]¶ Compute the Cholesky decomposition of a matrix.
LAX-backend implementation of
cholesky()
. Original docstring below.Returns the Cholesky decomposition, \(A = L L^*\) or \(A = U^* U\) of a Hermitian positive-definite matrix A.
Parameters: - a ((M, M) array_like) – Matrix to be decomposed
- lower (bool, optional) – Whether to compute the upper- or lower-triangular Cholesky factorization. Default is upper-triangular.
- overwrite_a (bool, optional) – Whether to overwrite data in a (may improve performance).
- check_finite (bool, optional) – Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns: c – Upper- or lower-triangular Cholesky factor of a.
Return type: (M, M) ndarray
Raises: LinAlgError : if decomposition fails.
Examples
>>> from scipy.linalg import cholesky >>> a = np.array([[1,-2j],[2j,5]]) >>> L = cholesky(a, lower=True) >>> L array([[ 1.+0.j, 0.+0.j], [ 0.+2.j, 1.+0.j]]) >>> L @ L.T.conj() array([[ 1.+0.j, 0.-2.j], [ 0.+2.j, 5.+0.j]])
-
symjax.tensor.linalg.
block_diag
(*arrs)[source]¶ Create a block diagonal matrix from provided arrays.
LAX-backend implementation of
block_diag()
. Original docstring below.Given the inputs A, B and C, the output will have these arrays arranged on the diagonal:
[[A, 0, 0], [0, B, 0], [0, 0, C]]
Returns: D – Array with A, B, C, … on the diagonal. D has the same dtype as A. Return type: ndarray Notes
If all the input arrays are square, the output is known as a block diagonal matrix.
Empty sequences (i.e., array-likes of zero size) will not be ignored. Noteworthy, both [] and [[]] are treated as matrices with shape
(1,0)
.Examples
>>> from scipy.linalg import block_diag >>> A = [[1, 0], ... [0, 1]] >>> B = [[3, 4, 5], ... [6, 7, 8]] >>> C = [[7]] >>> P = np.zeros((2, 0), dtype='int32') >>> block_diag(A, B, C) array([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 3, 4, 5, 0], [0, 0, 6, 7, 8, 0], [0, 0, 0, 0, 0, 7]]) >>> block_diag(A, P, B, C) array([[1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 3, 4, 5, 0], [0, 0, 6, 7, 8, 0], [0, 0, 0, 0, 0, 7]]) >>> block_diag(1.0, [2, 3], [[4, 5], [6, 7]]) array([[ 1., 0., 0., 0., 0.], [ 0., 2., 3., 0., 0.], [ 0., 0., 0., 4., 5.], [ 0., 0., 0., 6., 7.]])
-
symjax.tensor.linalg.
cho_solve
(c_and_lower, b, overwrite_b=False, check_finite=True)[source]¶ Solve the linear equations A x = b, given the Cholesky factorization of A.
LAX-backend implementation of
cho_solve()
. Original docstring below.- (c, lower) : tuple, (array, bool)
- Cholesky factorization of a, as given by cho_factor
- b : array
- Right-hand side
- overwrite_b : bool, optional
- Whether to overwrite data in b (may improve performance)
- check_finite : bool, optional
- Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
- x : array
- The solution to the system A x = b
cho_factor : Cholesky factorization of a matrix
>>> from scipy.linalg import cho_factor, cho_solve >>> A = np.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]]) >>> c, low = cho_factor(A) >>> x = cho_solve((c, low), [1, 1, 1, 1]) >>> np.allclose(A @ x - [1, 1, 1, 1], np.zeros(4)) True
-
symjax.tensor.linalg.
eigh
(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, overwrite_b=False, turbo=True, eigvals=None, type=1, check_finite=True)[source] - Solve a standard or generalized eigenvalue problem for a complex
- Hermitian or real symmetric matrix.
LAX-backend implementation of
eigh()
. Original docstring below.Find eigenvalues array
w
and optionally eigenvectors arrayv
of arraya
, whereb
is positive definite such that for every eigenvalue λ (i-th entry of w) and its eigenvectorvi
(i-th column ofv
) satisfies:a @ vi = λ * b @ vi vi.conj().T @ a @ vi = λ vi.conj().T @ b @ vi = 1
In the standard problem,
b
is assumed to be the identity matrix.Parameters: - a ((M, M) array_like) – A complex Hermitian or real symmetric matrix whose eigenvalues and eigenvectors will be computed.
- b ((M, M) array_like, optional) – A complex Hermitian or real symmetric definite positive matrix in. If omitted, identity matrix is assumed.
- lower (bool, optional) – Whether the pertinent array data is taken from the lower or upper
triangle of
a
and, if applicable,b
. (Default: lower) - eigvals_only (bool, optional) – Whether to calculate only eigenvalues and no eigenvectors. (Default: both are calculated)
- type (int, optional) – For the generalized problems, this keyword specifies the problem type
to be solved for
w
andv
(only takes 1, 2, 3 as possible inputs): - overwrite_a (bool, optional) – Whether to overwrite data in
a
(may improve performance). Default is False. - overwrite_b (bool, optional) – Whether to overwrite data in
b
(may improve performance). Default is False. - check_finite (bool, optional) – Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
- turbo (bool, optional) – Deprecated since v1.5.0, use ``driver=gvd`` keyword instead. Use divide and conquer algorithm (faster but expensive in memory, only for generalized eigenvalue problem and if full set of eigenvalues are requested.). Has no significant effect if eigenvectors are not requested.
- eigvals (tuple (lo, hi), optional) – Deprecated since v1.5.0, use ``subset_by_index`` keyword instead. Indexes of the smallest and largest (in ascending order) eigenvalues and corresponding eigenvectors to be returned: 0 <= lo <= hi <= M-1. If omitted, all eigenvalues and eigenvectors are returned.
Returns: - w ((N,) ndarray) – The N (1<=N<=M) selected eigenvalues, in ascending order, each repeated according to its multiplicity.
- v ((M, N) ndarray) – (if
eigvals_only == False
)
Raises: LinAlgError
– If eigenvalue computation does not converge, an error occurred, or b matrix is not definite positive. Note that if input matrices are not symmetric or Hermitian, no error will be reported but results will be wrong.See also
eigvalsh()
- eigenvalues of symmetric or Hermitian arrays
eig()
- eigenvalues and right eigenvectors for non-symmetric arrays
eigh_tridiagonal()
- eigenvalues and right eiegenvectors for symmetric/Hermitian tridiagonal matrices
Notes
This function does not check the input array for being hermitian/symmetric in order to allow for representing arrays with only their upper/lower triangular parts. Also, note that even though not taken into account, finiteness check applies to the whole array and unaffected by “lower” keyword.
This function uses LAPACK drivers for computations in all possible keyword combinations, prefixed with
sy
if arrays are real andhe
if complex, e.g., a float array with “evr” driver is solved via “syevr”, complex arrays with “gvx” driver problem is solved via “hegvx” etc.As a brief summary, the slowest and the most robust driver is the classical
<sy/he>ev
which uses symmetric QR.<sy/he>evr
is seen as the optimal choice for the most general cases. However, there are certain occassions that<sy/he>evd
computes faster at the expense of more memory usage.<sy/he>evx
, while still being faster than<sy/he>ev
, often performs worse than the rest except when very few eigenvalues are requested for large arrays though there is still no performance guarantee.For the generalized problem, normalization with respoect to the given type argument:
type 1 and 3 : v.conj().T @ a @ v = w type 2 : inv(v).conj().T @ a @ inv(v) = w type 1 or 2 : v.conj().T @ b @ v = I type 3 : v.conj().T @ inv(b) @ v = I
Examples
>>> from scipy.linalg import eigh >>> A = np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]) >>> w, v = eigh(A) >>> np.allclose(A @ v - v @ np.diag(w), np.zeros((4, 4))) True
Request only the eigenvalues
>>> w = eigh(A, eigvals_only=True)
Request eigenvalues that are less than 10.
>>> A = np.array([[34, -4, -10, -7, 2], ... [-4, 7, 2, 12, 0], ... [-10, 2, 44, 2, -19], ... [-7, 12, 2, 79, -34], ... [2, 0, -19, -34, 29]]) >>> eigh(A, eigvals_only=True, subset_by_value=[-np.inf, 10]) array([6.69199443e-07, 9.11938152e+00])
Request the largest second eigenvalue and its eigenvector
>>> w, v = eigh(A, subset_by_index=[1, 1]) >>> w array([9.11938152]) >>> v.shape # only a single column is returned (5, 1)
-
symjax.tensor.linalg.
expm
(A, *, upper_triangular=False, max_squarings=16)[source]¶ Compute the matrix exponential using Pade approximation.
LAX-backend implementation of
expm()
.In addition to the original NumPy argument(s) listed below, also supports the optional boolean argument
upper_triangular
to specify whether theA
matrix is upper triangular, and the optional argumentmax_squarings
to specify the max number of squarings allowed in the scaling-and-squaring approximation method. Return nan if the actual number of squarings required is more thanmax_squarings
.The number of required squarings = max(0, ceil(log2(norm(A)) - c) where norm() denotes the L1 norm, and
c=2.42 for float64 or complex128, c=1.97 for float32 or complex64Original docstring below.
Parameters: A ((N, N) array_like or sparse matrix) – Matrix to be exponentiated. Returns: expm – Matrix exponential of A. Return type: (N, N) ndarray References
[1] Awad H. Al-Mohy and Nicholas J. Higham (2009) “A New Scaling and Squaring Algorithm for the Matrix Exponential.” SIAM Journal on Matrix Analysis and Applications. 31 (3). pp. 970-989. ISSN 1095-7162 Examples
>>> from scipy.linalg import expm, sinm, cosm
Matrix version of the formula exp(0) = 1:
>>> expm(np.zeros((2,2))) array([[ 1., 0.], [ 0., 1.]])
Euler’s identity (exp(i*theta) = cos(theta) + i*sin(theta)) applied to a matrix:
>>> a = np.array([[1.0, 2.0], [-1.0, 3.0]]) >>> expm(1j*a) array([[ 0.42645930+1.89217551j, -2.13721484-0.97811252j], [ 1.06860742+0.48905626j, -1.71075555+0.91406299j]]) >>> cosm(a) + 1j*sinm(a) array([[ 0.42645930+1.89217551j, -2.13721484-0.97811252j], [ 1.06860742+0.48905626j, -1.71075555+0.91406299j]])
-
symjax.tensor.linalg.
inv
(a, overwrite_a=False, check_finite=True)[source] Compute the inverse of a matrix.
LAX-backend implementation of
inv()
. Original docstring below.Parameters: - a (array_like) – Square matrix to be inverted.
- overwrite_a (bool, optional) – Discard data in a (may improve performance). Default is False.
- check_finite (bool, optional) – Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns: ainv – Inverse of the matrix a.
Return type: ndarray
Raises: LinAlgError
– If a is singular.ValueError
– If a is not square, or not 2D.
Examples
>>> from scipy import linalg >>> a = np.array([[1., 2.], [3., 4.]]) >>> linalg.inv(a) array([[-2. , 1. ], [ 1.5, -0.5]]) >>> np.dot(a, linalg.inv(a)) array([[ 1., 0.], [ 0., 1.]])
-
symjax.tensor.linalg.
lu
(a, permute_l=False, overwrite_a=False, check_finite=True)[source]¶ Compute pivoted LU decomposition of a matrix.
LAX-backend implementation of
lu()
. Original docstring below.The decomposition is:
A = P L U
where P is a permutation matrix, L lower triangular with unit diagonal elements, and U upper triangular.
- a : (M, N) array_like
- Array to decompose
- permute_l : bool, optional
- Perform the multiplication P*L (Default: do not permute)
- overwrite_a : bool, optional
- Whether to overwrite data in a (may improve performance)
- check_finite : bool, optional
- Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
(If permute_l == False)
- p : (M, M) ndarray
- Permutation matrix
- l : (M, K) ndarray
- Lower triangular or trapezoidal matrix with unit diagonal. K = min(M, N)
- u : (K, N) ndarray
- Upper triangular or trapezoidal matrix
(If permute_l == True)
- pl : (M, K) ndarray
- Permuted L matrix. K = min(M, N)
- u : (K, N) ndarray
- Upper triangular or trapezoidal matrix
This is a LU factorization routine written for SciPy.
>>> from scipy.linalg import lu >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) >>> p, l, u = lu(A) >>> np.allclose(A - p @ l @ u, np.zeros((4, 4))) True
-
symjax.tensor.linalg.
lu_factor
(a, overwrite_a=False, check_finite=True)[source]¶ Compute pivoted LU decomposition of a matrix.
LAX-backend implementation of
lu_factor()
. Original docstring below.The decomposition is:
A = P L U
where P is a permutation matrix, L lower triangular with unit diagonal elements, and U upper triangular.
Parameters: - a ((M, M) array_like) – Matrix to decompose
- overwrite_a (bool, optional) – Whether to overwrite data in A (may increase performance)
- check_finite (bool, optional) – Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns: - lu ((N, N) ndarray) – Matrix containing U in its upper triangle, and L in its lower triangle. The unit diagonal elements of L are not stored.
- piv ((N,) ndarray) – Pivot indices representing the permutation matrix P: row i of matrix was interchanged with row piv[i].
See also
lu_solve()
- solve an equation system using the LU factorization of a matrix
Notes
This is a wrapper to the
*GETRF
routines from LAPACK.Examples
>>> from scipy.linalg import lu_factor >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) >>> lu, piv = lu_factor(A) >>> piv array([2, 2, 3, 3], dtype=int32)
Convert LAPACK’s
piv
array to NumPy index and test the permutation>>> piv_py = [2, 0, 3, 1] >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu) >>> np.allclose(A[piv_py] - L @ U, np.zeros((4, 4))) True
-
symjax.tensor.linalg.
lu_solve
(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True)[source]¶ Solve an equation system, a x = b, given the LU factorization of a
LAX-backend implementation of
lu_solve()
. Original docstring below.Parameters: - b (array) – Right-hand side
- trans ({0, 1, 2}, optional) – Type of system to solve:
- overwrite_b (bool, optional) – Whether to overwrite data in b (may increase performance)
- check_finite (bool, optional) – Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns: x – Solution to the system
Return type: array
See also
lu_factor()
- LU factorize a matrix
Examples
>>> from scipy.linalg import lu_factor, lu_solve >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) >>> b = np.array([1, 1, 1, 1]) >>> lu, piv = lu_factor(A) >>> x = lu_solve((lu, piv), b) >>> np.allclose(A @ x - b, np.zeros((4,))) True
-
symjax.tensor.linalg.
solve_triangular
(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True)[source]¶ Solve the equation a x = b for x, assuming a is a triangular matrix.
LAX-backend implementation of
solve_triangular()
. Original docstring below.Parameters: - a ((M, M) array_like) – A triangular matrix
- b ((M,) or (M, N) array_like) – Right-hand side matrix in a x = b
- lower (bool, optional) – Use only data contained in the lower triangle of a. Default is to use upper triangle.
- trans ({0, 1, 2, 'N', 'T', 'C'}, optional) – Type of system to solve:
- unit_diagonal (bool, optional) – If True, diagonal elements of a are assumed to be 1 and will not be referenced.
- overwrite_b (bool, optional) – Allow overwriting data in b (may enhance performance)
- check_finite (bool, optional) – Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.
Returns: x – Solution to the system a x = b. Shape of return matches b.
Return type: (M,) or (M, N) ndarray
Raises: LinAlgError
– If a is singularNotes
New in version 0.9.0.
Examples
Solve the lower triangular system a x = b, where:
[3 0 0 0] [4] a = [2 1 0 0] b = [2] [1 0 1 0] [4] [1 1 1 1] [2]
>>> from scipy.linalg import solve_triangular >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) >>> b = np.array([4, 2, 4, 2]) >>> x = solve_triangular(a, b, lower=True) >>> x array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) >>> a.dot(x) # Check the result array([ 4., 2., 4., 2.])
-
symjax.tensor.linalg.
tril
(m, k=0)[source]¶ Make a copy of a matrix with elements above the kth diagonal zeroed.
LAX-backend implementation of
tril()
. Original docstring below.Parameters: - m (array_like) – Matrix whose elements to return
- k (int, optional) – Diagonal above which to zero elements. k == 0 is the main diagonal, k < 0 subdiagonal and k > 0 superdiagonal.
Returns: tril – Return is the same shape and type as m.
Return type: ndarray
Examples
>>> from scipy.linalg import tril >>> tril([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) array([[ 0, 0, 0], [ 4, 0, 0], [ 7, 8, 0], [10, 11, 12]])
-
symjax.tensor.linalg.
triu
(m, k=0)[source]¶ Make a copy of a matrix with elements below the kth diagonal zeroed.
LAX-backend implementation of
triu()
. Original docstring below.Parameters: - m (array_like) – Matrix whose elements to return
- k (int, optional) – Diagonal below which to zero elements. k == 0 is the main diagonal, k < 0 subdiagonal and k > 0 superdiagonal.
Returns: triu – Return matrix with zeroed elements below the kth diagonal and has same shape and type as m.
Return type: ndarray
Examples
>>> from scipy.linalg import triu >>> triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) array([[ 1, 2, 3], [ 4, 5, 6], [ 0, 8, 9], [ 0, 0, 12]])
symjax.nn
¶
Implements the machine learning/deep learning utilities to train and create/adapt any state of the art deep neural network, do training, adapt learning rates, etc
Activation functions¶
relu (x) |
Rectified linear unit activation function. |
relu6 (x) |
Rectified Linear Unit 6 activation function. |
sigmoid (x) |
Sigmoid activation function. |
softplus (x) |
Softplus activation function. |
soft_sign (x) |
Soft-sign activation function. |
silu (x) |
SiLU activation function. |
swish (x, beta) |
Swish activation function. |
log_sigmoid (x) |
Log-sigmoid activation function. |
leaky_relu (x[, negative_slope]) |
Leaky rectified linear unit activation function. |
hard_sigmoid (x) |
Hard Sigmoid activation function. |
hard_silu (x) |
Hard SiLU activation function |
hard_swish |
|
hard_tanh (x) |
Hard \(\mathrm{tanh}\) activation function. |
elu (x[, alpha]) |
Exponential linear unit activation function. |
celu (x[, alpha]) |
Continuously-differentiable exponential linear unit activation. |
selu (x) |
Scaled exponential linear unit activation. |
gelu (x, approximate) |
Gaussian error linear unit activation function. |
glu (linear_x, gated_x[, axis]) |
Gated linear unit activation function. |
Other Ops¶
softmax (x[, axis]) |
Softmax function. |
log_softmax (x[, axis]) |
Log-Softmax function. |
normalize (x[, axis, mean, variance, epsilon]) |
Normalizes an array by subtracting mean and dividing by sqrt(var). |
one_hot |
Detailed Descriptions¶
-
symjax.nn.
relu
(x)[source]¶ Rectified linear unit activation function.
Computes the element-wise function:
\[\mathrm{relu}(x) = \max(x, 0)\]
-
symjax.nn.
relu6
(x)[source]¶ Rectified Linear Unit 6 activation function.
Computes the element-wise function
\[\mathrm{relu6}(x) = \min(\max(x, 0), 6)\]
-
symjax.nn.
sigmoid
(x)[source]¶ Sigmoid activation function.
Computes the element-wise function:
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
-
symjax.nn.
softplus
(x)[source]¶ Softplus activation function.
Computes the element-wise function
\[\mathrm{softplus}(x) = \log(1 + e^x)\]
-
symjax.nn.
soft_sign
(x)[source]¶ Soft-sign activation function.
Computes the element-wise function
\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]
-
symjax.nn.
silu
(x)[source]¶ SiLU activation function.
Computes the element-wise function:
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]
-
symjax.nn.
swish
(x, beta)[source]¶ Swish activation function.
Computes the element-wise function:
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-\beta * x}}\]
-
symjax.nn.
log_sigmoid
(x)[source]¶ Log-sigmoid activation function.
Computes the element-wise function:
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
-
symjax.nn.
leaky_relu
(x, negative_slope=0.01)[source]¶ Leaky rectified linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]where \(\alpha\) =
negative_slope
.
-
symjax.nn.
hard_sigmoid
(x)[source]¶ Hard Sigmoid activation function.
Computes the element-wise function
\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]
-
symjax.nn.
hard_silu
(x)[source]¶ Hard SiLU activation function
Computes the element-wise function
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]
-
symjax.nn.
hard_tanh
(x)[source]¶ Hard \(\mathrm{tanh}\) activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & 0 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]
-
symjax.nn.
elu
(x, alpha=1.0)[source]¶ Exponential linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
-
symjax.nn.
celu
(x, alpha=1.0)[source]¶ Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]For more information, see Continuously Differentiable Exponential Linear Units.
-
symjax.nn.
selu
(x)[source]¶ Scaled exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).
For more information, see Self-Normalizing Neural Networks.
-
symjax.nn.
gelu
(x, approximate: bool = True)[source]¶ Gaussian error linear unit activation function.
If
approximate=False
, computes the element-wise function:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]If
approximate=True
, uses the approximate formulation of GELU:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]For more information, see Gaussian Error Linear Units (GELUs), section 2.
Parameters: approximate – whether to use the approximate or exact formulation.
-
symjax.nn.
softmax
(x, axis=-1)[source]¶ Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axis
sum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]Parameters: axis – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.
-
symjax.nn.
log_softmax
(x, axis=-1)[source]¶ Log-Softmax function.
Computes the logarithm of the
softmax
function, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]Parameters: axis – the axis or axes along which the log_softmax
should be computed. Either an integer or a tuple of integers.
symjax.nn.initializers
¶
This module provides all the basic initializers used in Deep Learning. All the involved operations are meant to take as input a shape of the desired weight tensor (vector, matrix, …) and will return a numpy-array.
constant (shape, value) |
|
uniform (shape[, scale]) |
Sample uniform weights U(-scale, scale). |
normal (shape[, scale]) |
Sample Gaussian weights N(0, scale). |
orthogonal (shape[, scale]) |
From Lasagne. |
glorot_uniform (shape) |
Reference: Glorot & Bengio, AISTATS 2010 |
glorot_normal (shape) |
Reference: Glorot & Bengio, AISTATS 2010 |
he_uniform (shape) |
Reference: He et al., http://arxiv.org/abs/1502.01852 |
he_normal (shape) |
Reference: He et al., http://arxiv.org/abs/1502.01852 |
lecun_uniform (shape[, name]) |
Reference: LeCun 98, Efficient Backprop http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf |
get_fans (shape) |
utility giving fan_in and fan_out of a tensor (shape). |
variance_scaling (shape, mode[, gain, …]) |
Variance Scaling initialization. |
Detailed Descriptions¶
-
symjax.nn.initializers.
uniform
(shape, scale=0.05)[source]¶ Sample uniform weights U(-scale, scale).
Parameters: - shape (tuple) –
- scale (float (default=0.05)) –
-
symjax.nn.initializers.
normal
(shape, scale=0.05)[source]¶ Sample Gaussian weights N(0, scale).
Parameters: - shape (tuple) –
- scale (float (default=0.05)) –
-
symjax.nn.initializers.
orthogonal
(shape, scale=1)[source]¶ From Lasagne. Reference: Saxe et al., http://arxiv.org/abs/1312.6120
-
symjax.nn.initializers.
he_uniform
(shape)[source]¶ Reference: He et al., http://arxiv.org/abs/1502.01852
-
symjax.nn.initializers.
he_normal
(shape)[source]¶ Reference: He et al., http://arxiv.org/abs/1502.01852
-
symjax.nn.initializers.
lecun_uniform
(shape, name=None)[source]¶ Reference: LeCun 98, Efficient Backprop http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
-
symjax.nn.initializers.
get_fans
(shape)[source]¶ utility giving fan_in and fan_out of a tensor (shape).
The concept of fan_in and fan_out helps to create weight initializers. Those quantities represent the number of units that the current weight takes as input (from the previous layer) and the number of output is produces. From those two numbers, the variance of random variables can be obtained such that the layer feature maps do not vanish or explode in amplitude.
Parameters: shape (tuple) – the shape of the tensor. For a densely connected this is (previous layer width, current layer width) and for convolutional (2D) it is (n_filters, input_channels)+ spatial shapes Returns: - fan_in (int)
- fan_out (int)
symjax.nn.layers
¶
Dense¶
Dense (input, units[, W, b, trainable_W, …]) |
Fully-connected/Dense layer |
Renormalization¶
BatchNormalization (input, axis, deterministic) |
batch-normalization layer |
Data Augmentation¶
RandomCrop (input, crop_shape, deterministic) |
random crop selection form the input |
RandomFlip (input, p, axis, deterministic[, seed]) |
random axis flip on the input |
Dropout (input, p, deterministic[, seed]) |
binary mask onto the input |
Convolution¶
Conv1D (input, n_filters, filter_length[, W, …]) |
1-D (time) convolution |
Conv2D (input, n_filters, filter_shape[, …]) |
2-D (spatial) convolution |
Pooling¶
Pool1D (input, pool_shape[, pool_type, strides]) |
2-D (spatial) pooling |
Pool2D (input, pool_shape[, pool_type, strides]) |
2-D (spatial) pooling |
Recurrent¶
RNN (sequence, init_h, units[, W, H, b, …]) |
|
GRU (sequence, init_h, units[, Wh, Uh, bh, …]) |
|
LSTM (sequence, init_h, units[, Wf, Uf, bf, …]) |
Detailed Description¶
-
class
symjax.nn.layers.
BatchNormalization
(input, axis, deterministic, const=0.001, beta_1=0.99, beta_2=0.99, W=<function ones>, b=<function zeros>, trainable_W=True, trainable_b=True)[source]¶ batch-normalization layer
- input_or_shape: shape or Tensor
- the layer input tensor or shape
- axis: list or tuple of ints
- the axis to normalize on. If using BN on a dense layer then axis should be [0] to normalize over the samples. If the layer if a convolutional layer with data format NCHW then axis should be [0, 2, 3] to normalize over the samples and spatial dimensions (commonly done)
- deterministic: bool or Tensor
- controlling the state of the layer
- const: float32 (optional)
- the constant used in the standard deviation renormalization
- beta1: flaot32 (optional)
- the parameter for the exponential moving average of the mean
- beta2: float32 (optional)
- the parameters for the exponential moving average of the std
Returns: output Return type: the layer output with attributes given by the layer options
-
class
symjax.nn.layers.
RandomCrop
(input, crop_shape, deterministic, padding=0, seed=None)[source]¶ random crop selection form the input
Random layer that will select a window of the input based on the given parameters, with the possibility to first apply a padding. This layer is commonly used as a data augmentation technique and positioned at the beginning of the deep network topology. Note that all the involved operations are GPU compatible and allow for backpropagation
Parameters: - input_or_shape (shape or Tensor) – the input of the layer or the shape of the layer input
- crop_shape (shape) – the shape of the cropped part of the input. It must have the same length as the input shape minus one for the first dimension
- deterministic (bool or Tensor) – if the layer is in deterministic mode or not
- padding (shape) – the amount of padding to apply on each dimension (except the first one) each dimension should have a couple for the before and after padding parts
- seed (seed (optional)) – to control reproducibility
Returns: output
Return type: the output tensor which containts the internal variables
-
class
symjax.nn.layers.
RandomFlip
(input, p, axis, deterministic, seed=None)[source]¶ random axis flip on the input
Random layer that will randomly flip the axis of the input. Note that all the involved operations are GPU compatible and allow for backpropagation
Parameters: - input_or_shape (shape or Tensor) – the input of the layer or the shape of the layer input
- crop_shape (shape) – the shape of the cropped part of the input. It must have the same length as the input shape minus one for the first dimension
- deterministic (bool or Tensor) – if the layer is in deterministic mode or not
- padding (shape) – the amount of padding to apply on each dimension (except the first one) each dimension should have a couple for the before and after padding parts
- seed (seed (optional)) – to control reproducibility
Returns: output
Return type: the output tensor which containts the internal variables
-
class
symjax.nn.layers.
Dropout
(input, p, deterministic, seed=None)[source]¶ binary mask onto the input
Parameters: - input_or_shape (shape or Tensor) – the layer input or shape
- p (float (0<=p<=1)) – the probability to drop the value
- deterministic (bool or Tensor) – the state of the layer
- seed (seed) – the RNG seed
Returns: output
Return type: the layer output
-
class
symjax.nn.layers.
Conv1D
(input, n_filters, filter_length, W=<function glorot_uniform>, b=<built-in function zeros>, stride=1, padding='VALID', trainable_W=True, trainable_b=True, inplace_W=False, inplace_b=False, W_preprocessor=None, b_preprocessor=None, input_dilations=None, filter_dilations=None)[source]¶ 1-D (time) convolution
perform a dense matrix multiplication and bias shifting of the input
input
n_filters
filter_length
W=initializers.glorot_uniform
b=numpy.zeros
stride=1
padding=”VALID”
trainable_W=True
trainable_b=True
inplace_W=False
inplace_b=False
W_preprocessor=None
b_preprocessor=None
input_dilations=None
filter_dilations=None
-
class
symjax.nn.layers.
Conv2D
(input, n_filters, filter_shape, padding='VALID', strides=1, W=<function glorot_uniform>, b=<built-in function zeros>, trainable_W=True, trainable_b=True, inplace_W=False, inplace_b=False, input_dilations=None, filter_dilations=None, W_preprocessor=None, b_preprocessor=None)[source]¶ 2-D (spatial) convolution
-
class
symjax.nn.layers.
Pool1D
(input, pool_shape, pool_type='MAX', strides=None)[source]¶ 2-D (spatial) pooling
-
class
symjax.nn.layers.
Pool2D
(input, pool_shape, pool_type='MAX', strides=None)[source]¶ 2-D (spatial) pooling
-
class
symjax.nn.layers.
RNN
(sequence, init_h, units, W=<function glorot_uniform>, H=<function orthogonal>, b=<function zeros>, trainable_W=True, trainable_H=True, trainable_b=True, activation=<function sigmoid>, only_last=False)[source]¶
-
class
symjax.nn.layers.
GRU
(sequence, init_h, units, Wh=<function glorot_uniform>, Uh=<function orthogonal>, bh=<function zeros>, Wz=<function glorot_uniform>, Uz=<function orthogonal>, bz=<function zeros>, Wr=<function glorot_uniform>, Ur=<function orthogonal>, br=<function zeros>, trainable_Wh=True, trainable_Uh=True, trainable_bh=True, trainable_Wz=True, trainable_Uz=True, trainable_bz=True, trainable_Wr=True, trainable_Ur=True, trainable_br=True, activation=<function sigmoid>, phi=<function _one_to_one_unop.<locals>.<lambda>>, only_last=False, gate='minimal')[source]¶
-
class
symjax.nn.layers.
LSTM
(sequence, init_h, units, Wf=<function glorot_uniform>, Uf=<function orthogonal>, bf=<function zeros>, Wi=<function glorot_uniform>, Ui=<function orthogonal>, bi=<function zeros>, Wo=<function glorot_uniform>, Uo=<function orthogonal>, bo=<function zeros>, Wc=<function glorot_uniform>, Uc=<function orthogonal>, bc=<function zeros>, trainable_Wf=True, trainable_Uf=True, trainable_bf=True, trainable_Wi=True, trainable_Ui=True, trainable_bi=True, trainable_Wo=True, trainable_Uo=True, trainable_bo=True, trainable_Wc=True, trainable_Uc=True, trainable_bc=True, activation_g=<function sigmoid>, activation_c=<function _one_to_one_unop.<locals>.<lambda>>, activation_h=<function _one_to_one_unop.<locals>.<lambda>>, only_last=False, gate='minimal')[source]¶
symjax.nn.optimizers
¶
-
symjax.nn.
optimizers
¶ alias of
symjax.nn.optimizers
-
class
symjax.nn.optimizers.
Adam
(*args, name=None, **kwargs)[source]¶ Adaptive Gradient Based Optimization with renormalization.
The update rule for variable with gradient g uses an optimization described at the end of section 2 of the paper with learning rate α.
If
amsgrad
isFalse
:initialization:
- \(m_0 = 0\) (Initialize initial 1st moment vector)
- \(v_0 = 0\) (Initialize initial 2nd moment vector)
- \(t = 0\) (Initialize timestep)
update:
- \(t = t + 1\)
- \(α_t = α × \sqrt{1 - β_2^t}/(1 - β_1^t)\)
- \(m_t = β_1 × m_{t-1} + (1 - β_1) × g\)
- \(v_t = β_2 × v_{t-1} + (1 - β_2) × g \odot g\)
- \(variable = variable - α_t × m_t / (\sqrt{v_t} + ε)\)
If
amsgrad
isTrue
:initialization:
- \(m_0 = 0\) (Initialize initial 1st moment vector)
- \(v_0 = 0\) (Initialize initial 2nd moment vector)
- \(v'_0 = 0\) (Initialize initial 2nd moment vector)
- \(t = 0\) (Initialize timestep)
update:
- \(t = t + 1\)
- \(α_t = α × \sqrt{1 - β_2^t}/(1 - β_1^t)\)
- \(m_t = β_1 × m_{t-1} + (1 - β_1) × g\)
- \(v_t = β_2 × v_{t-1} + (1 - β_2) × g \odot g\)
- \(v'_t := \max(v'_{t-1}, v_t)\)
- \(variable = variable - α_t × m_t / (\sqrt{v'_t} + ε)\)
The default value of \(\epsilon=1e-7\) might not be a good default in general. For example, when training an Inception network on ImageNet a current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the formulation just before Section 2.1 of the Kingma and Ba paper rather than the formulation in Algorithm 1, the “epsilon” referred to here is “epsilon hat” in the paper.
Parameters: - grads_or_loss (scalar tensor or list of gradients) – either the loss (scalar of Tensor type) to be differentied or the list of gradients already computed and possibly altered manually (such as clipping)
- (α) (learning_rate) – the learning rate use to update the parameters
- amsgrad (bool) – whether to use the amsgrad updates or not
- β_1 (constant or Tensor) – the value of the exponential moving average of the average of the gradients through time (updates)
- β_2 (constant or Tensor) – the value of the exponential moving average of the variance of the gradients through time
- ε (constant or Tensor) – the value added to the second order moment
- params (list (optional)) – if grads_or_loss is al list then it should be ordered w.r.t. the given parameters, if not given then the optimizer will find all variables that are traininable and involved with the given loss
-
updates
¶ Type: list of updates
-
variables
¶ Type: list of variables
-
class
symjax.nn.optimizers.
NesterovMomentum
(*args, name=None, **kwargs)[source]¶ Nesterov momentum Optimization
Parameters: - grads_or_loss (scalar tensor or list of gradients) – either the loss (scalar of Tensor type) to be differentied or the list of gradients already computed and possibly altered manually (such as clipping)
- learning_rate (constant or Tensor) – the learning rate use to update the parameters
- momentum (constant or Tensor) – the amount of momentum to be applied
- params (list (optional)) – if grads_or_loss is al list then it should be ordered w.r.t. the given parameters
-
updates
¶ Type: list of updates
-
variables
¶ Type: list of variables
-
class
symjax.nn.optimizers.
SGD
(*args, name=None, **kwargs)[source]¶ Stochastic gradient descent optimization.
Notice that SGD is also the acronym employed in
tf.keras.optimizers.SGD
and intorch.optim.sgd
but might be misleading. In fact, those and this implementation implement GD, the SGD term only applies if one performs GD optimization only using 1 (random) sample to compute the gradients. If multiple samples are used it is commonly referred as mini-batch GD and when the entire dataset is used then the optimizer is refered as GD. See an illustrative discussion here.The produced update for parameter θ and a given learning rate α is:
\[θ = θ - α ∇_{θ} L\]Parameters: - grads_or_loss (scalar tensor or list of gradients) – either the loss (scalar of Tensor type) to be differentied or the list of gradients already computed and possibly altered manually (such as clipping)
- learning_rate (constant or Tensor) – the learning rate use to update the parameters
- params (list (optional)) – if grads_or_loss is al list then it should be ordered w.r.t. the given parameters
-
updates
¶ Type: list of updates
-
variables
¶ Type: list of variables
-
symjax.nn.optimizers.
conjugate_gradients
(Ax, b)[source]¶ Conjugate gradient algorithm (see https://en.wikipedia.org/wiki/Conjugate_gradient_method)
symjax.nn.schedules
¶
Implements some schedules which consist of a Tensor variable that is updated online based on new values of some other tensors
PiecewiseConstant (init, steps_and_values) |
piecewise constant variable updating automatically |
ExponentialMovingAverage (value, alpha[, …]) |
exponential moving average of a given value |
Detailed Descriptions¶
-
class
symjax.nn.schedules.
PiecewiseConstant
[source]¶ piecewise constant variable updating automatically
This method allows to obtain a variable with an internal counter that will be updated based on the function updates, whenver this counter reaches one of the step given in the function input then the actual value of the variable becomes the one given for the associated step
Parameters: - init (float-like) – the initial value of the variable that will remain as is until a step and update is reached
- steps_and_values (dict) – the dictionnary mapping steps-> values, that is, when the number of steps reached one of the given one, the value of the variable becomes the given one associated to the reached step
Returns: variable
Return type: float-like
Example
>>> import symjax >>> symjax.current_graph().reset() >>> var = symjax.nn.schedules.PiecewiseConstant(0.1, {4:1, 8:2}) >>> # in the background, symjax automatically records that everytime >>> # a function is using this variable an udnerlying update should occur >>> print(symjax.get_updates()) {Variable(name=step, shape=(), dtype=int32, trainable=False, scope=/PiecewiseConstant/): Op(name=add, fn=add, shape=(), dtype=int32, scope=/PiecewiseConstant/)} >>> # it is up to the user to use it or not, if not used, the internal counter >>> # is never updated and this the variable never changes. >>> # example of use: >>> f = symjax.function(outputs=var, updates=symjax.get_updates()) >>> for i in range(10): ... print(i, f()) 0 0.1 1 0.1 2 0.1 3 0.1 4 1.0 5 1.0 6 1.0 7 1.0 8 2.0 9 2.0
-
class
symjax.nn.schedules.
ExponentialMovingAverage
[source]¶ exponential moving average of a given value
This method allows to obtain an EMA of a given variable (or any Tensor) with internal state automatically upating its values as new samples are observed and the internal updates are applied as part of a fuction At each iteration the new value is given by
\[v(0) = value(0) or init v(t) = v(t-1) * alpha + value(t) * (1 - alpha)\]Parameters: - value (Tensor-like) – the value to use for the EMA
- alpha (scalar) – the decay of the EMA
- init (Tensor-like (same shape as value) optional) – the initialization of the EMA, if not given uses the value allowing for unbiased estimate
- decay_min (bool) – at early stages, clip the decay to avoid erratir behaviors
Returns: - ema (Tensor-like) – the current (latest) value of the EMA incorporating information of the latest observation of value
- fixed_ema (Tensor-like) – the value of the EMA of the previous pass. This is usefull if one wants to keep the estimate of the EMA fixed for new observations, then simply do not apply anymore updates (using a new function) and using this fixed variable during testing (while ema will keep use the latest observed value)
Example
>>> import symjax >>> import numpy as np >>> np.random.seed(0) >>> symjax.current_graph().reset() >>> # suppose we want to do an EMA of a vector user-input >>> input = symjax.tensor.Placeholder((2,), 'float32') >>> ema, var = symjax.nn.schedules.ExponentialMovingAverage(input, 0.9) >>> # in the background, symjax automatically records the needed updates >>> print(symjax.get_updates()) {Variable(name=EMA, shape=(2,), dtype=float32, trainable=False, scope=/ExponentialMovingAverage/): Op(name=where, fn=where, shape=(2,), dtype=float32, scope=/ExponentialMovingAverage/), Variable(name=first_step, shape=(), dtype=bool, trainable=False, scope=/ExponentialMovingAverage/): False} >>> # example of use: >>> f = symjax.function(input, outputs=ema, updates=symjax.get_updates()) >>> for i in range(25): ... print(f(np.ones(2) + np.random.randn(2) * 0.3)) [1.5292157 1.1200472] [1.5056562 1.1752692] [1.5111173 1.1284239] [1.4885082 1.1110408] [1.4365609 1.1122546] [1.3972261 1.1446574] [1.3803346 1.1338419] [1.355617 1.1304679] [1.3648777 1.1112664] [1.3377819 1.0745169] [1.227414 1.0866737] [1.2306056 1.0557414] [1.2756376 1.0065362] [1.2494465 1.000267 ] [1.2704852 1.0443211] [1.2480851 1.0512339] [1.196643 0.9866866] [1.1665413 0.9927084] [1.186796 1.029509] [1.1564965 1.017489 ] [1.1093903 0.97313946] [1.0472631 1.0343488] [1.0272473 1.0177717] [0.9869387 1.0393193] [0.93982786 1.029005 ]
symjax.nn.losses
¶
vae (x, x_hat, q_mean, q_cov[, z_mean, …]) |
N samples of dimension D to latent space in K dimension with Gaussian distributions |
vae_gmm (x, x_hat, z_mu, z_logvar, mu, …[, …]) |
N samples of dimension D to latent space of C sluters in K dimension |
vae_comp_gmm (x, x_hat, z_mu, z_logvar, mu, …) |
N samples of dimension D to latent space of I pieces each of C sluters in K dimension |
sparse_softmax_crossentropy_logits (p, q) |
Cross entropy loss given that \(p\) is sparse and \(q\) is the log-probability. |
softmax_crossentropy_logits (p, q) |
see sparse cross entropy |
sigmoid_crossentropy_logits |
|
accuracy (targets, predictions) |
classification accuracy. |
clustering_accuracy (labels, predictions, …) |
find accuracy of clustering based on intra cluster labels |
huber (targets, predictions[, delta]) |
huber loss (regression). |
explained_variance (y, ypred[, axis, epsilon]) |
Computes fraction of variance that ypred explains about y. |
hinge_loss (predictions, targets[, delta]) |
(binary) hinge loss. |
multiclass_hinge_loss (predictions, targets) |
multi-class hinge loss. |
squared_differences (x, y) |
elementwise squared differences. |
Detailed Descriptions¶
-
symjax.nn.losses.
vae
(x, x_hat, q_mean, q_cov, z_mean=None, z_cov=None, x_cov=None)[source]¶ N samples of dimension D to latent space in K dimension with Gaussian distributions
Parameters: - x (array) – should be of shape (N, D)
- x_hat (array) – should be of shape (N, D)
- q_mean (array) – should be of shape (N, K), infered mean of variational Gaussian
- q_cov (array) – should be of shape (N, K), infered log-variance of variational Gaussian
- z_mean (array) – should be of shape (K,), mean of z variable
- z_cov (array) – should be of shape (K,), logstd of z variable
-
symjax.nn.losses.
vae_gmm
(x, x_hat, z_mu, z_logvar, mu, logvar, logpi, logvar_x=0.0, eps=1e-08)[source]¶ N samples of dimension D to latent space of C sluters in K dimension
Parameters: - x (array) – should be of shape (N, D)
- x_hat (array) – should be of shape (N, D)
- z_mu (array) – should be of shape (N, K), infered mean of variational Gaussian
- z_logvar (array) – should be of shape (N, K), infered log-variance of variational Gaussian
- mu (array) – should be of shape (C, K), parameter (centroids)
- logvar (array) – should be of shape (C, K), parameter (logvar of clusters)
- logpi (array) – should be of shape (C,), parameter (prior of clusters) :param logvar_x: :param eps:
-
symjax.nn.losses.
vae_comp_gmm
(x, x_hat, z_mu, z_logvar, mu, logvar, logpi, logvar_x=0.0, eps=1e-08)[source]¶ N samples of dimension D to latent space of I pieces each of C sluters in K dimension
Parameters: - x (array) – should be of shape (N, D)
- x_hat (array) – should be of shape (N, D)
- z_mu (array) – should be of shape (N, I, K), infered mean of variational Gaussian
- z_logvar (array) – should be of shape (N, I, K), infered log-variance of variational Gaussian
- mu (array) – should be of shape (I, C, K), parameter (centroids)
- logvar (array) – should be of shape (I, C, K), parameter (logvar of clusters)
- logpi (array) – should be of shape (I, C), parameter (prior of clusters) :param logvar_x: :param eps:
-
symjax.nn.losses.
sparse_softmax_crossentropy_logits
(p, q)[source]¶ Cross entropy loss given that \(p\) is sparse and \(q\) is the log-probability.
The formal definition given that \(p\) is now an index (of the Dirac) s.a. \(p\in \{1,\dots,D\}\) and \(q\) is unormalized (log-proba) is given by (for discrete variables, p sparse)
\[\mathcal{L}(p,q)=-q_{p}+\log(\sum_{d=1}^D \exp(q_d))\]\[\mathcal{L}(p,q)=-q_{p}+LogSumExp(q)\]\[\mathcal{L}(p,q)=-q_{p}+LogSumExp(q-\max_{d}q_d)\]or by (non p sparse)
\[\mathcal{L}(p,q)=-\sum_{d=1}^Dp_{d}q_{d}+\log(\sum_{d=1}^D \exp(q_d))\]\[\mathcal{L}(p,q)=-\sum_{d=1}^Dp_{d}q_{d}+LogSumExp(q)\]\[\mathcal{L}(p,q)=-\sum_{d=1}^Dp_{d}q_{d}+LogSumExp(q-\max_{d}q_d)\]with \(p\) the class index and \(q\) the predicted one (output of the network). This class takes two non sparse vectors which should be nonnegative and sum to one.
-
symjax.nn.losses.
accuracy
(targets, predictions)[source]¶ classification accuracy.
It is computed by averaging the 0-1 loss as in
\[(Σ_{n=1}^N 1_{\{y_n == p_n\}})/N\]where \(p\) denotes the predictions. The inputs must be vectors but in the special case where targets is a vector but predictions is a matrix, then the argmax is used to get the real predictions as in
\[(Σ_{n=1}^N 1_{\{y_n == arg \max p_{n,:}\}})/N\]Parameters: - targets (1D tensor-like) –
- predictions (tensor-like) – it can be a \(2D\) matrix in which case the
argmax
is used to get the prediction
Returns: Return type: tensor-like
-
symjax.nn.losses.
clustering_accuracy
(labels, predictions, n_clusters)[source]¶ find accuracy of clustering based on intra cluster labels
This accuracy allows to quantify the ability of a clustering algorithm to solve the clustering task given the true labels of the data. This functions finds for each predicted cluster what is the most present label and uses it as the cluster label. Based on those cluster labels the accuracy is then computed.
Args:
- labels: 1d integer Tensor
- the true labels of the data
- predictions: 1d integer Tensor
- the predicted data clusters
- n_clusters: int
- the number of clusters
-
symjax.nn.losses.
huber
(targets, predictions, delta=1.0)[source]¶ huber loss (regression).
For each value x in error=targets-predictions, the following is calculated:
- \(0.5 × x^2\) if \(|x| <= Δ\)
- \(0.5 × Δ^2 + Δ × (|x| - Δ)\) if \(|x| > Δ\)
leading to
Parameters: - targets – The ground truth output tensor, same dimensions as ‘predictions’.
- predictions – The predicted outputs.
- delta (Δ) – float, the point where the huber loss function changes from a quadratic to linear.
Returns: loss float, this has the same shape as targets
-
symjax.nn.losses.
explained_variance
(y, ypred, axis=None, epsilon=1e-06)[source]¶ Computes fraction of variance that ypred explains about y. The formula is
\[1 - Var[y-ypred] / Var[y]\]and in the special case of centered targets and predictions it becomes
\[1 - \|y-ypred\|^2_2 / \|y\|_2^2\]hence it can be seen as an :math:`ℓ_2’ loss rescaled by the energy in the targets.
interpretation:
- ev=0 => might as well have predicted zero
- ev=1 => perfect prediction
- ev<0 => worse than just predicting zero
- y: Tensor like
- true target
- ypred: Tensor like
- prediction
- axis: integer or None (default=None)
- the axis along which to compute the var, by default uses all axes
- epsilon (ϵ): float (default=1e-6)
- the added constant in the denominator
\[1 - Var(y-ypred)/(Var(y)+ϵ)\]This is not a symmetric function
-
symjax.nn.losses.
hinge_loss
(predictions, targets, delta=1)[source]¶ (binary) hinge loss.
For an intended output \(t = ±1\) and a classifier score \(p\), the hinge loss is defined for each datum as
\[\max ( 0 , Δ − t p)\]as soon as the loss is smaller than \(Δ\) the datum is well classified, however margin is increased by pushing the loss to \(0\) hence \(Δ\) is the user-defined prefered margin to reach. In standard SVM \(Δ=1\) leading to
Note that \(p\) should be the “raw” output of the classifier’s decision function, not the predicted class label. For instance, in linear SVMs, \(p = <w, x> + b\) where ( \(w , b\) are the parameters of the hyperplane and \(x\) is the input variable(s).
Parameters: - predictions (1D tensor) – prediction of the classifier (raw,)
- targets (1D binary tensor with values in \(t\in\{-1,1\}\).) –
Returns: An expression for the item-wise hinge loss
Return type: 1D tensor
Notes
This is an alternative to the categorical cross-entropy loss for classification problems
-
symjax.nn.losses.
multiclass_hinge_loss
(predictions, targets, delta=1)[source]¶ multi-class hinge loss.
\[L_i = \max_{j ≠ t_i} (0, p_j - p_{t_i} + Δ)\]Parameters: - predictions (2D tensor) – Predictions in (0, 1), such as softmax output of a neural network, with data points in rows and class probabilities in columns.
- targets (Theano 2D tensor or 1D tensor) – Either a vector of int giving the correct class index per data point or a 2D tensor of one-hot encoding of the correct class in the same layout as predictions (non-binary targets in [0, 1] do not work!)
- delta (scalar, default 1) – The hinge loss margin
Returns: An expression for the item-wise multi-class hinge loss
Return type: Theano 1D tensor
Notes
This is an alternative to the categorical cross-entropy loss for multi-class classification problems
symjax.probabilities
¶
Implementation of basic distribution, their (log) densities, sampling, KL divergence, entropies
Categorical ([probabilities, logits, eps]) |
|
Normal (mean, cov) |
(batched, multivariate) normal distribution |
KL (X, Y[, EPS]) |
Normal: distributions are specified by means and log stds. |
Detailed Descriptions¶
-
symjax.probabilities.
KL
(X, Y, EPS=1e-08)[source]¶ Normal: distributions are specified by means and log stds. (https://en.wikipedia.org/wiki/Kullback-Leibler_divergence#Multivariate_normal_distributions)
\[ \begin{align}\begin{aligned}KL(p||q)=\int [\log(p(x))-\log(q(x))]p(x)dx\\=\int[\frac{1}{2}log(\frac{|\Sigma_2|}{|\Sigma_1|})−\frac{1}{2}(x−\mu_1)^𝑇\Sigma_1^{-1}(x−\mu_1)+\frac{1}{2}(x−\mu_2)^𝑇\Sigma_2^{−1}(x−\mu_2)] p(x)dx\\=\frac{1}{2}log(\frac{|\Sigma_2|}{|\Sigma_1|})−\frac{1}{2}tr {𝐸[(x−\mu_1)(x−\mu_1)^𝑇] Σ−11}+\frac{1}{2}𝐸[(x−\mu_2)^𝑇\Sigma_2^{−1}(x−\mu_2)]\\=\frac{1}{2}log(\frac{|\Sigma_2|}{|\Sigma_1|})−\frac{1}{2}tr {𝐼𝑑}+\frac{1}{2}(\mu_1−\mu_2)^𝑇Σ_2^{-1}(\mu_1−\mu_2)+\frac{1}{2}tr{\Sigma_2^{-1}\Sigma_1}\\=\frac{1}{2}[log(\frac{|\Sigma_2|}{|\Sigma_1|})−𝑑+tr{\Sigma_2^{−1}\Sigma_1}+(\mu_2−\mu_1)^𝑇\Sigma_2^{−1}(\mu_2−\mu_1)].\end{aligned}\end{align} \]
-
class
symjax.probabilities.
Normal
(mean, cov)[source]¶ (batched, multivariate) normal distribution
Parameters: - mean (N dimensional Tensor) – the mean of the normal distribution, the last dimension is the one used to represent the dimension of the data, the first dimensions are indexed ones
- cov ((N or N+1) dimensional Tensor) – the covariance matrix, if N-dimensional then it is assumed to be diagonal, if (N+1)-dimensional then the last 2 dimensions are the ones representing the covariance dimensions and thus their shape should be equal
-
entropy
()[source]¶ Compute the differential entropy of the multivariate normal.
Returns: h – Entropy of the multivariate normal distribution Return type: scalar
-
log_prob
(x)[source]¶ Log of the multivariate normal probability density function.
Parameters: x (Tensor) – samples to use to evaluate the log pdf, with the last axis of x denoting the components. Returns: pdf – Log of the probability density function evaluated at x Return type: Tensor
symjax.rl
¶
Implementation of basic agents, environment utilites and learning policies
Buffer (maxlen[, priority_sampling, gamma, lam]) |
Buffer holding different values of experience |
run (env, agent, buffer[, rewarder, noise, …]) |
Actor (states[, actions_distribution, name]) |
actor (state to action mapping) for RL |
Critic (states[, actions]) |
REINFORCE (state_shape, actions_shape, …[, …]) |
policy gradient reinforce also called reward-to-go policy gradient |
ActorCritic (state_shape, actions_shape, …) |
this corresponds to Q actor critic or V actor critic depending on the given critic |
PPO (state_shape, actions_shape, batch_size, …) |
instead of using target networks one can record the old log probs |
DDPG (state_shape, actions_shape, batch_size, …) |
Detailed Descriptions¶
-
class
symjax.rl.utils.
Buffer
(maxlen, priority_sampling=False, gamma=0.99, lam=0.95)[source]¶ Buffer holding different values of experience
By default this contains
"reward", "reward-to-go", "V" or "Q", "action", "state", "episode", "priorities", "TD-error", "terminal", "next-state"
𝑄𝜋(𝑠,𝑎)=𝐸𝜋{𝑅𝑡|𝑠𝑡=𝑠,𝑎𝑡=𝑎}=𝐸𝜋{∑𝑘=0∞𝛾𝑘𝑟𝑡+𝑘+1|𝑠𝑡=𝑠,𝑎𝑡=𝑎} 𝑉𝜋(𝑠)=𝐸𝜋{𝑅𝑡|𝑠𝑡=𝑠}=𝐸𝜋{∑𝑘=0∞𝛾𝑘𝑟𝑡+𝑘+1|𝑠𝑡=𝑠} 𝛾∈[0,1] is called discount factor and determines if one focuses on immediate rewards (𝛾=0), the total reward (𝛾=1) or some trade-off. lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,close to 1.)
-
symjax.rl.utils.
run
(env, agent, buffer, rewarder=None, noise=None, action_processor=None, max_episode_steps=10000, max_episodes=1000, update_every=1, update_after=1, skip_frames=1, reset_each_episode=False, wait_end_path=False, eval_every=10, eval_max_episode_steps=10000, eval_max_episodes=10)[source]¶
-
class
symjax.rl.agents.
Actor
(states, actions_distribution=None, name='actor')[source]¶ actor (state to action mapping) for RL
This class implements an actor. The user must first define its own class inheriting from
Actor
and implementing only the create_network method. This method will then be used internally to instantiace the actor network.If the used distribution is symjax.probabilities.Normal then the output of the create_network method should be first the mean and then the covariance.
In general the user should not instanciate this class, instead pass the user’s inherited class (uninstanciated) to a policy-learning method.
- states: Tensor-like
- the states of the environment (batch size in first axis)
- batch_size: int
- the batch size
- actions_distribution: None or symjax.probabilities.Distribution object
- the distribution for the actions, if the policy is deterministic, then put this to None. Note, this is different than the noise parameter employed for exploration, this is simply the rv modeling of the actions used to compute probabilities of sampled actions and the likes
-
class
symjax.rl.
REINFORCE
(state_shape, actions_shape, n_episodes, episode_length, actor, lr=0.001, gamma=0.99)[source]¶ policy gradient reinforce also called reward-to-go policy gradient
the vanilla policy gradient uses the total reward of each episode as a weight. In this implementation it is the discounted rewards to go that are used. Setting
gamma
to 1 leads to the reward to go policy gradienthttps://medium.com/@thechrisyoon/deriving-policy-gradients-and-implementing-reinforce-f887949bd63
-
class
symjax.rl.
ActorCritic
(state_shape, actions_shape, n_episodes, episode_length, actor, critic, lr=0.001, gamma=0.99, train_v_iters=10)[source]¶ this corresponds to Q actor critic or V actor critic depending on the given critic
(with GAE-Lambda for advantage estimation)