# -*- coding: utf-8 -*-
# File: mnist.py
import gzip
import numpy
import os
from ...utils import logger
from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow
__all__ = ['Mnist', 'FashionMnist']
def maybe_download(url, work_directory):
"""Download the data from Yann's website, unless it's already here."""
filename = url.split('/')[-1]
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
logger.info("Downloading to {}...".format(filepath))
download(url, work_directory)
return filepath
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError(
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
data = data.astype('float32') / 255.0
return data
def extract_labels(filename):
"""Extract the labels into a 1D uint8 numpy array [index]."""
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError(
'Invalid magic number %d in MNIST label file: %s' %
(magic, filename))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
return labels
[docs]class Mnist(RNGDataFlow):
"""
Produces [image, label] in MNIST dataset,
image is 28x28 in the range [0,1], label is an int.
"""
_DIR_NAME = 'mnist_data'
_SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
[docs] def __init__(self, train_or_test, shuffle=True, dir=None):
"""
Args:
train_or_test (str): either 'train' or 'test'
shuffle (bool): shuffle the dataset
"""
if dir is None:
dir = get_dataset_path(self._DIR_NAME)
assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test
self.shuffle = shuffle
def get_images_and_labels(image_file, label_file):
f = maybe_download(self._SOURCE_URL + image_file, dir)
images = extract_images(f)
f = maybe_download(self._SOURCE_URL + label_file, dir)
labels = extract_labels(f)
assert images.shape[0] == labels.shape[0]
return images, labels
if self.train_or_test == 'train':
self.images, self.labels = get_images_and_labels(
'train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz')
else:
self.images, self.labels = get_images_and_labels(
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz')
def __len__(self):
return self.images.shape[0]
def __iter__(self):
idxs = list(range(self.__len__()))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
img = self.images[k].reshape((28, 28))
label = self.labels[k]
yield [img, label]
[docs]class FashionMnist(Mnist):
"""
Same API as :class:`Mnist`, but more fashion.
"""
_DIR_NAME = 'fashion_mnist_data'
_SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
[docs] def get_label_names(self):
"""
Returns:
[str]: the name of each class
"""
# copied from https://github.com/zalandoresearch/fashion-mnist
return ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
if __name__ == '__main__':
ds = Mnist('train')
ds.reset_state()
for _ in ds:
from IPython import embed
embed()
break