Source code for symjax.nn.schedules

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from symjax import tensor as T
from ..base import current_graph, Scope
import numpy as np


[docs]def ExponentialMovingAverage( value, alpha, init=None, decay_min=False, debias=True, name="ExponentialMovingAverage", ): """exponential moving average of a given value This method allows to obtain an EMA of a given variable (or any Tensor) with internal state automatically upating its values as new samples are observed and the internal updates are applied as part of a fuction At each iteration the new value is given by .. math:: v(0) = value(0) or init v(t) = v(t-1) * alpha + value(t) * (1 - alpha) Args ---- value: Tensor-like the value to use for the EMA alpha: scalar the decay of the EMA init: Tensor-like (same shape as value) optional the initialization of the EMA, if not given uses the value allowing for unbiased estimate decay_min: bool at early stages, clip the decay to avoid erratir behaviors Returns ------- ema: Tensor-like the current (latest) value of the EMA incorporating information of the latest observation of value fixed_ema: Tensor-like the value of the EMA of the previous pass. This is usefull if one wants to keep the estimate of the EMA fixed for new observations, then simply do not apply anymore updates (using a new function) and using this fixed variable during testing (while ema will keep use the latest observed value) Example ------- .. doctest :: >>> import symjax >>> import numpy as np >>> np.random.seed(0) >>> symjax.current_graph().reset() >>> # suppose we want to do an EMA of a vector user-input >>> input = symjax.tensor.Placeholder((2,), 'float32') >>> ema, var = symjax.nn.schedules.ExponentialMovingAverage(input, 0.9) >>> # in the background, symjax automatically records the needed updates >>> print(symjax.get_updates()) {Variable(name=EMA, shape=(2,), dtype=float32, trainable=False, scope=/ExponentialMovingAverage/): Op(name=where, fn=where, shape=(2,), dtype=float32, scope=/ExponentialMovingAverage/), Variable(name=first_step, shape=(), dtype=bool, trainable=False, scope=/ExponentialMovingAverage/): False} >>> # example of use: >>> f = symjax.function(input, outputs=ema, updates=symjax.get_updates()) >>> for i in range(25): ... print(f(np.ones(2) + np.random.randn(2) * 0.3)) [1.5292157 1.1200472] [1.5056562 1.1752692] [1.5111173 1.1284239] [1.4885082 1.1110408] [1.4365609 1.1122546] [1.3972261 1.1446574] [1.3803346 1.1338419] [1.355617 1.1304679] [1.3648777 1.1112664] [1.3377819 1.0745169] [1.227414 1.0866737] [1.2306056 1.0557414] [1.2756376 1.0065362] [1.2494465 1.000267 ] [1.2704852 1.0443211] [1.2480851 1.0512339] [1.196643 0.9866866] [1.1665413 0.9927084] [1.186796 1.029509] [1.1564965 1.017489 ] [1.1093903 0.97313946] [1.0472631 1.0343488] [1.0272473 1.0177717] [0.9869387 1.0393193] [0.93982786 1.029005 ] """ with Scope(name): init = init if init is not None else T.zeros_like(value, detach=True) num_steps = T.Variable(0, trainable=False, name="num_steps", dtype="int32") var = T.Variable(init, trainable=False, dtype="float32", name="EMA") if decay_min: decay = T.minimum(alpha, (1.0 + num_steps) / (10.0 + num_steps)) else: decay = alpha ema = decay * var + (1 - decay) * value var_update = T.where(T.equal(num_steps, 0), init, ema) current_graph().add_updates({var: ema, num_steps: num_steps + 1}) if debias: debiased_ema = ema_debias(ema, init, decay, num_steps + 1) debiased_var = T.Variable( init, trainable=False, dtype="float32", name="debiased_EMA" ) current_graph().add_updates({debiased_var: debiased_ema}) if debias: return debiased_ema, debiased_var else: return ema, var
def ema_debias(biased_ema, init, decay, num_steps): """Compute the delta required for a debiased Variable. All exponential moving averages are biased to their init. This function creates the debias updated amount according to a scale factor, as in (Kingma et al., 2015): ``` EMA = init*b^(t) + c*(1 - b)*b^(t-1) + c*(1 - b)*b^(t-2) + ... = init*b^(t) + c*(1 - b^t) ``` To have the true value `c`, we would substract `init*b^(t)` and divide by the scale factor `1 - b^t`. Args: ----- biased_ema: Tensor the ema (biased) init: Tensor the init that was used to compute the ema decay: float the decay parameter num_steps: number of steps used Returns: -------- unbiased_ema """ bt = T.power(decay, num_steps) shift = init * bt scalor = 1 - bt return (biased_ema - shift) / scalor
[docs]def PiecewiseConstant(init, steps_and_values): """piecewise constant variable updating automatically This method allows to obtain a variable with an internal counter that will be updated based on the function updates, whenver this counter reaches one of the step given in the function input then the actual value of the variable becomes the one given for the associated step Args ---- init: float-like the initial value of the variable that will remain as is until a step and update is reached steps_and_values: dict the dictionnary mapping steps-> values, that is, when the number of steps reached one of the given one, the value of the variable becomes the given one associated to the reached step Returns ------- variable: float-like Example ------- .. doctest :: >>> import symjax >>> symjax.current_graph().reset() >>> var = symjax.nn.schedules.PiecewiseConstant(0.1, {4:1, 8:2}) >>> # in the background, symjax automatically records that everytime >>> # a function is using this variable an udnerlying update should occur >>> print(symjax.get_updates()) {Variable(name=step, shape=(), dtype=int32, trainable=False, scope=/PiecewiseConstant/): Op(name=add, fn=add, shape=(), dtype=int32, scope=/PiecewiseConstant/)} >>> # it is up to the user to use it or not, if not used, the internal counter >>> # is never updated and this the variable never changes. >>> # example of use: >>> f = symjax.function(outputs=var, updates=symjax.get_updates()) >>> for i in range(10): ... print(i, f()) 0 0.1 1 0.1 2 0.1 3 0.1 4 1.0 5 1.0 6 1.0 7 1.0 8 2.0 9 2.0 """ with Scope("PiecewiseConstant"): all_steps = T.stack([0] + list(steps_and_values.keys()) + [np.inf]) all_values = T.stack([init] + list(steps_and_values.values()) + [0]) step = T.Variable( 0, trainable=False, name="step", dtype="int32", ) value = all_values[(step >= all_steps).argmin() - 1] return value, step
def SimpleMovingAverage(value, n): """simple moving average Args ---- value: tensor the initial value of the variable that will remain as is until a step and update is reached n: int Returns ------- variable: float-like Example ------- .. doctest :: >>> import symjax >>> symjax.current_graph().reset() >>> var = symjax.nn.schedules.PiecewiseConstant(0.1, {4:1, 8:2}) >>> # in the background, symjax automatically records that everytime >>> # a function is using this variable an udnerlying update should occur >>> print(symjax.get_updates()) {Variable(name=step, shape=(), dtype=int32, trainable=False, scope=/PiecewiseConstant/): Op(name=add, fn=add, shape=(), dtype=int32, scope=/PiecewiseConstant/)} >>> # it is up to the user to use it or not, if not used, the internal counter >>> # is never updated and this the variable never changes. >>> # example of use: >>> f = symjax.function(outputs=var, updates=symjax.get_updates()) >>> for i in range(10): ... print(i, f()) 0 0.1 1 0.1 2 0.1 3 0.1 4 1.0 5 1.0 6 1.0 7 1.0 8 2.0 9 2.0 """ with Scope("SimpleMovingAverage"): last_values = T.Variable( np.ones((n,) + value.shape) * np.nan, trainable=False, name="n_last_values", dtype="float32", ) index = T.Variable(0, trainable=False, dtype="int32", name="index") var = T.Variable( T.zeros_like(value, detach=True), trainable=False, dtype="float32", name="SMA", ) updated = T.index_update(last_values, T.mod(index, n), value) avg = T.nanmean(updated, axis=0) current_graph().add_updates({var: avg, index: index + 1, last_values: updated}) return avg, var