Note
Click here to download the full example code
RNTK kernelΒΆ
tiem series regression and classification
Out:
Op(name=multiply, fn=multiply, shape=(3, 3), dtype=float32, scope=/)
[array([[374.33334, 120.098 , 131.59918],
[120.098 , 302.12424, 86.20666],
[131.59918, 86.20666, 203.82948]], dtype=float32), array([[71.971275, 48.085426, 45.39569 ],
[48.085426, 60.663128, 35.050926],
[45.39569 , 35.050926, 38.225647]], dtype=float32)]
import numpy as np
import symjax
import symjax.tensor as T
import networkx as nx
def RNTK_first_time_step(x, param):
# this is for computing the first GP and RNTK for t = 1. Both for relu and erf
sw = param["sigmaw"]
su = param["sigmau"]
sb = param["sigmab"]
sh = param["sigmah"]
X = x * x[:, None]
print(X)
n = X.shape[0]
GP_new = sh ** 2 * sw ** 2 * T.eye(n, n) + (su ** 2 / m) * X + sb ** 2
RNTK_new = GP_new
return RNTK_new, GP_new
def RNTK_relu(x, RNTK_old, GP_old, param, output):
sw = param["sigmaw"]
su = param["sigmau"]
sb = param["sigmab"]
sv = param["sigmav"]
a = T.diag(GP_old) # GP_old is in R^{n*n} having the output gp kernel
# of all pairs of data in the data set
B = a * a[:, None]
C = T.sqrt(B) # in R^{n*n}
D = GP_old / C # this is lamblda in ReLU analyrucal formula
# clipping E between -1 and 1 for numerical stability.
E = T.clip(D, -1, 1)
F = (1 / (2 * np.pi)) * (E * (np.pi - T.arccos(E)) + T.sqrt(1 - E ** 2)) * C
G = (np.pi - T.arccos(E)) / (2 * np.pi)
if output:
GP_new = sv ** 2 * F
RNTK_new = sv ** 2.0 * RNTK_old * G + GP_new
else:
X = x * x[:, None]
GP_new = sw ** 2 * F + (su ** 2 / m) * X + sb ** 2
RNTK_new = sw ** 2.0 * RNTK_old * G + GP_new
return RNTK_new, GP_new
L = 10
N = 3
DATA = T.Placeholder((N, L), "float32", name="data")
# parameters
param = {}
param["sigmaw"] = 1.33
param["sigmau"] = 1.45
param["sigmab"] = 1.2
param["sigmah"] = 0.4
param["sigmav"] = 2.34
m = 1
# first time step
RNTK, GP = RNTK_first_time_step(DATA[:, 0], param)
for t in range(1, L):
RNTK, GP = RNTK_relu(DATA[:, t], RNTK, GP, param, False)
RNTK, GP = RNTK_relu(0, RNTK, GP, param, True)
f = symjax.function(DATA, outputs=[RNTK, GP])
# three data of length T
a = np.random.randn(L)
b = np.random.randn(L)
c = np.random.randn(L)
example = np.stack([a, b, c]) # it is of shape (3, T)
print(f(example))
Total running time of the script: ( 0 minutes 4.321 seconds)