# -*- coding: utf-8 -*-
import numpy as np
import os
import platform
from collections import defaultdict

from ..utils import logger
from ..utils.serialize import dumps, loads
from ..utils.develop import create_dummy_class  # noqa
from ..utils.utils import get_tqdm
from .base import DataFlow
from .common import FixedSizeData, MapData
from .format import HDF5Data, LMDBData
from .raw import DataFromGenerator, DataFromList

__all__ = ['LMDBSerializer', 'NumpySerializer', 'TFRecordSerializer', 'HDF5Serializer']

def _reset_df_and_get_size(df):
        sz = len(df)
    except NotImplementedError:
        sz = 0
    return sz

[docs]class LMDBSerializer(): """ Serialize a Dataflow to a lmdb database, where the keys are indices and values are serialized datapoints. You will need to ``pip install lmdb`` to use it. Example: .. code-block:: python, "output.lmdb") new_df = LMDBSerializer.load("output.lmdb", shuffle=True) """
[docs] @staticmethod def save(df, path, write_frequency=5000): """ Args: df (DataFlow): the DataFlow to serialize. path (str): output path. Either a directory or an lmdb file. write_frequency (int): the frequency to write back data to disk. A smaller value reduces memory usage. """ assert isinstance(df, DataFlow), type(df) isdir = os.path.isdir(path) if isdir: assert not os.path.isfile(os.path.join(path, 'data.mdb')), "LMDB file exists!" else: assert not os.path.isfile(path), "LMDB file {} exists!".format(path) # It's OK to use super large map_size on Linux, but not on other platforms # See: map_size = 1099511627776 * 2 if platform.system() == 'Linux' else 128 * 10**6 db =, subdir=isdir, map_size=map_size, readonly=False, meminit=False, map_async=True) # need sync() at the end size = _reset_df_and_get_size(df) # put data into lmdb, and doubling the size if full. # Ref: def put_or_grow(txn, key, value): try: txn.put(key, value) return txn except lmdb.MapFullError: pass txn.abort() curr_size =['map_size'] new_size = curr_size * 2"Doubling LMDB map_size to {:.2f}GB".format(new_size / 10**9)) db.set_mapsize(new_size) txn = db.begin(write=True) txn = put_or_grow(txn, key, value) return txn with get_tqdm(total=size) as pbar: idx = -1 # LMDB transaction is not exception-safe! # although it has a context manager interface txn = db.begin(write=True) for idx, dp in enumerate(df): txn = put_or_grow(txn, u'{:08}'.format(idx).encode('ascii'), dumps(dp)) pbar.update() if (idx + 1) % write_frequency == 0: txn.commit() txn = db.begin(write=True) txn.commit() keys = [u'{:08}'.format(k).encode('ascii') for k in range(idx + 1)] with db.begin(write=True) as txn: txn = put_or_grow(txn, b'__keys__', dumps(keys))"Flushing database ...") db.sync() db.close()
[docs] @staticmethod def load(path, shuffle=True): """ Note: If you found deserialization being the bottleneck, you can use :class:`LMDBData` as the reader and run deserialization as a mapper in parallel. """ df = LMDBData(path, shuffle=shuffle) return MapData(df, LMDBSerializer._deserialize_lmdb)
@staticmethod def _deserialize_lmdb(dp): return loads(dp[1])
[docs]class NumpySerializer(): """ Serialize the entire dataflow to a npz dict. Note that this would have to store the entire dataflow in memory, and is also >10x slower than LMDB/TFRecord serializers. """
[docs] @staticmethod def save(df, path): """ Args: df (DataFlow): the DataFlow to serialize. path (str): output npz file. """ buffer = [] size = _reset_df_and_get_size(df) with get_tqdm(total=size) as pbar: for dp in df: buffer.append(dp) pbar.update() np.savez_compressed(path, buffer=np.asarray(buffer, dtype=np.object))
[docs] @staticmethod def load(path, shuffle=True): # allow_pickle defaults to False since numpy 1.16.3 # ( buffer = np.load(path, allow_pickle=True)['buffer'] return DataFromList(buffer, shuffle=shuffle)
[docs]class TFRecordSerializer(): """ Serialize datapoints to bytes (by tensorpack's default serializer) and write to a TFRecord file. Note that TFRecord does not support random access and is in fact not very performant. It's better to use :class:`LMDBSerializer`. """
[docs] @staticmethod def save(df, path): """ Args: df (DataFlow): the DataFlow to serialize. path (str): output tfrecord file. """ size = _reset_df_and_get_size(df) with tf.python_io.TFRecordWriter(path) as writer, get_tqdm(total=size) as pbar: for dp in df: writer.write(dumps(dp)) pbar.update()
[docs] @staticmethod def load(path, size=None): """ Args: size (int): total number of records. If not provided, the returned dataflow will have no `__len__()`. It's needed because this metadata is not stored in the TFRecord file. """ gen = tf.python_io.tf_record_iterator(path) ds = DataFromGenerator(gen) ds = MapData(ds, loads) if size is not None: ds = FixedSizeData(ds, size) return ds
[docs]class HDF5Serializer(): """ Write datapoints to a HDF5 file. Note that HDF5 files are in fact not very performant and currently do not support lazy loading. It's better to use :class:`LMDBSerializer`. """
[docs] @staticmethod def save(df, path, data_paths): """ Args: df (DataFlow): the DataFlow to serialize. path (str): output hdf5 file. data_paths (list[str]): list of h5 paths. It should have the same length as each datapoint, and each path should correspond to one component of the datapoint. """ size = _reset_df_and_get_size(df) buffer = defaultdict(list) with get_tqdm(total=size) as pbar: for dp in df: assert len(dp) == len(data_paths), "Datapoint has {} components!".format(len(dp)) for k, el in zip(data_paths, dp): buffer[k].append(el) pbar.update() with h5py.File(path, 'w') as hf, get_tqdm(total=len(data_paths)) as pbar: for data_path in data_paths: hf.create_dataset(data_path, data=buffer[data_path]) pbar.update()
[docs] @staticmethod def load(path, data_paths, shuffle=True): """ Args: data_paths (list): list of h5 paths to be zipped. """ return HDF5Data(path, data_paths, shuffle)
try: import lmdb except ImportError: LMDBSerializer = create_dummy_class('LMDBSerializer', 'lmdb') # noqa try: from tensorpack.compat import tfv1 as tf except ImportError: TFRecordSerializer = create_dummy_class('TFRecordSerializer', 'tensorflow') # noqa try: import h5py except ImportError: HDF5Serializer = create_dummy_class('HDF5Serializer', 'h5py') # noqa if __name__ == '__main__': from .raw import FakeData import time ds = FakeData([[300, 300, 3], [1]], 1000) print(time.time()), 'out.tfrecords') print(time.time()) df = TFRecordSerializer.load('out.tfrecords', size=1000) df.reset_state() for idx, dp in enumerate(df): pass print("TF Finished, ", idx) print(time.time()), 'out.lmdb') print(time.time()) df = LMDBSerializer.load('out.lmdb') df.reset_state() for idx, dp in enumerate(df): pass print("LMDB Finished, ", idx) print(time.time()), 'out.npz') print(time.time()) df = NumpySerializer.load('out.npz') df.reset_state() for idx, dp in enumerate(df): pass print("Numpy Finished, ", idx) print(time.time()) paths = ['p1', 'p2'], 'out.h5', paths) print(time.time()) df = HDF5Serializer.load('out.h5', paths) df.reset_state() for idx, dp in enumerate(df): pass print("HDF5 Finished, ", idx) print(time.time())