{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nRNTK kernel\n===========\n\ntiem series regression and classification\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nimport symjax\nimport symjax.tensor as T\nimport networkx as nx\n\n\ndef RNTK_first_time_step(x, param):\n    # this is for computing the first GP and RNTK for t = 1. Both for relu and erf\n    sw = param[\"sigmaw\"]\n    su = param[\"sigmau\"]\n    sb = param[\"sigmab\"]\n    sh = param[\"sigmah\"]\n    X = x * x[:, None]\n    print(X)\n    n = X.shape[0]\n    GP_new = sh ** 2 * sw ** 2 * T.eye(n, n) + (su ** 2 / m) * X + sb ** 2\n    RNTK_new = GP_new\n    return RNTK_new, GP_new\n\n\ndef RNTK_relu(x, RNTK_old, GP_old, param, output):\n    sw = param[\"sigmaw\"]\n    su = param[\"sigmau\"]\n    sb = param[\"sigmab\"]\n    sv = param[\"sigmav\"]\n\n    a = T.diag(GP_old)  # GP_old is in R^{n*n} having the output gp kernel\n    # of all pairs of data in the data set\n    B = a * a[:, None]\n    C = T.sqrt(B)  # in R^{n*n}\n    D = GP_old / C  # this is lamblda in ReLU analyrucal formula\n    # clipping E between -1 and 1 for numerical stability.\n    E = T.clip(D, -1, 1)\n    F = (1 / (2 * np.pi)) * (E * (np.pi - T.arccos(E)) + T.sqrt(1 - E ** 2)) * C\n    G = (np.pi - T.arccos(E)) / (2 * np.pi)\n    if output:\n        GP_new = sv ** 2 * F\n        RNTK_new = sv ** 2.0 * RNTK_old * G + GP_new\n    else:\n        X = x * x[:, None]\n        GP_new = sw ** 2 * F + (su ** 2 / m) * X + sb ** 2\n        RNTK_new = sw ** 2.0 * RNTK_old * G + GP_new\n    return RNTK_new, GP_new\n\n\nL = 10\nN = 3\nDATA = T.Placeholder((N, L), \"float32\", name=\"data\")\n# parameters\nparam = {}\nparam[\"sigmaw\"] = 1.33\nparam[\"sigmau\"] = 1.45\nparam[\"sigmab\"] = 1.2\nparam[\"sigmah\"] = 0.4\nparam[\"sigmav\"] = 2.34\nm = 1\n\n# first time step\nRNTK, GP = RNTK_first_time_step(DATA[:, 0], param)\n\nfor t in range(1, L):\n    RNTK, GP = RNTK_relu(DATA[:, t], RNTK, GP, param, False)\n\nRNTK, GP = RNTK_relu(0, RNTK, GP, param, True)\n\n\nf = symjax.function(DATA, outputs=[RNTK, GP])\n\n# three data of length T\na = np.random.randn(L)\nb = np.random.randn(L)\nc = np.random.randn(L)\nexample = np.stack([a, b, c])  # it is of shape (3, T)\nprint(f(example))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}