Source code for tensorpack.dataflow.dataset.caltech101

# -*- coding: utf-8 -*-
# File: caltech101.py


import os

from ...utils import logger
from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow

__all__ = ["Caltech101Silhouettes"]


def maybe_download(url, work_directory):
    """Download the data from Marlin's website, unless it's already here."""
    filename = url.split("/")[-1]
    filepath = os.path.join(work_directory, filename)
    if not os.path.exists(filepath):
        logger.info("Downloading to {}...".format(filepath))
        download(url, work_directory)
    return filepath


[docs]class Caltech101Silhouettes(RNGDataFlow): """ Produces [image, label] in Caltech101 Silhouettes dataset, image is 28x28 in the range [0,1], label is an int in the range [0,100]. """ _DIR_NAME = "caltech101_data" _SOURCE_URL = "https://people.cs.umass.edu/~marlin/data/"
[docs] def __init__(self, name, shuffle=True, dir=None): """ Args: name (str): 'train', 'test', 'val' shuffle (bool): shuffle the dataset """ if dir is None: dir = get_dataset_path(self._DIR_NAME) assert name in ['train', 'test', 'val'] self.name = name self.shuffle = shuffle def get_images_and_labels(data_file): f = maybe_download(self._SOURCE_URL + data_file, dir) data = scipy.io.loadmat(f) return data self.data = get_images_and_labels("caltech101_silhouettes_28_split1.mat") if self.name == "train": self.images = self.data["train_data"].reshape((4100, 28, 28)) self.labels = self.data["train_labels"].ravel() - 1 elif self.name == "test": self.images = self.data["test_data"].reshape((2307, 28, 28)) self.labels = self.data["test_labels"].ravel() - 1 else: self.images = self.data["val_data"].reshape((2264, 28, 28)) self.labels = self.data["val_labels"].ravel() - 1
def __len__(self): return self.images.shape[0] def __iter__(self): idxs = list(range(self.__len__())) if self.shuffle: self.rng.shuffle(idxs) for k in idxs: img = self.images[k] label = self.labels[k] yield [img, label]
try: import scipy.io except ImportError: from ...utils.develop import create_dummy_class Caltech101Silhouettes = create_dummy_class('Caltech101Silhouettes', 'scipy.io') # noqa if __name__ == "__main__": ds = Caltech101Silhouettes("train") ds.reset_state() for _ in ds: from IPython import embed embed() break