import os
import pickle
import tarfile
import time
from .utils import download_dataset
import numpy as np
labels_list = [
"apple",
"aquarium_fish",
"baby",
"bear",
"beaver",
"bed",
"bee",
"beetle",
"bicycle",
"bottle",
"bowl",
"boy",
"bridge",
"bus",
"butterfly",
"camel",
"can",
"castle",
"caterpillar",
"cattle",
"chair",
"chimpanzee",
"clock",
"cloud",
"cockroach",
"couch",
"crab",
"crocodile",
"cup",
"dinosaur",
"dolphin",
"elephant",
"flatfish",
"forest",
"fox",
"girl",
"hamster",
"house",
"kangaroo",
"keyboard",
"lamp",
"lawn_mower",
"leopard",
"lion",
"lizard",
"lobster",
"man",
"maple_tree",
"motorcycle",
"mountain",
"mouse",
"mushroom",
"oak_tree",
"orange",
"orchid",
"otter",
"palm_tree",
"pear",
"pickup_truck",
"pine_tree",
"plain",
"plate",
"poppy",
"porcupine",
"possum",
"rabbit",
"raccoon",
"ray",
"road",
"rocket",
"rose",
"sea",
"seal",
"shark",
"shrew",
"skunk",
"skyscraper",
"snail",
"snake",
"spider",
"squirrel",
"streetcar",
"sunflower",
"sweet_pepper",
"table",
"tank",
"telephone",
"television",
"tiger",
"tractor",
"train",
"trout",
"tulip",
"turtle",
"wardrobe",
"whale",
"willow_tree",
"wolf",
"woman",
"worm",
]
_dataset = "cifar100"
_urls = {"https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz": "cifar100.tar.gz"}
[docs]def load(path=None):
"""Image classification.
The `CIFAR-100 < https: // www.cs.toronto.edu/~kriz/cifar.html >`_ dataset is
just like the CIFAR-10, except it has 100 classes containing 600 images
each. There are 500 training images and 100 testing images per class.
The 100 classes in the CIFAR-100 are grouped into 20 superclasses. Each
image comes with a "fine" label(the class to which it belongs) and a
"coarse" label(the superclass to which it belongs)."""
if path is None:
path = os.environ["DATASET_PATH"]
download_dataset(path, _dataset, _urls)
t0 = time.time()
# Loading the file
tar = tarfile.open(os.path.join(path, _dataset, "cifar100.tar.gz"), "r:gz")
# Loading training set
f = tar.extractfile("cifar-100-python/train").read()
data = pickle.loads(f, encoding="latin1")
train_images = data["data"].reshape((-1, 3, 32, 32)).astype("float32")
train_fine = np.array(data["fine_labels"])
train_coarse = np.array(data["coarse_labels"])
# Loading test set
f = tar.extractfile("cifar-100-python/test").read()
data = pickle.loads(f, encoding="latin1")
test_images = data["data"].reshape((-1, 3, 32, 32)).astype("float32")
test_fine = np.array(data["fine_labels"])
test_coarse = np.array(data["coarse_labels"])
data = {
"train_set/images": train_images,
"train_set/labels": train_fine,
"train_set/coarse_labels": train_coarse,
"test_set/images": test_images,
"test_set/labels": test_fine,
"test_set/coarse_labels": test_coarse,
}
print("Dataset cifar100 loaded in {0:.2f}s.".format(time.time() - t0))
return data