Source code for symjax.nn.layers

from symjax import tensor as T
from . import ops_nn as nn
from symjax.nn import initializers, schedules
import symjax
import numpy
import jax


# IMPORTANT NOTE
# in order to make sphinx doc clean we use a hacky way to use the
# __init__ method as a staticmethod wich has an actual return ...
# not pythonic, suggestions welcome !


def create_variable(
    name,
    tensor_or_func,
    shape,
    trainable,
    inplace=False,
    dtype="float32",
    preprocessor=None,
):
    if tensor_or_func is None:
        return None

    if inplace:
        assert not callable(tensor_or_func)
        return tensor_or_func

    variable = T.Variable(
        tensor_or_func,
        name=name,
        shape=symjax.current_graph().get(shape),
        dtype=dtype,
        trainable=trainable,
    )
    if preprocessor is not None:
        return preprocessor(variable)
    else:
        return variable


class Layer(T.Tensor):
    def __new__(cls, *args, name=None, **kwargs):

        if name is None:
            name = cls.__NAME__

        with symjax.Scope(name):

            output = cls.__init__(cls, *args, **kwargs)

        return output

    @staticmethod
    def add_updates(self, update):
        symjax.current_graph().add_updates(update)

    def forward(self):
        pass


class Identity(Layer):

    __NAME__ = "Identity"

    def __init__(self, input):
        return input


class Upsample1D(Layer):

    __NAME__ = "Upsample1D"

    def __init__(self, input, repeat, axis=-1, mode="constant", value=0.0):
        return T.interpolation.upsample_1d(
            input,
            repeat=repeat,
            axis=axis,
            mode=mode,
            value=value,
        )


class Upsample2D(Layer):

    __NAME__ = "Upsample2D"

    def __init__(self, input, repeat, axis, mode="constant", value=0.0):
        p1 = T.upsample_1d(
            input,
            repeat=repeat[0],
            axis=axis[0],
            mode=mode,
            value=value,
        )
        p2 = T.upsample_1d(
            p1,
            repeat=repeat[1],
            axis=axis[1],
            mode=mode,
            value=value,
        )
        return p2


class Dense(Layer):
    """Fully-connected/Dense layer

    perform a dense matrix multiplication and bias shifting of the
    input

    Parameters:
    -----------
    input: Tensor
        the input to the layer (does not have to be 2D)

    units: int
        the width of the layer

    W: Tensor-like/ndarray or callable (default initializers.glorot_uniform)
        the matrix weight of the layer of shape (units, input_dim)

    b: Tensor-like/ndarray or callable (default numpy.zeros)
        the bias vector of the layer

    trainable_W: bool (default True)
        if the variable initialized from W should be trainable

    trainable_b: bool (default True)
        if the vector initialized from b should be trainable

    W_preprocessor: None or callable (default: None)
        a possible preprocessing function applied onto the layer variable of W
        before computing the layer output

    b_preprocessor:  None or callable (default: None)
        a possible preprocessing function applied onto the layer variable of b
        before computing the layer output

    inplace_W: bool (default False)
        if the given Tensor-like/array or callable W should be used in place and
        not put as a Variable (which is then either frozen or learned) This is
        useful to create multiple layers with same weights.

    inplace_b: bool (default False)
        same as inplace_W but for b

    flatten: bool (default True)
        whether to flatten or not the input if more than 2 dimensional

    """

    __NAME__ = "Dense"

    def __init__(
        self,
        input,
        units,
        W=initializers.glorot_uniform,
        b=numpy.zeros,
        trainable_W=True,
        trainable_b=True,
        W_preprocessor=None,
        b_preprocessor=None,
        inplace_W=False,
        inplace_b=False,
        flatten=True,
    ):
        if flatten:
            width_in = numpy.prod(input.shape[1:])
        else:
            width_in = input.shape[-1]

        W = create_variable(
            "W",
            W,
            (units, width_in),
            trainable=trainable_W,
            preprocessor=W_preprocessor,
            inplace=inplace_W,
        )

        b = create_variable(
            "b",
            b,
            (units,),
            trainable=trainable_b,
            preprocessor=b_preprocessor,
            inplace=inplace_b,
        )

        if flatten:
            flat_input = T.flatten2d(input)
        else:
            flat_input = input
        if b is not None and W is None:
            return flat_input + b
        elif b is None and W is not None:
            return T.dot(flat_input, W.T)
        elif b is not None and W is not None:
            return T.dot(flat_input, W.T) + b
        else:
            return flat_input


[docs]class Conv1D(Layer): """1-D (time) convolution perform a dense matrix multiplication and bias shifting of the input Parameters: ----------- input n_filters filter_length W=initializers.glorot_uniform b=numpy.zeros stride=1 padding="VALID" trainable_W=True trainable_b=True inplace_W=False inplace_b=False W_preprocessor=None b_preprocessor=None input_dilations=None filter_dilations=None """ __NAME__ = "Conv1D"
[docs] def __init__( self, input, n_filters, filter_length, W=initializers.glorot_uniform, b=numpy.zeros, stride=1, padding="VALID", trainable_W=True, trainable_b=True, inplace_W=False, inplace_b=False, W_preprocessor=None, b_preprocessor=None, input_dilations=None, filter_dilations=None, ): if numpy.isscalar(input_dilations): input_dilations = (input_dilations,) * 2 W = create_variable( "W", W, (n_filters, input.shape[1], filter_length), trainable=trainable_W, preprocessor=W_preprocessor, inplace=inplace_W, ) b = create_variable( "b", b, (n_filters,), trainable=trainable_b, preprocessor=b_preprocessor, inplace=inplace_b, ) conv = T.signal.batch_convolve( input, W, strides=stride, padding=padding, input_dilation=input_dilations, filter_dilation=filter_dilations, ) if b is not None: return conv + b[:, None] else: return conv
class Conv2DTranspose(Layer): """2-D (spatial) convolution""" __NAME__ = "Conv2DTranspose" def __init__( self, input, n_filters, filter_shape, padding="VALID", strides=1, W=initializers.glorot_uniform, b=numpy.zeros, trainable_W=True, trainable_b=True, transpose_W=True, filter_dilations=None, ): self.init_input(input) self.transpose_W = transpose_W self.filter_dilation = filter_dilations self.strides = strides self.padding = padding self.create_variable( "W", W, (input.shape[1], n_filters) + tuple(filter_shape), trainable=trainable_W, ) self.create_variable("b", b, (n_filters,), trainable=trainable_b) conv = T.signal.batch_convolve_transpose( input, self.W, strides=self.strides, padding=self.padding, transpose_kernel=self.transpose_W, filter_dilation=self.filter_dilation, ) return conv + self.b.reshape((-1, 1, 1))
[docs]class Conv2D(Layer): """2-D (spatial) convolution""" __NAME__ = "Conv2D"
[docs] def __init__( self, input, n_filters, filter_shape, padding="VALID", strides=1, W=initializers.glorot_uniform, b=numpy.zeros, trainable_W=True, trainable_b=True, inplace_W=False, inplace_b=False, input_dilations=None, filter_dilations=None, W_preprocessor=None, b_preprocessor=None, ): W = create_variable( "W", W, (n_filters, input.shape[1]) + tuple(filter_shape), trainable=trainable_W, preprocessor=W_preprocessor, inplace=inplace_W, ) b = create_variable( "b", b, (n_filters,), trainable=trainable_b, preprocessor=b_preprocessor, inplace=inplace_b, ) conv = T.signal.batch_convolve( input, W, strides=strides, padding=padding, input_dilation=input_dilations, filter_dilation=filter_dilations, ) if b is not None: return conv + b.reshape((-1, 1, 1)) else: return conv
[docs]class Pool1D(Layer): """2-D (spatial) pooling""" __NAME__ = "Pool1D"
[docs] def __init__(self, input, pool_shape, pool_type="MAX", strides=None): pool_shape = (1, 1, pool_shape) if strides is None: strides = pool_shape else: strides = (1, 1, strides) return T.signal.pool( input, pool_shape, strides=strides, reducer=pool_type, )
[docs]class Pool2D(Layer): """2-D (spatial) pooling""" __NAME__ = "Pool2D"
[docs] def __init__(self, input, pool_shape, pool_type="MAX", strides=None): pool_shape = (1, 1) + symjax.data.utils.as_tuple(pool_shape, 2) if strides is None: strides = pool_shape else: strides = (1, 1) + symjax.data.utils.as_tuple(strides, 2) return T.signal.pool( input, pool_shape, strides=strides, reducer=pool_type, )
[docs]class Dropout(Layer): """binary mask onto the input Parameters ---------- input_or_shape: shape or Tensor the layer input or shape p: float (0<=p<=1) the probability to drop the value deterministic: bool or Tensor the state of the layer seed: seed the RNG seed Returns ------- output: the layer output """ __NAME__ = "Dropout"
[docs] def __init__(self, input, p, deterministic, seed=None): mask = T.random.bernoulli(shape=input.shape, p=1 - p, seed=seed) return T.where(deterministic, input, mask * input / T.maximum(1e-4, (1 - p)))
[docs]class RandomFlip(Layer): """ random axis flip on the input Random layer that will randomly flip the axis of the input. Note that all the involved operations are GPU compatible and allow for backpropagation Parameters ---------- input_or_shape: shape or Tensor the input of the layer or the shape of the layer input crop_shape: shape the shape of the cropped part of the input. It must have the same length as the input shape minus one for the first dimension deterministic: bool or Tensor if the layer is in deterministic mode or not padding: shape the amount of padding to apply on each dimension (except the first one) each dimension should have a couple for the before and after padding parts seed: seed (optional) to control reproducibility Returns ------- output: the output tensor which containts the internal variables """ __NAME__ = "RandomFlip"
[docs] def __init__(self, input, p, axis, deterministic, seed=None): extra_dims = input.ndim - 1 flip = T.random.bernoulli( shape=(input.shape[0],) + (1,) * extra_dims, p=p, seed=seed, ) dirac = T.cast(deterministic, "float32") flipped_input = T.where(flip, T.flip(input, axis), input) return input * dirac + flipped_input * (1 - dirac)
[docs]class RandomCrop(Layer): """ random crop selection form the input Random layer that will select a window of the input based on the given parameters, with the possibility to first apply a padding. This layer is commonly used as a data augmentation technique and positioned at the beginning of the deep network topology. Note that all the involved operations are GPU compatible and allow for backpropagation Parameters ---------- input_or_shape: shape or Tensor the input of the layer or the shape of the layer input crop_shape: shape the shape of the cropped part of the input. It must have the same length as the input shape minus one for the first dimension deterministic: bool or Tensor if the layer is in deterministic mode or not padding: shape the amount of padding to apply on each dimension (except the first one) each dimension should have a couple for the before and after padding parts seed: seed (optional) to control reproducibility Returns ------- output: the output tensor which containts the internal variables """ __NAME__ = "RandomCrop"
[docs] def __init__(self, input, crop_shape, deterministic, padding=0, seed=None): # if given only a scalar if not hasattr(padding, "__len__"): pad_shape = [(padding, padding)] * (input.ndim - 1) # else else: pad_shape = [ (pad, pad) if not hasattr(pad, "__len__") else pad for pad in padding ] assert len(pad_shape) == len(crop_shape) assert len(pad_shape) == input.ndim - 1 start_indices = list() fixed_indices = list() for i, (pad, dim, crop) in enumerate( zip(pad_shape, input.shape[1:], crop_shape) ): maxval = pad[0] + pad[1] + dim - crop start_indices.append( T.random.randint( minval=0, maxval=maxval, shape=(input.shape[0], 1), dtype="int32", seed=seed + i if seed is not None else seed, ) ) fixed_indices.append(T.ones((input.shape[0], 1), "int32") * (maxval // 2)) start_indices = T.concatenate(start_indices, 1) fixed_indices = T.concatenate(fixed_indices, 1) dirac = T.cast(deterministic, "float32") # pad the input pinput = T.pad(input, [(0, 0)] + pad_shape) routput = T.map( lambda x, indices: T.dynamic_slice(x, indices, crop_shape), sequences=[pinput, start_indices], ) doutput = T.map( lambda x, indices: T.dynamic_slice(x, indices, crop_shape), sequences=[pinput, fixed_indices], ) return doutput * dirac + (1 - dirac) * routput
[docs]class BatchNormalization(Layer): """ batch-normalization layer Parameters: ----------- input_or_shape: shape or Tensor the layer input tensor or shape axis: list or tuple of ints the axis to normalize on. If using BN on a dense layer then axis should be [0] to normalize over the samples. If the layer if a convolutional layer with data format NCHW then axis should be [0, 2, 3] to normalize over the samples and spatial dimensions (commonly done) deterministic: bool or Tensor controlling the state of the layer const: float32 (optional) the constant used in the standard deviation renormalization beta1: flaot32 (optional) the parameter for the exponential moving average of the mean beta2: float32 (optional) the parameters for the exponential moving average of the std Returns ------- output: the layer output with attributes given by the layer options """ __NAME__ = "BatchNormalization"
[docs] def __init__( self, input, axis, deterministic, const=0.001, beta_1=0.99, beta_2=0.99, W=T.ones, b=T.zeros, trainable_W=True, trainable_b=True, ): parameter_shape = [ input.shape[i] if i in axis else 1 for i in range(input.ndim) ] r_axes = [i for i in range(input.ndim) if i not in axis] W = create_variable("W", W, parameter_shape, trainable=trainable_W) b = create_variable("b", b, parameter_shape, trainable=trainable_b) input_mean = input.mean(r_axes, keepdims=True) # this definition is traditionally seen as less accurate than jnp.var's # mean((x - mean(x))**2) but may be faster and even, given typical # activation distributions and low-precision arithmetic, more accurate # when used in neural network normalization layers input_var = (input ** 2).mean(r_axes, keepdims=True) - input_mean ** 2 + const input_var = input.var(r_axes, keepdims=True) avg_mean = schedules.ExponentialMovingAverage( input_mean, beta_1, debias=False, name="mean_ema" )[1] avg_var = schedules.ExponentialMovingAverage( input_var, beta_2, init=T.ones_like(input_var, detach=True), debias=False, name="var_ema", )[1] m = T.where(deterministic, avg_mean, input_mean) v = T.where(deterministic, avg_var, input_var) output = nn.normalize(input, mean=m, variance=v, epsilon=const) if b is None and W is not None: return W * output elif b is not None and W is None: return output + b elif b is not None and W is not None: return W * output + b else: return output
[docs]class RNN(Layer): __NAME__ = "RNN" @staticmethod def gate(h, x, W, H, b, sigma): ht = sigma(T.dot(x, W) + b + T.dot(h, H)) return ht, ht
[docs] def __init__( self, sequence, init_h, units, W=initializers.glorot_uniform, H=initializers.orthogonal, b=T.zeros, trainable_W=True, trainable_H=True, trainable_b=True, activation=nn.sigmoid, only_last=False, ): W = create_variable("W", W, (sequence.shape[2], units), trainable=trainable_W) H = create_variable("H", H, (units, units), trainable=trainable_H) b = create_variable("b", b, (units,), trainable=trainable_b) last, output = T.scan( lambda h, x, W, H, b: RNN.gate(h, x, W, H, b, activation), init=init_h, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[W, H, b], ) if only_last: return last else: return output.transpose((1, 0, 2))
[docs]class GRU(Layer): __NAME__ = "GRU" @staticmethod def full_gate(h, x, Wh, Uh, bh, Wz, Uz, bz, Wr, Ur, br, sigma, phi): zt = sigma(T.dot(x, Wz) + bz + T.dot(h, Uz)) rt = sigma(T.dot(x, Wr) + br + T.dot(h, Ur)) h_hat = phi(T.dot(x, Wh) + bh + T.dot(h * rt, Uh)) ht = (1 - zt) * h + zt * h_hat return ht, ht @staticmethod def minimal_gate(h, x, Wh, Uh, bh, Wz, Uz, bz, sigma, phi): ft = sigma(T.dot(x, Wz) + bz + T.dot(h, Uz)) h_hat = phi(T.dot(x, Wh) + bh + T.dot(h * ft, Uh)) ht = (1 - ft) * h + ft * h_hat return ht, ht
[docs] def __init__( self, sequence, init_h, units, Wh=initializers.glorot_uniform, Uh=initializers.orthogonal, bh=T.zeros, Wz=initializers.glorot_uniform, Uz=initializers.orthogonal, bz=T.zeros, Wr=initializers.glorot_uniform, Ur=initializers.orthogonal, br=T.zeros, trainable_Wh=True, trainable_Uh=True, trainable_bh=True, trainable_Wz=True, trainable_Uz=True, trainable_bz=True, trainable_Wr=True, trainable_Ur=True, trainable_br=True, activation=nn.sigmoid, phi=T.tanh, only_last=False, gate="minimal", ): Wh = create_variable( "Wh", Wh, (sequence.shape[2], units), trainable=trainable_Wh ) Uh = create_variable("Uh", Uh, (units, units), trainable=trainable_Uh) bh = create_variable("bh", bh, (units,), trainable=trainable_bh) Wz = create_variable( "Wz", Wz, (sequence.shape[2], units), trainable=trainable_Wz ) Uz = create_variable("Uz", Uz, (units, units), trainable=trainable_Uz) bz = create_variable("bz", bz, (units,), trainable=trainable_bz) if gate == "full": Wr = create_variable( "Wr", Wr, (sequence.shape[2], units), trainable=trainable_Wr ) Ur = create_variable("Ur", Ur, (units, units), trainable=trainable_Ur) br = create_variable("br", br, (units,), trainable=trainable_br) if gate == "minimal": def fn(*args): return GRU.minimal_gate(*args, activation, phi) last, output = T.scan( fn, init=init_h, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[Wh, Uh, bh, Wz, Uz, bz], ) elif gate == "full": def fn(*args): return GRU.full_gate(*args, activation, phi) last, output = T.scan( fn, init=init_h, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[Wh, Uh, bh, Wz, Uz, bz, Wr, Ur, br], ) if only_last: return last else: return output.transpose((1, 0, 2))
[docs]class LSTM(Layer): __NAME__ = "GRU" @staticmethod def gate( carry, x, Wf, Uf, bf, Wi, Ui, bi, Wo, Uo, bo, Wc, Uc, bc, sigma_g, sigma_c, sigma_h, ): h, c = carry[0], carry[1] f = sigma_g(T.dot(x, Wf) + bf + T.dot(h, Uf)) i = sigma_g(T.dot(x, Wi) + bi + T.dot(h, Ui)) o = sigma_g(T.dot(x, Wo) + bo + T.dot(h, Uo)) ctilde = sigma_c(T.dot(x, Wc) + bc + T.dot(h, Uc)) cnew = f * c + i * ctilde hnew = o * sigma_h(cnew) return T.stack([hnew, cnew]), h
[docs] def __init__( self, sequence, init_h, units, Wf=initializers.glorot_uniform, Uf=initializers.orthogonal, bf=T.zeros, Wi=initializers.glorot_uniform, Ui=initializers.orthogonal, bi=T.zeros, Wo=initializers.glorot_uniform, Uo=initializers.orthogonal, bo=T.zeros, Wc=initializers.glorot_uniform, Uc=initializers.orthogonal, bc=T.zeros, trainable_Wf=True, trainable_Uf=True, trainable_bf=True, trainable_Wi=True, trainable_Ui=True, trainable_bi=True, trainable_Wo=True, trainable_Uo=True, trainable_bo=True, trainable_Wc=True, trainable_Uc=True, trainable_bc=True, activation_g=nn.sigmoid, activation_c=T.tanh, activation_h=T.tanh, only_last=False, gate="minimal", ): self.create_variable( "Wf", Wf, (sequence.shape[2], units), trainable=trainable_Wf ) self.create_variable("Uf", Uf, (units, units), trainable=trainable_Uf) self.create_variable("bf", bf, (units,), trainable=trainable_bf) self.create_variable( "Wi", Wi, (sequence.shape[2], units), trainable=trainable_Wi ) self.create_variable("Ui", Ui, (units, units), trainable=trainable_Ui) self.create_variable("bi", bi, (units,), trainable=trainable_bi) self.create_variable( "Wo", Wo, (sequence.shape[2], units), trainable=trainable_Wo ) self.create_variable("Uo", Uo, (units, units), trainable=trainable_Uo) self.create_variable("bo", bo, (units,), trainable=trainable_bo) self.create_variable( "Wc", Wc, (sequence.shape[2], units), trainable=trainable_Wc ) self.create_variable("Uc", Uc, (units, units), trainable=trainable_Uc) self.create_variable("bc", bc, (units,), trainable=trainable_bc) def fn(*args): return self.gate(*args, activation_g, activation_c, activation_h) init = T.stack((init_h, T.zeros(init_h.shape, init_h.dtype))) last, output = T.scan( fn, init=init, sequences=[sequence.transpose((1, 0, 2))], non_sequences=[ self.Wf, self.Uf, self.bf, self.Wi, self.Ui, self.bi, self.Wo, self.Uo, self.bo, self.Wc, self.Uc, self.bc, ], ) if only_last: return last else: return output.transpose((1, 0, 2))