Source code for tensorpack.dataflow.format

# -*- coding: utf-8 -*-
# File: format.py


import numpy as np
import os
import six

from ..utils import logger
from ..utils.argtools import log_once
from ..utils.serialize import loads
from ..utils.develop import create_dummy_class  # noqa
from ..utils.loadcaffe import get_caffe_pb
from ..utils.timer import timed_operation
from ..utils.utils import get_tqdm
from .base import DataFlowReentrantGuard, RNGDataFlow
from .common import MapData

__all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder',
           'CaffeLMDB', 'SVMLightData']

"""
Adapters for different data format.
"""


[docs]class HDF5Data(RNGDataFlow): """ Zip data from different paths in an HDF5 file. Warning: The current implementation will load all data into memory. (TODO) """ # TODO
[docs] def __init__(self, filename, data_paths, shuffle=True): """ Args: filename (str): h5 data file. data_paths (list): list of h5 paths to zipped. For example `['images', 'labels']`. shuffle (bool): shuffle all data. """ self.f = h5py.File(filename, 'r') logger.info("Loading {} to memory...".format(filename)) self.dps = [self.f[k][...] for k in data_paths] lens = [len(k) for k in self.dps] assert all(k == lens[0] for k in lens) self._size = lens[0] self.shuffle = shuffle
def __len__(self): return self._size def __iter__(self): idxs = list(range(self._size)) if self.shuffle: self.rng.shuffle(idxs) for k in idxs: yield [dp[k] for dp in self.dps]
[docs]class LMDBData(RNGDataFlow): """ Read a LMDB database and produce (k,v) raw bytes pairs. The raw bytes are usually not what you're interested in. You might want to use :class:`LMDBDataDecoder` or apply a mapper function after :class:`LMDBData`. """
[docs] def __init__(self, lmdb_path, shuffle=True, keys=None): """ Args: lmdb_path (str): a directory or a file. shuffle (bool): shuffle the keys or not. keys (list[str] or str): list of str as the keys, used only when shuffle is True. It can also be a format string e.g. ``{:0>8d}`` which will be formatted with the indices from 0 to *total_size - 1*. If not given, it will then look in the database for ``__keys__`` which :func:`LMDBSerializer.save` used to store the list of keys. If still not found, it will iterate over the database to find all the keys. """ self._lmdb_path = lmdb_path self._shuffle = shuffle self._open_lmdb() self._size = self._txn.stat()['entries'] self._set_keys(keys) logger.info("Found {} entries in {}".format(self._size, self._lmdb_path)) # Clean them up after finding the list of keys, since we don't want to fork them self._close_lmdb()
def _set_keys(self, keys=None): def find_keys(txn, size): logger.warn("Traversing the database to find keys is slow. Your should specify the keys.") keys = [] with timed_operation("Loading LMDB keys ...", log_start=True), \ get_tqdm(total=size) as pbar: for k in self._txn.cursor(): assert k[0] != b'__keys__' keys.append(k[0]) pbar.update() return keys self.keys = self._txn.get(b'__keys__') if self.keys is not None: self.keys = loads(self.keys) self._size -= 1 # delete this item if self._shuffle: # keys are necessary when shuffle is True if keys is None: if self.keys is None: self.keys = find_keys(self._txn, self._size) else: # check if key-format like '{:0>8d}' was given if isinstance(keys, six.string_types): self.keys = [keys.format(x) for x in np.arange(self._size)] else: self.keys = keys def _open_lmdb(self): self._lmdb = lmdb.open(self._lmdb_path, subdir=os.path.isdir(self._lmdb_path), readonly=True, lock=False, readahead=True, map_size=1099511627776 * 2, max_readers=100) self._txn = self._lmdb.begin() def _close_lmdb(self): self._lmdb.close() del self._lmdb del self._txn def reset_state(self): self._guard = DataFlowReentrantGuard() super(LMDBData, self).reset_state() self._open_lmdb() # open the LMDB in the worker process def __len__(self): return self._size def __iter__(self): with self._guard: if not self._shuffle: c = self._txn.cursor() for k, v in c: if k != b'__keys__': yield [k, v] else: self.rng.shuffle(self.keys) for k in self.keys: v = self._txn.get(k) yield [k, v]
[docs]class LMDBDataDecoder(MapData): """ Read a LMDB database with a custom decoder and produce decoded outputs."""
[docs] def __init__(self, lmdb_data, decoder): """ Args: lmdb_data: a :class:`LMDBData` instance. decoder (k,v -> dp | None): a function taking k, v and returning a datapoint, or return None to discard. """ def f(dp): return decoder(dp[0], dp[1]) super(LMDBDataDecoder, self).__init__(lmdb_data, f)
[docs]def CaffeLMDB(lmdb_path, shuffle=True, keys=None): """ Read a Caffe-format LMDB file where each value contains a ``caffe.Datum`` protobuf. Produces datapoints of the format: [HWC image, label]. Note that Caffe LMDB format is not efficient: it stores serialized raw arrays rather than JPEG images. Args: lmdb_path, shuffle, keys: same as :class:`LMDBData`. Example: .. code-block:: python ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}') """ cpb = get_caffe_pb() lmdb_data = LMDBData(lmdb_path, shuffle, keys) def decoder(k, v): try: datum = cpb.Datum() datum.ParseFromString(v) img = np.fromstring(datum.data, dtype=np.uint8) img = img.reshape(datum.channels, datum.height, datum.width) except Exception: log_once("Cannot read key {}".format(k), 'warn') return None return [img.transpose(1, 2, 0), datum.label] logger.warn("Caffe LMDB format doesn't store jpeg-compressed images, \ it's not recommended due to its inferior performance.") return LMDBDataDecoder(lmdb_data, decoder)
class DiskCacheData(RNGDataFlow): def __init__(self, path, shuffle=True): self._db = diskcache.Index(path) self._shuffle = shuffle self._size = len(self._db) def __len__(self): return self._size def __iter__(self): if not self._shuffle: for k in range(self._size): yield self._db[k] else: keys = list(range(self._size)) self.rng.shuffle(keys) for k in keys: yield self._db[k]
[docs]class SVMLightData(RNGDataFlow): """ Read X,y from an SVMlight file, and produce [X_i, y_i] pairs. """
[docs] def __init__(self, filename, shuffle=True): """ Args: filename (str): input file shuffle (bool): shuffle the data """ import sklearn.datasets # noqa self.X, self.y = sklearn.datasets.load_svmlight_file(filename) self.X = np.asarray(self.X.todense()) self.shuffle = shuffle
def __len__(self): return len(self.y) def __iter__(self): idxs = np.arange(self.__len__()) if self.shuffle: self.rng.shuffle(idxs) for id in idxs: yield [self.X[id, :], self.y[id]]
try: import h5py except ImportError: HDF5Data = create_dummy_class('HDF5Data', 'h5py') # noqa try: import lmdb except ImportError: for klass in ['LMDBData', 'LMDBDataDecoder', 'CaffeLMDB']: globals()[klass] = create_dummy_class(klass, 'lmdb') try: import diskcache except ImportError: DiskCacheData = create_dummy_class('DiskCacheData', 'diskcache') # noqa