{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nImage classification, Keras and SymJAX\n======================================\n\nexample of image classification with deep networks using Keras and SymJAX\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import symjax.tensor as T\nfrom symjax import nn\nimport symjax\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport sys\n\nsys.setrecursionlimit(3500)\n\n\ndef classif_tf(train_x, train_y, test_x, test_y, mlp=True):\n    import tensorflow as tf\n    from tensorflow.keras import layers\n\n    batch_size = 128\n\n    inputs = layers.Input(shape=(3, 32, 32))\n    if not mlp:\n        out = layers.Permute((2, 3, 1))(inputs)\n        out = layers.Conv2D(32, 3, activation=\"relu\")(out)\n        for i in range(3):\n            for j in range(3):\n\n                conv = layers.Conv2D(\n                    32 * (i + 1), 3, activation=\"linear\", padding=\"SAME\"\n                )(out)\n                bn = layers.BatchNormalization(axis=-1)(conv)\n                relu = layers.Activation(\"relu\")(bn)\n                conv = layers.Conv2D(\n                    32 * (i + 1), 3, activation=\"linear\", padding=\"SAME\"\n                )(relu)\n                bn = layers.BatchNormalization(axis=-1)(conv)\n\n                out = layers.Add()([out, bn])\n            out = layers.AveragePooling2D()(out)\n            out = layers.Conv2D(32 * (i + 2), 1, activation=\"linear\")(out)\n            print(out.shape)\n        out = layers.GlobalAveragePooling2D()(out)\n    else:\n        out = layers.Flatten()(inputs)\n        for i in range(6):\n            out = layers.Dense(4000, activation=\"linear\")(out)\n            bn = layers.BatchNormalization(axis=-1)(out)\n            out = layers.Activation(\"relu\")(bn)\n    outputs = layers.Dense(10, activation=\"linear\")(out)\n\n    model = tf.keras.Model(inputs, outputs)\n    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)\n\n    for epoch in range(5):\n        accu = 0\n        for x, y in symjax.data.utils.batchify(\n            train_x, train_y, batch_size=batch_size, option=\"random\"\n        ):\n            with tf.GradientTape() as tape:\n                preds = model(x, training=True)\n                loss = tf.reduce_mean(\n                    tf.nn.sparse_softmax_cross_entropy_with_logits(y, preds)\n                )\n            accu += tf.reduce_mean(tf.cast(y == tf.argmax(preds, 1), \"float32\"))\n            grads = tape.gradient(loss, model.trainable_variables)\n            optimizer.apply_gradients(zip(grads, model.trainable_variables))\n        print(\"training\", accu / (len(train_x) // batch_size))\n        accu = 0\n        for x, y in symjax.data.utils.batchify(\n            test_x, test_y, batch_size=batch_size, option=\"continuous\"\n        ):\n            preds = model(x, training=False)\n            accu += tf.reduce_mean(tf.cast(y == tf.argmax(preds, 1), \"float32\"))\n        print(accu / (len(test_x) // batch_size))\n\n\ndef classif_sj(train_x, train_y, test_x, test_y, mlp=True):\n    symjax.current_graph().reset()\n    from symjax import nn\n\n    batch_size = 128\n\n    input = T.Placeholder((batch_size, 3, 32, 32), \"float32\")\n    labels = T.Placeholder((batch_size,), \"int32\")\n    deterministic = T.Placeholder((), \"bool\")\n\n    if not mlp:\n        out = nn.relu(nn.layers.Conv2D(input, 32, (3, 3)))\n        for i in range(3):\n            for j in range(3):\n                conv = nn.layers.Conv2D(out, 32 * (i + 1), (3, 3), pad=\"SAME\")\n                bn = nn.layers.BatchNormalization(\n                    conv, [1], deterministic=deterministic\n                )\n                bn = nn.relu(bn)\n                conv = nn.layers.Conv2D(bn, 32 * (i + 1), (3, 3), pad=\"SAME\")\n                bn = nn.layers.BatchNormalization(\n                    conv, [1], deterministic=deterministic\n                )\n                out = out + bn\n\n            out = nn.layers.Pool2D(out, (2, 2), pool_type=\"AVG\")\n            out = nn.layers.Conv2D(out, 32 * (i + 2), (1, 1))\n        # out = out.mean((2, 3))\n        out = nn.layers.Pool2D(out, out.shape.get()[-2:], pool_type=\"AVG\")\n    else:\n        out = input\n        for i in range(6):\n            out = nn.layers.Dense(out, 4000)\n            out = nn.relu(\n                nn.layers.BatchNormalization(out, [1], deterministic=deterministic)\n            )\n\n    outputs = nn.layers.Dense(out, 10)\n\n    loss = nn.losses.sparse_softmax_crossentropy_logits(labels, outputs).mean()\n    nn.optimizers.Adam(loss, 0.001)\n\n    accu = T.equal(outputs.argmax(1), labels).astype(\"float32\").mean()\n\n    train = symjax.function(\n        input,\n        labels,\n        deterministic,\n        outputs=[loss, accu, outputs],\n        updates=symjax.get_updates(),\n    )\n    test = symjax.function(input, labels, deterministic, outputs=accu)\n\n    for epoch in range(5):\n        accu = 0\n        for x, y in symjax.data.utils.batchify(\n            train_x, train_y, batch_size=batch_size, option=\"random\"\n        ):\n            accu += train(x, y, 0)[1]\n\n        print(\"training\", accu / (len(train_x) // batch_size))\n\n        accu = 0\n        for x, y in symjax.data.utils.batchify(\n            test_x, test_y, batch_size=batch_size, option=\"continuous\"\n        ):\n            accu += test(x, y, 1)\n        print(accu / (len(test_x) // batch_size))\n\n\nmnist = symjax.data.cifar10()\ntrain_x, train_y = mnist[\"train_set/images\"], mnist[\"train_set/labels\"]\ntest_x, test_y = mnist[\"test_set/images\"], mnist[\"test_set/labels\"]\ntrain_x /= train_x.max()\ntest_x /= test_x.max()\n\n\n# classif_sj(train_x, train_y, test_x, test_y, False)\n# classif_tf(train_x, train_y, test_x, test_y, False)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}