Source code for symjax.data.cifar10

import os
import pickle
import tarfile
import time
from .utils import download_dataset

import numpy as np
from tqdm import tqdm


_dataset = "cifar10"
_urls = {
    "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz": "cifar-10-python.tar.gz"
}


label_to_name = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "sheep",
    9: "truck",
}


[docs]def load(path=None): """Image classification. The `CIFAR-10 < https: // www.cs.toronto.edu/~kriz/cifar.html >`_ dataset was collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. It consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class. Parameters ---------- path: str (optional) default ($DATASET_PATH), the path to look for the data and where the data will be downloaded if not present Returns ------- train_images: array train_labels: array test_images: array test_labels: array """ if path is None: path = os.environ["DATASET_PATH"] download_dataset(path, _dataset, _urls) t0 = time.time() tar = tarfile.open(os.path.join(path, _dataset, "cifar10.tar.gz"), "r:gz") # Load train set train_images = list() train_labels = list() for k in tqdm(range(1, 6), desc="Loading cifar10", ascii=True): f = tar.extractfile("cifar-10-batches-py/data_batch_" + str(k)).read() data_dic = pickle.loads(f, encoding="latin1") train_images.append(data_dic["data"].reshape((-1, 3, 32, 32))) train_labels.append(data_dic["labels"]) train_images = np.concatenate(train_images, 0).astype("float32") train_labels = np.concatenate(train_labels, 0).astype("int32") # Load test set f = tar.extractfile("cifar-10-batches-py/test_batch").read() data_dic = pickle.loads(f, encoding="latin1") test_images = data_dic["data"].reshape((-1, 3, 32, 32)).astype("float32") test_labels = np.array(data_dic["labels"]).astype("int32") data = { "train_set/images": train_images, "train_set/labels": train_labels, "test_set/images": test_images, "test_set/labels": test_labels, "label_to_name": label_to_name, } print("Dataset cifar10 loaded in{0:.2f}s.".format(time.time() - t0)) return data