Source code for tensorpack.dataflow.dataset.cifar

# -*- coding: utf-8 -*-
# File: cifar.py

#         Yukun Chen <cykustc@gmail.com>

import numpy as np
import os
import pickle
import tarfile

from ...utils import logger
from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow

__all__ = ['CifarBase', 'Cifar10', 'Cifar100']


DATA_URL_CIFAR_10 = ('http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', 170498071)
DATA_URL_CIFAR_100 = ('http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz', 169001437)


def maybe_download_and_extract(dest_directory, cifar_classnum):
    """Download and extract the tarball from Alex's website. Copied from tensorflow example """
    assert cifar_classnum == 10 or cifar_classnum == 100
    if cifar_classnum == 10:
        cifar_foldername = 'cifar-10-batches-py'
    else:
        cifar_foldername = 'cifar-100-python'
    if os.path.isdir(os.path.join(dest_directory, cifar_foldername)):
        logger.info("Found cifar{} data in {}.".format(cifar_classnum, dest_directory))
        return
    else:
        DATA_URL = DATA_URL_CIFAR_10 if cifar_classnum == 10 else DATA_URL_CIFAR_100
        filename = DATA_URL[0].split('/')[-1]
        filepath = os.path.join(dest_directory, filename)
        download(DATA_URL[0], dest_directory, expect_size=DATA_URL[1])
        tarfile.open(filepath, 'r:gz').extractall(dest_directory)


def read_cifar(filenames, cifar_classnum):
    assert cifar_classnum == 10 or cifar_classnum == 100
    ret = []
    for fname in filenames:
        fo = open(fname, 'rb')
        dic = pickle.load(fo, encoding='bytes')
        data = dic[b'data']
        if cifar_classnum == 10:
            label = dic[b'labels']
            IMG_NUM = 10000  # cifar10 data are split into blocks of 10000
        else:
            label = dic[b'fine_labels']
            IMG_NUM = 50000 if 'train' in fname else 10000
        fo.close()
        for k in range(IMG_NUM):
            img = data[k].reshape(3, 32, 32)
            img = np.transpose(img, [1, 2, 0])
            ret.append([img, label[k]])
    return ret


def get_filenames(dir, cifar_classnum):
    assert cifar_classnum == 10 or cifar_classnum == 100
    if cifar_classnum == 10:
        train_files = [os.path.join(
            dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)]
        test_files = [os.path.join(
            dir, 'cifar-10-batches-py', 'test_batch')]
        meta_file = os.path.join(dir, 'cifar-10-batches-py', 'batches.meta')
    elif cifar_classnum == 100:
        train_files = [os.path.join(dir, 'cifar-100-python', 'train')]
        test_files = [os.path.join(dir, 'cifar-100-python', 'test')]
        meta_file = os.path.join(dir, 'cifar-100-python', 'meta')
    return train_files, test_files, meta_file


def _parse_meta(filename, cifar_classnum):
    with open(filename, 'rb') as f:
        obj = pickle.load(f)
        return obj['label_names' if cifar_classnum == 10 else 'fine_label_names']


[docs]class CifarBase(RNGDataFlow): """ Produces [image, label] in Cifar10/100 dataset, image is 32x32x3 in the range [0,255]. label is an int. """
[docs] def __init__(self, train_or_test, shuffle=None, dir=None, cifar_classnum=10): """ Args: train_or_test (str): 'train' or 'test' shuffle (bool): defaults to True for training set. dir (str): path to the dataset directory cifar_classnum (int): 10 or 100 """ assert train_or_test in ['train', 'test'] assert cifar_classnum == 10 or cifar_classnum == 100 self.cifar_classnum = cifar_classnum if dir is None: dir = get_dataset_path('cifar{}_data'.format(cifar_classnum)) maybe_download_and_extract(dir, self.cifar_classnum) train_files, test_files, meta_file = get_filenames(dir, cifar_classnum) if train_or_test == 'train': self.fs = train_files else: self.fs = test_files for f in self.fs: if not os.path.isfile(f): raise ValueError('Failed to find file: ' + f) self._label_names = _parse_meta(meta_file, cifar_classnum) self.train_or_test = train_or_test self.data = read_cifar(self.fs, cifar_classnum) self.dir = dir if shuffle is None: shuffle = train_or_test == 'train' self.shuffle = shuffle
def __len__(self): return 50000 if self.train_or_test == 'train' else 10000 def __iter__(self): idxs = np.arange(len(self.data)) if self.shuffle: self.rng.shuffle(idxs) for k in idxs: # since cifar is quite small, just do it for safety yield self.data[k]
[docs] def get_per_pixel_mean(self, names=('train', 'test')): """ Args: names (tuple[str]): the names ('train' or 'test') of the datasets Returns: a mean image of all images in the given datasets, with size 32x32x3 """ for name in names: assert name in ['train', 'test'], name train_files, test_files, _ = get_filenames(self.dir, self.cifar_classnum) all_files = [] if 'train' in names: all_files.extend(train_files) if 'test' in names: all_files.extend(test_files) all_imgs = [x[0] for x in read_cifar(all_files, self.cifar_classnum)] arr = np.array(all_imgs, dtype='float32') mean = np.mean(arr, axis=0) return mean
[docs] def get_label_names(self): """ Returns: [str]: name of each class. """ return self._label_names
[docs] def get_per_channel_mean(self, names=('train', 'test')): """ Args: names (tuple[str]): the names ('train' or 'test') of the datasets Returns: An array of three values as mean of each channel, for all images in the given datasets. """ mean = self.get_per_pixel_mean(names) return np.mean(mean, axis=(0, 1))
[docs]class Cifar10(CifarBase): """ Produces [image, label] in Cifar10 dataset, image is 32x32x3 in the range [0,255]. label is an int. """
[docs] def __init__(self, train_or_test, shuffle=None, dir=None): """ Args: train_or_test (str): either 'train' or 'test'. shuffle (bool): shuffle the dataset, default to shuffle in training """ super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10)
[docs]class Cifar100(CifarBase): """ Similar to Cifar10""" def __init__(self, train_or_test, shuffle=None, dir=None): super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100)
if __name__ == '__main__': ds = Cifar10('train') mean = ds.get_per_channel_mean() print(mean) import cv2 ds.reset_state() for i, dp in enumerate(ds): if i == 100: break img = dp[0] cv2.imwrite("{:04d}.jpg".format(i), img)