RNTK kernelΒΆ

tiem series regression and classification

Out:

Op(name=multiply, fn=multiply, shape=(3, 3), dtype=float32, scope=/)
[array([[374.33334, 120.098  , 131.59918],
       [120.098  , 302.12424,  86.20666],
       [131.59918,  86.20666, 203.82948]], dtype=float32), array([[71.971275, 48.085426, 45.39569 ],
       [48.085426, 60.663128, 35.050926],
       [45.39569 , 35.050926, 38.225647]], dtype=float32)]

import numpy as np
import symjax
import symjax.tensor as T
import networkx as nx


def RNTK_first_time_step(x, param):
    # this is for computing the first GP and RNTK for t = 1. Both for relu and erf
    sw = param["sigmaw"]
    su = param["sigmau"]
    sb = param["sigmab"]
    sh = param["sigmah"]
    X = x * x[:, None]
    print(X)
    n = X.shape[0]
    GP_new = sh ** 2 * sw ** 2 * T.eye(n, n) + (su ** 2 / m) * X + sb ** 2
    RNTK_new = GP_new
    return RNTK_new, GP_new


def RNTK_relu(x, RNTK_old, GP_old, param, output):
    sw = param["sigmaw"]
    su = param["sigmau"]
    sb = param["sigmab"]
    sv = param["sigmav"]

    a = T.diag(GP_old)  # GP_old is in R^{n*n} having the output gp kernel
    # of all pairs of data in the data set
    B = a * a[:, None]
    C = T.sqrt(B)  # in R^{n*n}
    D = GP_old / C  # this is lamblda in ReLU analyrucal formula
    # clipping E between -1 and 1 for numerical stability.
    E = T.clip(D, -1, 1)
    F = (1 / (2 * np.pi)) * (E * (np.pi - T.arccos(E)) + T.sqrt(1 - E ** 2)) * C
    G = (np.pi - T.arccos(E)) / (2 * np.pi)
    if output:
        GP_new = sv ** 2 * F
        RNTK_new = sv ** 2.0 * RNTK_old * G + GP_new
    else:
        X = x * x[:, None]
        GP_new = sw ** 2 * F + (su ** 2 / m) * X + sb ** 2
        RNTK_new = sw ** 2.0 * RNTK_old * G + GP_new
    return RNTK_new, GP_new


L = 10
N = 3
DATA = T.Placeholder((N, L), "float32", name="data")
# parameters
param = {}
param["sigmaw"] = 1.33
param["sigmau"] = 1.45
param["sigmab"] = 1.2
param["sigmah"] = 0.4
param["sigmav"] = 2.34
m = 1

# first time step
RNTK, GP = RNTK_first_time_step(DATA[:, 0], param)

for t in range(1, L):
    RNTK, GP = RNTK_relu(DATA[:, t], RNTK, GP, param, False)

RNTK, GP = RNTK_relu(0, RNTK, GP, param, True)


f = symjax.function(DATA, outputs=[RNTK, GP])

# three data of length T
a = np.random.randn(L)
b = np.random.randn(L)
c = np.random.randn(L)
example = np.stack([a, b, c])  # it is of shape (3, T)
print(f(example))

Total running time of the script: ( 0 minutes 4.321 seconds)

Gallery generated by Sphinx-Gallery