{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nMNIST classification\n====================\n\nexample of image (MNIST) classification on small part of the data\nand with a small architecture\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import symjax.tensor as T\nfrom symjax import nn\nimport symjax\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom symjax.data import mnist\nfrom symjax.data.utils import batchify\n\nimport os\n\nos.environ[\"DATASET_PATH\"] = \"/home/vrael/DATASETS/\"\nsymjax.current_graph().reset()\n# load the dataset\nmnist = mnist()\n\n# some renormalization, and we only keep the first 2000 images\nmnist[\"train_set/images\"] = mnist[\"train_set/images\"][:2000]\nmnist[\"train_set/labels\"] = mnist[\"train_set/labels\"][:2000]\n\nmnist[\"train_set/images\"] /= mnist[\"train_set/images\"].max((1, 2, 3), keepdims=True)\nmnist[\"test_set/images\"] /= mnist[\"test_set/images\"].max((1, 2, 3), keepdims=True)\n\n# create the network\nBATCH_SIZE = 64\nimages = T.Placeholder((BATCH_SIZE, 1, 28, 28), \"float32\", name=\"images\")\nlabels = T.Placeholder((BATCH_SIZE,), \"int32\", name=\"labels\")\ndeterministic = T.Placeholder((1,), \"bool\")\n\n\nlayer = [nn.layers.Identity(images)]\n\nfor l in range(2):\n    layer.append(nn.layers.Conv2D(layer[-1], 64, (3, 3), b=None))\n    # due to the small size of the dataset we can\n    # increase the update of the bn moving averages\n    layer.append(\n        nn.layers.BatchNormalization(\n            layer[-1], [1], deterministic, beta_1=0.9, beta_2=0.9\n        )\n    )\n    layer.append(nn.leaky_relu(layer[-1]))\n    layer.append(nn.layers.Pool2D(layer[-1], (2, 2)))\n\n\nlayer.append(nn.layers.Pool2D(layer[-1], layer[-1].shape.get()[2:], pool_type=\"AVG\"))\n\nlayer.append(nn.layers.Dense(layer[-1], 10))\n\n# each layer is itself a tensor which represents its output and thus\n# any tensor operation can be used on the layer instance, for example\nfor l in layer:\n    print(l.shape.get())\n\n\nloss = nn.losses.sparse_softmax_crossentropy_logits(labels, layer[-1]).mean()\naccuracy = nn.losses.accuracy(labels, layer[-1])\n\nnn.optimizers.Adam(loss, 0.001)\n\ntest = symjax.function(images, labels, deterministic, outputs=[loss, accuracy])\n\ntrain = symjax.function(\n    images,\n    labels,\n    deterministic,\n    outputs=[loss, accuracy],\n    updates=symjax.get_updates(),\n)\n\ntest_accuracy = []\ntrain_accuracy = []\n\nfor epoch in range(20):\n    print(\"...epoch:\", epoch)\n    L = list()\n    for x, y in batchify(\n        mnist[\"test_set/images\"],\n        mnist[\"test_set/labels\"],\n        batch_size=BATCH_SIZE,\n        option=\"continuous\",\n    ):\n        L.append(test(x, y, 1))\n    print(\"Test Loss and Accu:\", np.mean(L, 0))\n    test_accuracy.append(np.mean(L, 0))\n    L = list()\n    for x, y in batchify(\n        mnist[\"train_set/images\"],\n        mnist[\"train_set/labels\"],\n        batch_size=BATCH_SIZE,\n        option=\"random_see_all\",\n    ):\n        L.append(train(x, y, 0))\n    train_accuracy.append(np.mean(L, 0))\n    print(\"Train Loss and Accu\", np.mean(L, 0))\n\ntrain_accuracy = np.array(train_accuracy)\ntest_accuracy = np.array(test_accuracy)\n\nplt.subplot(121)\nplt.plot(test_accuracy[:, 1], c=\"k\")\nplt.plot(train_accuracy[:, 1], c=\"b\")\nplt.xlabel(\"epochs\")\nplt.ylabel(\"accuracy\")\n\nplt.subplot(122)\nplt.plot(test_accuracy[:, 0], c=\"k\")\nplt.plot(train_accuracy[:, 0], c=\"b\")\nplt.xlabel(\"epochs\")\nplt.ylabel(\"accuracy\")\n\nplt.suptitle(\"MNIST (1K data) classification task\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}