Note
Click here to download the full example code
Basic (linear) deconvolution filter learningΒΆ
demonstration on how to learn a deconvolutional filter based on some flavors of gradietn descent assuming we know the true output
Out:
... mnist.pkl.gz already exists
Loading mnist
Dataset mnist loaded in 0.95s.
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:445: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
% get_backend())
import symjax
import symjax.tensor as T
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import convolve2d
import os
os.environ["DATASET_PATH"] = "/home/vrael/DATASETS/"
symjax.current_graph().reset()
true_image = symjax.data.mnist()
# 2d image
true_image = true_image["train_set/images"][0, 0]
true_image /= true_image.max()
np.random.seed(0)
noisy_image = convolve2d(true_image, np.random.randn(5, 5) / 5, "same")
# GRADIENT DESCENT
filter_1 = T.Variable(np.random.randn(8, 8) / 8, dtype="float32")
filter_2 = T.Variable(filter_1.value, dtype="float32")
reconstruction_1 = T.signal.convolve2d(noisy_image, filter_1, "same")
reconstruction_2 = T.signal.convolve2d(noisy_image, filter_2, "same")
loss1 = T.abs(reconstruction_1 - true_image).mean()
loss2 = (T.abs(reconstruction_2 - true_image) ** 2).mean()
lr = symjax.nn.schedules.PiecewiseConstant(0.05, {5000: 0.01, 8000: 0.005})
symjax.nn.optimizers.Adam(loss1 + loss2, lr)
train = symjax.function(outputs=[loss1, loss2], updates=symjax.get_updates())
rec = symjax.function(outputs=[reconstruction_1, reconstruction_2])
losses_1 = list()
losses_2 = list()
for i in range(10000):
losses = train()
losses_1.append(losses[0])
losses_2.append(losses[1])
reconstruction_1, reconstruction_2 = rec()
plt.figure(figsize=(12, 6))
plt.subplot(221)
plt.semilogy(losses_1, "-x")
plt.ylabel("log-loss (l1)")
plt.xlabel("number of gradient updates")
plt.subplot(222)
plt.semilogy(losses_2, "-x")
plt.ylabel("log-loss (l2)")
plt.xlabel("number of gradient updates")
plt.subplot(245)
plt.imshow(reconstruction_1, aspect="auto", origin="lower", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("reconstruction (l1)")
plt.subplot(246)
plt.imshow(reconstruction_2, aspect="auto", origin="lower", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("reconstruction (l2)")
plt.subplot(247)
plt.imshow(true_image, aspect="auto", origin="lower", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("True image")
plt.subplot(248)
plt.imshow(noisy_image, aspect="auto", origin="lower", cmap="plasma")
plt.xticks([])
plt.yticks([])
plt.title("Convolved image")
plt.tight_layout()
plt.show()
Total running time of the script: ( 0 minutes 32.168 seconds)