Source code for symjax.data.fashionmnist

import os
import gzip
import urllib.request
import numpy as np
import time
from .utils import download_dataset


_urls = {
    "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz": "train-images.gz",
    "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz": "train-labels.gz",
    "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz": "test-images.gz",
    "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz": "test-labels.gz",
}


[docs]def load(path=None): """Grayscale image classification `Zalando <https://jobs.zalando.com/tech/>`_ 's article image classification. `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ is a dataset of `Zalando <https://jobs.zalando.com/tech/>`_ 's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes. We intend Fashion-MNIST to serve as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms. It shares the same image size and structure of training and testing splits. """ if path is None: path = os.environ["DATASET_PATH"] download_dataset(path, "fashionmnist", _urls) t0 = time.time() print("\tLoading fashionmnist") with gzip.open(path + "fashionmnist/train-labels.gz", "rb") as lbpath: train_labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) with gzip.open(path + "fashionmnist/train-images.gz", "rb") as lbpath: train_images = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=16) train_images = train_images.reshape((-1, 1, 28, 28)).astype("float32") with gzip.open(path + "fashionmnist/test-labels.gz", "rb") as lbpath: test_labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) with gzip.open(path + "fashionmnist/test-images.gz", "rb") as lbpath: test_images = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=16) test_images = test_images.reshape((-1, 1, 28, 28)).astype("float32") data = { "train_set/images": train_images, "train_set/labels": train_labels, "test_set/images": test_images, "test_set/labels": test_labels, } print("Dataset mnist loaded in {0:.2f}s.".format(time.time() - t0)) return data