Pixel interpolation learningΒΆ

we demonstrate in this toy example how to use the coordinate interpolation techniques with learnable parameter to map one image to another one simply by interpolation the original image values from learned coordinates

Training loss, input, target, reconstruction, Initialized coordinates, Learned coordinates

Out:

        ... mnist.pkl.gz already exists
Loading mnist
Dataset mnist loaded in 0.67s.
/home/vrael/anaconda3/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
[[ 0.  0.  0. ... 27. 27. 27.]
 [ 0.  1.  2. ... 25. 26. 27.]]
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/tight_layout.py:345: UserWarning: tight_layout not applied: number of columns in subplot specifications mustbe multiples of one another.
  warnings.warn('tight_layout not applied: '
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:445: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
  % get_backend())

import symjax
import symjax.tensor as T
import matplotlib.pyplot as plt
import numpy as np

import os

os.environ["DATASET_PATH"] = "/home/vrael/DATASETS/"

symjax.current_graph().reset()


mnist = symjax.data.mnist()
# 2d image
images = mnist["train_set/images"][mnist["train_set/labels"] == 2][:2, 0]
images /= images.max()

np.random.seed(0)

coordinates = T.meshgrid(T.range(28), T.range(28))
coordinates = T.Variable(
    T.stack([coordinates[1].flatten(), coordinates[0].flatten()]).astype("float32")
)
interp = T.interpolation.map_coordinates(images[0], coordinates, order=1).reshape(
    (28, 28)
)

loss = ((interp - images[1]) ** 2).mean()

lr = symjax.nn.schedules.PiecewiseConstant(0.05, {5000: 0.01, 8000: 0.005})
symjax.nn.optimizers.Adam(loss, lr)

train = symjax.function(outputs=loss, updates=symjax.get_updates())

rec = symjax.function(outputs=interp)

losses = list()

original = coordinates.value

for i in range(100):
    losses.append(train())

reconstruction = rec()

after = coordinates.value


plt.figure(figsize=(12, 6))

plt.subplot(311)
plt.semilogy(losses, "-x")
plt.ylabel("loss (l2)")
plt.title("Training loss")


plt.subplot(334)
plt.imshow(images[0], aspect="auto", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("input")

plt.subplot(335)
plt.imshow(images[1], aspect="auto", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("target")

plt.subplot(336)
plt.imshow(reconstruction, aspect="auto", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("reconstruction")


print(original)

plt.subplot(325)
plt.scatter(original[1][::-1], original[0], s=3)
plt.xticks([])
plt.yticks([])
plt.title("Initialized coordinates")

plt.subplot(326)
plt.scatter(after[1][::-1], after[0], s=3)
plt.xticks([])
plt.yticks([])
plt.title("Learned coordinates")


plt.tight_layout()
plt.show()

Total running time of the script: ( 0 minutes 1.922 seconds)

Gallery generated by Sphinx-Gallery