{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nBasic scan/loops examples\n=========================\n\nIn this example we demonstrate how to employ the :py:func:`symjax.tensor.scan`\nand other similar functions.\n\nWe first demonstrate how to compute a moving average with\n:py:func:`symjax.tensor.scan`\n\nWe then demonstrate how to do a simple for loop and then a while loop.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n\nimport symjax\nimport symjax.tensor as T\nimport numpy as np\n\n\n# suppose we are given a time-serie and we want to compute an\n# exponential moving average, we also use the EMA coefficient alpha\n# based on the user input\n\nsignal = T.Placeholder((512,), \"float32\", name=\"signal\")\nalpha = T.Placeholder((), \"float32\", \"alpha\")\n\n# to use a scan function one needs a function to be applied at each step\n# in our case an exponential moving average function\n# this function should output the new value of the carry as well as an\n# additional output, in our case, the carry (EMA) is also what we want to\n# output at each tiem step\n\n\ndef fn(at, xt, alpha):\n    # the function first input is the carry, then are the (ordered)\n    # values from sequences and non_sequences similar to Theano\n    EMA = at * alpha + (1 - alpha) * xt\n    return EMA, EMA\n\n\n# the scan function will return the carry at each time steps (first arg.)\n# as well as the last one, we also need to provide an init.\nlast_ema, all_ema = T.scan(\n    fn, init=signal[0], sequences=[signal[1:]], non_sequences=[alpha]\n)\n\n\nf = symjax.function(signal, alpha, outputs=all_ema)\n\n# generate a signal\nx = np.cos(np.linspace(-3, 3, 512)) + np.random.randn(512) * 0.2\n\nfig, ax = plt.subplots(3, 1, figsize=(3, 9))\n\nfor k, alpha in enumerate([0.1, 0.5, 0.9]):\n    ax[k].plot(x, c=\"b\")\n    ax[k].plot(f(x, alpha), c=\"r\")\n    ax[k].set_title(\"EMA: {}\".format(alpha))\n    ax[k].set_xticks([])\n    ax[k].set_yticks([])\n\nplt.tight_layout()\n\n\n# Now let's do a simple map for which we can compute a simple\n# moving average. The for loop will consist of moving a window and\n# average the values on that window\n\n# in that case the function also needs to be defined\n\n\ndef fn(window):\n    # the function first input is the current index of the for loop\n    # the other inputs are the (ordered) sequences and non_sequnces\n    # values\n\n    return T.mean(window)\n\n\nwindowed = T.extract_signal_patches(signal, 10)\noutput = T.map(fn, sequences=[windowed])\nf = symjax.function(signal, outputs=output)\n\nfig, ax = plt.subplots(1, 1, figsize=(5, 2))\n\nax.plot(x, c=\"b\")\nax.plot(f(x), c=\"r\")\nax.set_title(\"SMA: 10\")\nax.set_xticks([])\nax.set_yticks([])\n\nplt.tight_layout()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}