{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nComputation times\n=================\n\nIn this example we demonstrate how to perform a simple optimization with Adam\nin TF and SymJAX and compare the computation time\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n\nimport symjax\nimport symjax.tensor as T\nfrom symjax.nn import optimizers\nimport numpy as np\nimport time\n\n\nlr = 0.01\nBS = 10000\nD = 1000\nX = np.random.randn(BS, D).astype(\"float32\")\nY = X.dot(np.random.randn(D, 1).astype(\"float32\")) + 2\n\n\ndef TF1(x, y, N, preallocate=False):\n    import tensorflow.compat.v1 as tf\n\n    tf.compat.v1.disable_v2_behavior()\n\n    if preallocate:\n        tf_input = tf.constant(x)\n        tf_output = tf.constant(y)\n    else:\n        tf_input = tf.placeholder(dtype=tf.float32, shape=[BS, D])\n        tf_output = tf.placeholder(dtype=tf.float32, shape=[BS, 1])\n\n    np.random.seed(0)\n\n    tf_W = tf.Variable(np.random.randn(D, 1).astype(\"float32\"))\n    tf_b = tf.Variable(\n        np.random.randn(\n            1,\n        ).astype(\"float32\")\n    )\n\n    tf_loss = tf.reduce_mean((tf.matmul(tf_input, tf_W) + tf_b - tf_output) ** 2)\n\n    train_op = tf.train.AdamOptimizer(lr).minimize(tf_loss)\n\n    # initialize session\n    config = tf.ConfigProto()\n    config.gpu_options.allow_growth = True\n    sess = tf.Session(config=config)\n    sess.run(tf.global_variables_initializer())\n\n    if not preallocate:\n        t = time.time()\n        for i in range(N):\n            sess.run(train_op, feed_dict={tf_input: x, tf_output: y})\n    else:\n        t = time.time()\n        for i in range(N):\n            sess.run(train_op)\n    return time.time() - t\n\n\ndef TF2(x, y, N, preallocate=False):\n    import tensorflow as tf\n\n    optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)\n    np.random.seed(0)\n\n    tf_W = tf.Variable(np.random.randn(D, 1).astype(\"float32\"))\n    tf_b = tf.Variable(\n        np.random.randn(\n            1,\n        ).astype(\"float32\")\n    )\n\n    @tf.function\n    def train(tf_input, tf_output):\n\n        with tf.GradientTape() as tape:\n            tf_loss = tf.reduce_mean(\n                (tf.matmul(tf_input, tf_W) + tf_b - tf_output) ** 2\n            )\n        grads = tape.gradient(tf_loss, [tf_W, tf_b])\n        optimizer.apply_gradients(zip(grads, [tf_W, tf_b]))\n        return tf_loss\n\n    if preallocate:\n        x = tf.constant(x)\n        y = tf.constant(y)\n\n    t = time.time()\n    for i in range(N):\n        l = train(x, y)\n\n    return time.time() - t\n\n\ndef SJ(x, y, N, preallocate=False):\n    symjax.current_graph().reset()\n    sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])\n    sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])\n\n    np.random.seed(0)\n\n    sj_W = T.Variable(np.random.randn(D, 1).astype(\"float32\"))\n    sj_b = T.Variable(\n        np.random.randn(\n            1,\n        ).astype(\"float32\")\n    )\n\n    sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean()\n\n    optimizers.Adam(sj_loss, lr)\n\n    train = symjax.function(sj_input, sj_output, updates=symjax.get_updates())\n\n    if preallocate:\n        import jax\n\n        x = jax.device_put(x)\n        y = jax.device_put(y)\n\n    t = time.time()\n    for i in range(N):\n        train(x, y)\n\n    return time.time() - t\n\n\nvalues = []\nNs = [10, 100, 200, 400, 1000]\nfor pre in [False, True]:\n    for N in Ns:\n        print(pre, N)\n        print(\"TF1\")\n        values.append(TF1(X, Y, N, pre))\n        # print(\"TF2\")\n        # values.append(TF2(X, Y, N, pre))\n        print(\"SJ\")\n        values.append(SJ(X, Y, N, pre))\n\n\nvalues = np.array(values).reshape((2, len(Ns), 2))\n\nfor i, ls in enumerate([\"-\", \"--\"]):\n    for j, c in enumerate([\"r\", \"g\"]):\n        plt.plot(Ns, values[i, :, j], linestyle=ls, c=c, linewidth=3, alpha=0.8)\nplt.legend([\"TF1 no prealloc.\", \"SJ no prealloc.\", \"TF1 prealloc.\", \"SJ prealloc.\"])\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}