Source code for tensorpack.dataflow.dataset.bsds500

# -*- coding: utf-8 -*-
# File: bsds500.py


import glob
import numpy as np
import os

from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow

__all__ = ['BSDS500']


DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
DATA_SIZE = 70763455
IMG_W, IMG_H = 481, 321


[docs]class BSDS500(RNGDataFlow): """ `Berkeley Segmentation Data Set and Benchmarks 500 dataset <http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_. Produce ``(image, label)`` pair, where ``image`` has shape (321, 481, 3(BGR)) and ranges in [0,255]. ``Label`` is a floating point image of shape (321, 481) in range [0, 1]. The value of each pixel is ``number of times it is annotated as edge / total number of annotators for this image``. """
[docs] def __init__(self, name, data_dir=None, shuffle=True): """ Args: name (str): 'train', 'test', 'val' data_dir (str): a directory containing the original 'BSR' directory. """ # check and download data if data_dir is None: data_dir = get_dataset_path('bsds500_data') if not os.path.isdir(os.path.join(data_dir, 'BSR')): download(DATA_URL, data_dir, expect_size=DATA_SIZE) filename = DATA_URL.split('/')[-1] filepath = os.path.join(data_dir, filename) import tarfile tarfile.open(filepath, 'r:gz').extractall(data_dir) self.data_root = os.path.join(data_dir, 'BSR', 'BSDS500', 'data') assert os.path.isdir(self.data_root) self.shuffle = shuffle assert name in ['train', 'test', 'val'] self._load(name)
def _load(self, name): image_glob = os.path.join(self.data_root, 'images', name, '*.jpg') image_files = glob.glob(image_glob) gt_dir = os.path.join(self.data_root, 'groundTruth', name) self.data = np.zeros((len(image_files), IMG_H, IMG_W, 3), dtype='uint8') self.label = np.zeros((len(image_files), IMG_H, IMG_W), dtype='float32') for idx, f in enumerate(image_files): im = cv2.imread(f, cv2.IMREAD_COLOR) assert im is not None if im.shape[0] > im.shape[1]: im = np.transpose(im, (1, 0, 2)) assert im.shape[:2] == (IMG_H, IMG_W), "{} != {}".format(im.shape[:2], (IMG_H, IMG_W)) imgid = os.path.basename(f).split('.')[0] gt_file = os.path.join(gt_dir, imgid) gt = loadmat(gt_file)['groundTruth'][0] n_annot = gt.shape[0] gt = sum(gt[k]['Boundaries'][0][0] for k in range(n_annot)) gt = gt.astype('float32') gt *= 1.0 / n_annot if gt.shape[0] > gt.shape[1]: gt = gt.transpose() assert gt.shape == (IMG_H, IMG_W) self.data[idx] = im self.label[idx] = gt def __len__(self): return self.data.shape[0] def __iter__(self): idxs = np.arange(self.data.shape[0]) if self.shuffle: self.rng.shuffle(idxs) for k in idxs: yield [self.data[k], self.label[k]]
try: from scipy.io import loadmat import cv2 except ImportError: from ...utils.develop import create_dummy_class BSDS500 = create_dummy_class('BSDS500', ['scipy.io', 'cv2']) # noqa if __name__ == '__main__': a = BSDS500('val') a.reset_state() for k in a: cv2.imshow("haha", k[1].astype('uint8') * 255) cv2.waitKey(1000)