tensorpack.dataflow package

Relevant tutorials: DataFlow, Why DataFlow?.

class tensorpack.dataflow.DataFlow[source]

Bases: object

Base class for all DataFlow

abstract __iter__()[source]
  • A dataflow is an iterable. The __iter__() method should yield a list or dict each time. Note that dict is partially supported at the moment: certain dataflow does not support dict.

  • The __iter__() method can be either finite (will stop iteration) or infinite (will not stop iteration). For a finite dataflow, __iter__() can be called again immediately after the previous call returned.

  • For many dataflow, the __iter__() method is non-reentrant, which means for an dataflow instance df, df.__iter__() cannot be called before the previous df.__iter__() call has finished (iteration has stopped). When a dataflow is non-reentrant, df.__iter__() should throw an exception if called before the previous call has finished. For such non-reentrant dataflows, if you need to use the same dataflow in two places, you need to create two dataflow instances.

Yields

list/dict – The datapoint, i.e. list/dict of components.

__len__()[source]
  • A dataflow can optionally implement __len__(). If not implemented, it will throw NotImplementedError.

  • It returns an integer representing the size of the dataflow. The return value may not be accurate or meaningful at all. When saying the length is “accurate”, it means that __iter__() will always yield this many of datapoints before it stops iteration.

  • There could be many reasons why __len__() is inaccurate. For example, some dataflow has dynamic size, if it throws away datapoints on the fly. Some dataflow mixes the datapoints between consecutive passes over the dataset, due to parallelism and buffering. In this case it does not make sense to stop the iteration anywhere.

  • Due to the above reasons, the length is only a rough guidance. And it’s up to the user how to interpret it. Inside tensorpack it’s only used in these places:

    • A default steps_per_epoch in training, but you probably want to customize it yourself, especially when using data-parallel trainer.

    • The length of progress bar when processing a dataflow.

    • Used by InferenceRunner to get the number of iterations in inference. In this case users are responsible for making sure that __len__() is “accurate”. This is to guarantee that inference is run on a fixed set of images.

Returns

int – rough size of this dataflow.

Raises

NotImplementedError

reset_state()[source]
  • The caller must guarantee that reset_state() should be called once and only once by the process that uses the dataflow before __iter__() is called. The caller thread of this method should stay alive to keep this dataflow alive.

  • It is meant for certain initialization that involves processes, e.g., initialize random number generators (RNG), create worker processes.

    Because it’s very common to use RNG in data processing, developers of dataflow can also subclass RNGDataFlow to have easier access to a properly-initialized RNG.

  • A dataflow is not fork-safe after reset_state() is called (because this will violate the guarantee). There are a few other dataflows that are not fork-safe anytime, which will be mentioned in the docs.

  • You should take the responsibility and follow the above guarantee if you’re the caller of a dataflow yourself (either when you’re using dataflow outside of tensorpack, or if you’re writing a wrapper dataflow).

  • Tensorpack’s built-in forking dataflows (MultiProcessRunner, MultiProcessMapData, etc) and other component that uses dataflows (InputSource) already take care of the responsibility of calling this method.

class tensorpack.dataflow.ProxyDataFlow(ds)[source]

Bases: tensorpack.dataflow.base.DataFlow

Base class for DataFlow that proxies another. Every method is proxied to self.ds unless overriden by a subclass.

__init__(ds)[source]
Parameters

ds (DataFlow) – DataFlow to proxy.

class tensorpack.dataflow.RNGDataFlow[source]

Bases: tensorpack.dataflow.base.DataFlow

A DataFlow with RNG

reset_state()[source]

Reset the RNG

rng = None

self.rng is a np.random.RandomState instance that is initialized correctly (with different seeds in each process) in RNGDataFlow.reset_state().

exception tensorpack.dataflow.DataFlowTerminated[source]

Bases: BaseException

An exception indicating that the DataFlow is unable to produce any more data, i.e. something wrong happened so that calling get_data() cannot give a valid iterator any more. In most DataFlow this will never be raised.

class tensorpack.dataflow.TestDataSpeed(ds, size=5000, warmup=0)[source]

Bases: tensorpack.dataflow.base.ProxyDataFlow

Test the speed of a DataFlow

__init__(ds, size=5000, warmup=0)[source]
Parameters
  • ds (DataFlow) – the DataFlow to test.

  • size (int) – number of datapoints to fetch.

  • warmup (int) – warmup iterations

__iter__()[source]

Will run testing at the beginning, then produce data normally.

start()[source]

Start testing with a progress bar.

class tensorpack.dataflow.PrintData(ds, num=1, name=None, max_depth=3, max_list=3)[source]

Bases: tensorpack.dataflow.base.ProxyDataFlow

Behave like an identity proxy, but print shape and range of the first few datapoints. Good for debugging.

Example

Place it somewhere in your dataflow like

def create_my_dataflow():
    ds = SomeDataSource('path/to/lmdb')
    ds = SomeInscrutableMappings(ds)
    ds = PrintData(ds, num=2, max_list=2)
    return ds
ds = create_my_dataflow()
# other code that uses ds

When datapoints are taken from the dataflow, it will print outputs like:

[0110 09:22:21 @common.py:589] DataFlow Info:
datapoint 0<2 with 4 components consists of
   0: float with value 0.0816501893251
   1: ndarray:int32 of shape (64,) in range [0, 10]
   2: ndarray:float32 of shape (64, 64) in range [-1.2248, 1.2177]
   3: list of len 50
      0: ndarray:int32 of shape (64, 64) in range [-128, 80]
      1: ndarray:float32 of shape (64, 64) in range [0.8400, 0.6845]
      ...
datapoint 1<2 with 4 components consists of
   0: float with value 5.88252075399
   1: ndarray:int32 of shape (64,) in range [0, 10]
   2: ndarray:float32 of shape (64, 64) with range [-0.9011, 0.8491]
   3: list of len 50
      0: ndarray:int32 of shape (64, 64) in range [-70, 50]
      1: ndarray:float32 of shape (64, 64) in range [0.7400, 0.3545]
      ...
__init__(ds, num=1, name=None, max_depth=3, max_list=3)[source]
Parameters
  • ds (DataFlow) – input DataFlow.

  • num (int) – number of dataflow points to print.

  • name (str, optional) – name to identify this DataFlow.

  • max_depth (int, optional) – stop output when too deep recursion in sub elements

  • max_list (int, optional) – stop output when too many sub elements

class tensorpack.dataflow.BatchData(ds, batch_size, remainder=False, use_list=False)[source]

Bases: tensorpack.dataflow.base.ProxyDataFlow

Stack datapoints into batches. It produces datapoints of the same number of components as ds, but each component has one new extra dimension of size batch_size. The batch can be either a list of original components, or (by default) a numpy array of original components.

__init__(ds, batch_size, remainder=False, use_list=False)[source]
Parameters
  • ds (DataFlow) – A dataflow that produces either list or dict. When use_list=False, the components of ds must be either scalars or np.ndarray, and have to be consistent in shapes.

  • batch_size (int) – batch size

  • remainder (bool) – When the remaining datapoints in ds is not enough to form a batch, whether or not to also produce the remaining data as a smaller batch. If set to False, all produced datapoints are guaranteed to have the same batch size. If set to True, len(ds) must be accurate.

  • use_list (bool) – if True, each component will contain a list of datapoints instead of an numpy array of an extra dimension.

__iter__()[source]
Yields

Batched data by stacking each component on an extra 0th dimension.

static aggregate_batch(data_holder, use_list=False)[source]

Aggregate a list of datapoints to one batched datapoint.

Parameters
  • data_holder (list[dp]) – each dp is either a list or a dict.

  • use_list (bool) – whether to batch data into a list or a numpy array.

Returns

dp – either a list or a dict, depend on the inputs. Each item is a batched version of the corresponding inputs.

class tensorpack.dataflow.BatchDataByShape(ds, batch_size, idx)[source]

Bases: tensorpack.dataflow.common.BatchData

Group datapoints of the same shape together to batches. It doesn’t require input DataFlow to be homogeneous anymore: it can have datapoints of different shape, and batches will be formed from those who have the same shape.

Note

It is implemented by a dict{shape -> datapoints}. Therefore, datapoints of uncommon shapes may never be enough to form a batch and never get generated.

__init__(ds, batch_size, idx)[source]
Parameters
  • ds (DataFlow) – input DataFlow. dp[idx] has to be an np.ndarray.

  • batch_size (int) – batch size

  • idx (int) – dp[idx].shape will be used to group datapoints. Other components are assumed to be batch-able.

class tensorpack.dataflow.FixedSizeData(ds, size, keep_state=True)[source]

Bases: tensorpack.dataflow.base.ProxyDataFlow

Generate data from another DataFlow, but with a fixed total count.

__init__(ds, size, keep_state=True)[source]
Parameters
  • ds (DataFlow) – input dataflow

  • size (int) – size

  • keep_state (bool) – keep the iterator state of ds between calls to __iter__(), so that the next call will continue the previous iteration over ds, instead of reinitializing an iterator.

Example:

ds produces: 1, 2, 3, 4, 5; 1, 2, 3, 4, 5; ...
FixedSizeData(ds, 3, True): 1, 2, 3; 4, 5, 1; 2, 3, 4; ...
FixedSizeData(ds, 3, False): 1, 2, 3; 1, 2, 3; ...
FixedSizeData(ds, 6, False): 1, 2, 3, 4, 5, 1; 1, 2, 3, 4, 5, 1;...
class tensorpack.dataflow.MapData(ds, func)[source]

Bases: tensorpack.dataflow.base.ProxyDataFlow

Apply a mapper/filter on the datapoints of a DataFlow.

Note

  1. Please make sure func doesn’t modify its arguments in place, unless you’re certain it’s safe.

  2. If you discard some datapoints, len(MapData(ds)) will be incorrect.

Example

ds = Mnist('train')  # each datapoint is [img, label]
ds = MapData(ds, lambda dp: [dp[0] * 255, dp[1]])
__init__(ds, func)[source]
Parameters
  • ds (DataFlow) – input DataFlow

  • func (datapoint -> datapoint | None) – takes a datapoint and returns a new datapoint. Return None to discard/skip this datapoint.

class tensorpack.dataflow.MapDataComponent(ds, func, index=0)[source]

Bases: tensorpack.dataflow.common.MapData

Apply a mapper/filter on a datapoint component.

Note

  1. This dataflow itself doesn’t modify the datapoints. But please make sure func doesn’t modify its arguments in place, unless you’re certain it’s safe.

  2. If you discard some datapoints, len(MapDataComponent(ds, ..)) will be incorrect.

Example

ds = Mnist('train')  # each datapoint is [img, label]
ds = MapDataComponent(ds, lambda img: img * 255, 0)  # map the 0th component
__init__(ds, func, index=0)[source]
Parameters
  • ds (DataFlow) – input DataFlow which produces either list or dict.

  • func (TYPE -> TYPE|None) – takes dp[index], returns a new value for dp[index]. Return None to discard/skip this datapoint.

  • index (int or str) – index or key of the component.

class tensorpack.dataflow.RepeatedData(ds, num)[source]

Bases: tensorpack.dataflow.base.ProxyDataFlow

Take data points from another DataFlow and produce them until it’s exhausted for certain amount of times. i.e.: dp1, dp2, …. dpn, dp1, dp2, ….dpn

__init__(ds, num)[source]
Parameters
  • ds (DataFlow) – input DataFlow

  • num (int) – number of times to repeat ds. Set to -1 to repeat ds infinite times.

__len__()[source]
Raises

ValueError

class tensorpack.dataflow.RepeatedDataPoint(ds, num)[source]

Bases: tensorpack.dataflow.base.ProxyDataFlow

Take data points from another DataFlow and produce them a certain number of times. i.e.: dp1, dp1, …, dp1, dp2, …, dp2, …

__init__(ds, num)[source]
Parameters
  • ds (DataFlow) – input DataFlow

  • num (int) – number of times to repeat each datapoint.

class tensorpack.dataflow.RandomChooseData(df_lists)[source]

Bases: tensorpack.dataflow.base.RNGDataFlow

Randomly choose from several DataFlow. Stop producing when any of them is exhausted.

__init__(df_lists)[source]
Parameters

df_lists (list) – a list of DataFlow, or a list of (DataFlow, probability) tuples. Probabilities must sum to 1 if used.

class tensorpack.dataflow.RandomMixData(df_lists)[source]

Bases: tensorpack.dataflow.base.RNGDataFlow

Perfectly mix datapoints from several DataFlow using their __len__(). Will stop when all DataFlow exhausted.

__init__(df_lists)[source]
Parameters

df_lists (list) – a list of DataFlow. All DataFlow must implement __len__().

class tensorpack.dataflow.JoinData(df_lists)[source]

Bases: tensorpack.dataflow.base.DataFlow

Join the components from each DataFlow. See below for its behavior.

Note that you can’t join a DataFlow that produces lists with one that produces dicts.

Example:

df1 produces: [c1, c2]
df2 produces: [c3, c4]
joined: [c1, c2, c3, c4]

df1 produces: {"a":c1, "b":c2}
df2 produces: {"c":c3}
joined: {"a":c1, "b":c2, "c":c3}
__init__(df_lists)[source]
Parameters

df_lists (list) – a list of DataFlow. When these dataflows have different sizes, JoinData will stop when any of them is exhausted. The list could contain the same DataFlow instance more than once, but note that in that case __iter__ will then also be called many times.

__len__()[source]

Return the minimum size among all.

class tensorpack.dataflow.ConcatData(df_lists)[source]

Bases: tensorpack.dataflow.base.DataFlow

Concatenate several DataFlow. Produce datapoints from each DataFlow and start the next when one DataFlow is exhausted.

__init__(df_lists)[source]
Parameters

df_lists (list) – a list of DataFlow.

tensorpack.dataflow.SelectComponent(ds, idxs)[source]

Select / reorder components from datapoints.

Parameters

Example:

original df produces: [c1, c2, c3]
idxs: [2,1]
this df: [c3, c2]
class tensorpack.dataflow.LocallyShuffleData(ds, buffer_size, num_reuse=1, shuffle_interval=None)[source]

Bases: tensorpack.dataflow.base.ProxyDataFlow, tensorpack.dataflow.base.RNGDataFlow

Buffer the datapoints from a given dataflow, and shuffle them before producing them. This can be used as an alternative when a complete random shuffle is too expensive or impossible for the data source.

This dataflow has the following behavior:

  1. It takes datapoints from the given dataflow ds to an internal buffer of fixed size. Each datapoint is duplicated for num_reuse times.

  2. Once the buffer is full, this dataflow starts to yield data from the beginning of the buffer, and new datapoints will be added to the end of the buffer. This is like a FIFO queue.

  3. The internal buffer is shuffled after every shuffle_interval datapoints that come from ds.

To maintain shuffling states, this dataflow is not reentrant.

Datapoints from one pass of ds will get mixed with datapoints from a different pass. As a result, the iterator of this dataflow will run indefinitely because it does not make sense to stop the iteration anywhere.

__init__(ds, buffer_size, num_reuse=1, shuffle_interval=None)[source]
Parameters
  • ds (DataFlow) – input DataFlow.

  • buffer_size (int) – size of the buffer.

  • num_reuse (int) – duplicate each datapoints several times into the buffer to improve speed, but duplication may hurt your model.

  • shuffle_interval (int) – shuffle the buffer after this many datapoints were produced from the given dataflow. Frequent shuffle on large buffer may affect speed, but infrequent shuffle may not provide enough randomness. Defaults to buffer_size / 3

class tensorpack.dataflow.CacheData(ds, shuffle=False)[source]

Bases: tensorpack.dataflow.base.ProxyDataFlow

Completely cache the first pass of a DataFlow in memory, and produce from the cache thereafter.

NOTE: The user should not stop the iterator before it has reached the end.

Otherwise the cache may be incomplete.

__init__(ds, shuffle=False)[source]
Parameters
  • ds (DataFlow) – input DataFlow.

  • shuffle (bool) – whether to shuffle the cache before yielding from it.

class tensorpack.dataflow.HDF5Data(filename, data_paths, shuffle=True)[source]

Bases: tensorpack.dataflow.base.RNGDataFlow

Zip data from different paths in an HDF5 file.

Warning

The current implementation will load all data into memory. (TODO)

__init__(filename, data_paths, shuffle=True)[source]
Parameters
  • filename (str) – h5 data file.

  • data_paths (list) – list of h5 paths to zipped. For example [‘images’, ‘labels’].

  • shuffle (bool) – shuffle all data.

class tensorpack.dataflow.LMDBData(lmdb_path, shuffle=True, keys=None)[source]

Bases: tensorpack.dataflow.base.RNGDataFlow

Read a LMDB database and produce (k,v) raw bytes pairs. The raw bytes are usually not what you’re interested in. You might want to use LMDBDataDecoder or apply a mapper function after LMDBData.

__init__(lmdb_path, shuffle=True, keys=None)[source]
Parameters
  • lmdb_path (str) – a directory or a file.

  • shuffle (bool) – shuffle the keys or not.

  • keys (list[str] or str) –

    list of str as the keys, used only when shuffle is True. It can also be a format string e.g. {:0>8d} which will be formatted with the indices from 0 to total_size - 1.

    If not given, it will then look in the database for __keys__ which LMDBSerializer.save() used to store the list of keys. If still not found, it will iterate over the database to find all the keys.

class tensorpack.dataflow.LMDBDataDecoder(lmdb_data, decoder)[source]

Bases: tensorpack.dataflow.common.MapData

Read a LMDB database with a custom decoder and produce decoded outputs.

__init__(lmdb_data, decoder)[source]
Parameters
  • lmdb_data – a LMDBData instance.

  • decoder (k,v -> dp | None) – a function taking k, v and returning a datapoint, or return None to discard.

tensorpack.dataflow.CaffeLMDB(lmdb_path, shuffle=True, keys=None)[source]

Read a Caffe-format LMDB file where each value contains a caffe.Datum protobuf. Produces datapoints of the format: [HWC image, label].

Note that Caffe LMDB format is not efficient: it stores serialized raw arrays rather than JPEG images.

Parameters

shuffle, keys (lmdb_path,) – same as LMDBData.

Example

ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
class tensorpack.dataflow.SVMLightData(filename, shuffle=True)[source]

Bases: tensorpack.dataflow.base.RNGDataFlow

Read X,y from an SVMlight file, and produce [X_i, y_i] pairs.

__init__(filename, shuffle=True)[source]
Parameters
  • filename (str) – input file

  • shuffle (bool) – shuffle the data

class tensorpack.dataflow.ImageFromFile(files, channel=3, resize=None, shuffle=False)[source]

Bases: tensorpack.dataflow.base.RNGDataFlow

Produce images read from a list of files as (h, w, c) arrays.

__init__(files, channel=3, resize=None, shuffle=False)[source]
Parameters
  • files (list) – list of file paths.

  • channel (int) – 1 or 3. Will convert grayscale to RGB images if channel==3. Will produce (h, w, 1) array if channel==1.

  • resize (tuple) – int or (h, w) tuple. If given, resize the image.

class tensorpack.dataflow.AugmentImageComponent(ds, augmentors, index=0, copy=True, catch_exceptions=False)[source]

Bases: tensorpack.dataflow.common.MapDataComponent

Apply image augmentors on 1 image component.

__init__(ds, augmentors, index=0, copy=True, catch_exceptions=False)[source]
Parameters
  • ds (DataFlow) – input DataFlow.

  • augmentors (AugmentorList) – a list of imgaug.ImageAugmentor to be applied in order.

  • index (int or str) – the index or key of the image component to be augmented in the datapoint.

  • copy (bool) – Some augmentors modify the input images. When copy is True, a copy will be made before any augmentors are applied, to keep the original images not modified. Turn it off to save time when you know it’s OK.

  • catch_exceptions (bool) – when set to True, will catch all exceptions and only warn you when there are too many (>100). Can be used to ignore occasion errors in data.

class tensorpack.dataflow.AugmentImageCoordinates(ds, augmentors, img_index=0, coords_index=1, copy=True, catch_exceptions=False)[source]

Bases: tensorpack.dataflow.common.MapData

Apply image augmentors on an image and a list of coordinates. Coordinates must be a Nx2 floating point array, each row is (x, y).

__init__(ds, augmentors, img_index=0, coords_index=1, copy=True, catch_exceptions=False)[source]
Parameters
class tensorpack.dataflow.AugmentImageComponents(ds, augmentors, index=0, 1, coords_index=, copy=True, catch_exceptions=False)[source]

Bases: tensorpack.dataflow.common.MapData

Apply image augmentors on several components, with shared augmentation parameters.

Example

ds = MyDataFlow()   # produce [image(HWC), segmask(HW), keypoint(Nx2)]
ds = AugmentImageComponents(
    ds, augs,
    index=(0,1), coords_index=(2,))
__init__(ds, augmentors, index=0, 1, coords_index=, copy=True, catch_exceptions=False)[source]
Parameters
class tensorpack.dataflow.MultiProcessRunner(ds, num_prefetch, num_proc)[source]

Bases: tensorpack.dataflow.base.ProxyDataFlow

Running a DataFlow in >=1 processes using Python multiprocessing utilities. It will fork the process that calls __init__(), collect datapoints from ds in each process by a Python multiprocessing.Queue.

Note

  1. (Data integrity) An iterator cannot run faster automatically – what’s happening is that the process will be forked num_proc times. There will be num_proc dataflow running in parallel and independently. As a result, we have the following guarantee on the dataflow correctness:

    1. When num_proc=1, this dataflow produces the same data as the given dataflow in the same order.

    2. When num_proc>1, if each sample from the given dataflow is i.i.d., then this dataflow produces the same distribution of data as the given dataflow. This implies that there will be duplication, reordering, etc. You probably only want to use it for training.

      For example, if your original dataflow contains no randomness and produces the same first datapoint, then after parallel prefetching, the datapoint will be produced num_proc times at the beginning. Even when your original dataflow is fully shuffled, you still need to be aware of the Birthday Paradox and know that you’ll likely see duplicates.

    To utilize parallelism with more strict data integrity, you can use the parallel versions of MapData: MultiThreadMapData, MultiProcessMapData.

  2. This has more serialization overhead than MultiProcessRunnerZMQ when data is large.

  3. You can nest like this: MultiProcessRunnerZMQ(MultiProcessRunner(df, num_proc=a), num_proc=b). A total of a instances of df worker processes will be created.

  4. Fork happens in __init__. reset_state() is a no-op. DataFlow in the worker processes will be reset at the time of fork.

  5. This DataFlow does support windows. However, Windows requires more strict picklability on processes, which means that some code that’s forkable on Linux may not be forkable on Windows. If that happens you’ll need to re-organize some part of code that’s not forkable.

__init__(ds, num_prefetch, num_proc)[source]
Parameters
  • ds (DataFlow) – input DataFlow.

  • num_prefetch (int) – size of the queue to hold prefetched datapoints. Required.

  • num_proc (int) – number of processes to use. Required.

class tensorpack.dataflow.MultiProcessRunnerZMQ(ds, num_proc=1, hwm=50)[source]

Bases: tensorpack.dataflow.parallel._MultiProcessZMQDataFlow

Run a DataFlow in >=1 processes, with ZeroMQ for communication. It will fork the calling process of reset_state(), and collect datapoints from the given dataflow in each process by ZeroMQ IPC pipe. This is typically faster than MultiProcessRunner.

Note

  1. (Data integrity) An iterator cannot run faster automatically – what’s happening is that the process will be forked num_proc times. There will be num_proc dataflow running in parallel and independently. As a result, we have the following guarantee on the dataflow correctness:

    1. When num_proc=1, this dataflow produces the same data as the given dataflow in the same order.

    2. When num_proc>1, if each sample from the given dataflow is i.i.d., then this dataflow produces the same distribution of data as the given dataflow. This implies that there will be duplication, reordering, etc. You probably only want to use it for training.

      For example, if your original dataflow contains no randomness and produces the same first datapoint, then after parallel prefetching, the datapoint will be produced num_proc times at the beginning. Even when your original dataflow is fully shuffled, you still need to be aware of the Birthday Paradox and know that you’ll likely see duplicates.

    To utilize parallelism with more strict data integrity, you can use the parallel versions of MapData: MultiThreadMapData, MultiProcessMapData.

  2. reset_state() of the given dataflow will be called once and only once in the worker processes.

  3. The fork of processes happened in this dataflow’s reset_state() method. Please note that forking a TensorFlow GPU session may be unsafe. If you’re managing this dataflow on your own, it’s better to fork before creating the session.

  4. (Fork-safety) After the fork has happened, this dataflow becomes not fork-safe. i.e., if you fork an already reset instance of this dataflow, it won’t be usable in the forked process. Therefore, do not nest two MultiProcessRunnerZMQ.

  5. (Thread-safety) ZMQ is not thread safe. Therefore, do not call get_data() of the same dataflow in more than 1 threads.

  6. This dataflow does not support windows. Use MultiProcessRunner which works on windows.

  7. (For Mac only) A UNIX named pipe will be created in the current directory. However, certain non-local filesystem such as NFS/GlusterFS/AFS doesn’t always support pipes. You can change the directory by export TENSORPACK_PIPEDIR=/other/dir. In particular, you can use somewhere under ‘/tmp’ which is usually local.

    Note that some non-local FS may appear to support pipes and code may appear to run but crash with bizarre error. Also note that ZMQ limits the maximum length of pipe path. If you hit the limit, you can set the directory to a softlink which points to a local directory.

__init__(ds, num_proc=1, hwm=50)[source]
Parameters
  • ds (DataFlow) – input DataFlow.

  • num_proc (int) – number of processes to use.

  • hwm (int) – the zmq “high-water mark” (queue size) for both sender and receiver.

class tensorpack.dataflow.MultiThreadRunner(get_df, num_prefetch, num_thread)[source]

Bases: tensorpack.dataflow.base.DataFlow

Create multiple dataflow instances and run them each in one thread. Collect outputs from them with a queue.

Note

  1. (Data integrity) An iterator cannot run faster automatically – what’s happening is that each thread will create a dataflow iterator. There will be num_thread dataflow running in parallel and independently. As a result, we have the following guarantee on the dataflow correctness:

    1. When num_thread=1, this dataflow produces the same data as the given dataflow in the same order.

    2. When num_thread>1, if each sample from the given dataflow is i.i.d., then this dataflow produces the same distribution of data as the given dataflow. This implies that there will be duplication, reordering, etc. You probably only want to use it for training.

      For example, if your original dataflow contains no randomness and produces the same first datapoint, then after parallel prefetching, the datapoint will be produced num_thread times at the beginning. Even when your original dataflow is fully shuffled, you still need to be aware of the Birthday Paradox and know that you’ll likely see duplicates.

    To utilize parallelism with more strict data integrity, you can use the parallel versions of MapData: MultiThreadMapData, MultiProcessMapData.

__init__(get_df, num_prefetch, num_thread)[source]
Parameters
  • get_df (-> DataFlow) – a callable which returns a DataFlow. Each thread will call this function to get the DataFlow to use. Therefore do not return the same DataFlow object for each call, unless your dataflow is stateless.

  • num_prefetch (int) – size of the queue

  • num_thread (int) – number of threads

class tensorpack.dataflow.MultiThreadMapData(ds, num_thread=None, map_func=None, *, buffer_size=200, strict=False)[source]

Bases: tensorpack.dataflow.parallel_map._ParallelMapData

Same as MapData, but start threads to run the mapping function. This is useful when the mapping function is the bottleneck, but you don’t want to start processes for the entire dataflow pipeline.

The semantics of this class is identical to MapData except for the ordering. Threads run in parallel and can take different time to run the mapping function. Therefore the order of datapoints won’t be preserved.

When strict=True, MultiThreadMapData(df, ...) is guaranteed to produce the exact set of data as MapData(df, ...), if both are iterated until StopIteration. But the produced data will have different ordering. The behavior of strict mode is undefined if the given dataflow df is infinite.

When strict=False, the data that’s produced by MultiThreadMapData(df, ...) is a reordering of the data produced by RepeatedData(MapData(df, ...), -1). In other words, first pass of MultiThreadMapData.__iter__ may contain datapoints from the second pass of df.__iter__.

Note

  1. You should avoid starting many threads in your main process to reduce GIL contention.

    The threads will only start in the process which calls reset_state(). Therefore you can use MultiProcessRunnerZMQ(MultiThreadMapData(...), 1) to reduce GIL contention.

__init__(ds, num_thread=None, map_func=None, *, buffer_size=200, strict=False)[source]
Parameters
  • ds (DataFlow) – the dataflow to map

  • num_thread (int) – number of threads to use

  • map_func (callable) – datapoint -> datapoint | None. Return None to discard/skip the datapoint.

  • buffer_size (int) – number of datapoints in the buffer

  • strict (bool) – use “strict mode”, see notes above.

tensorpack.dataflow.MultiProcessMapData

alias of tensorpack.dataflow.parallel_map.MultiProcessMapDataZMQ

class tensorpack.dataflow.MultiProcessMapDataZMQ(ds, num_proc=None, map_func=None, *, buffer_size=200, strict=False)[source]

Bases: tensorpack.dataflow.parallel_map._ParallelMapData, tensorpack.dataflow.parallel._MultiProcessZMQDataFlow

Same as MapData, but start processes to run the mapping function, and communicate with ZeroMQ pipe.

The semantics of this class is identical to MapData except for the ordering. Processes run in parallel and can take different time to run the mapping function. Therefore the order of datapoints won’t be preserved.

When strict=True, MultiProcessMapData(df, ...) is guaranteed to produce the exact set of data as MapData(df, ...), if both are iterated until StopIteration. But the produced data will have different ordering. The behavior of strict mode is undefined if the given dataflow df is infinite.

When strict=False, the data that’s produced by MultiProcessMapData(df, ...) is a reordering of the data produced by RepeatedData(MapData(df, ...), -1). In other words, first pass of MultiProcessMapData.__iter__ may contain datapoints from the second pass of df.__iter__.

__init__(ds, num_proc=None, map_func=None, *, buffer_size=200, strict=False)[source]
Parameters
  • ds (DataFlow) – the dataflow to map

  • num_proc (int) – number of threads to use

  • map_func (callable) – datapoint -> datapoint | None. Return None to discard/skip the datapoint.

  • buffer_size (int) – number of datapoints in the buffer

  • strict (bool) – use “strict mode”, see notes above.

tensorpack.dataflow.MultiProcessMapAndBatchData

alias of tensorpack.dataflow.parallel_map.MultiProcessMapAndBatchDataZMQ

class tensorpack.dataflow.MultiProcessMapAndBatchDataZMQ(ds, num_proc, map_func, batch_size, buffer_size=None)[source]

Bases: tensorpack.dataflow.parallel._MultiProcessZMQDataFlow

Similar to MultiProcessMapDataZMQ, except that this DataFlow also does batching in parallel in the worker processes. Therefore it can be helpful if you wish to hide the latency of batching.

When nr_proc==1, the behavior of this class is identical to BatchData(MapData(ds, map_func), batch_size).

When nr_proc>1, the datapoints may be grouped in arbitrary order, or grouped with datapoints from a different pass of the given dataflow.

__init__(ds, num_proc, map_func, batch_size, buffer_size=None)[source]
Parameters
  • ds (DataFlow) – the dataflow to map

  • num_proc (int) – number of threads to use

  • map_func (callable) – datapoint -> datapoint | None. Return None to discard/skip the datapoint.

  • batch_size (int) – batch size

  • buffer_size (int) – number of datapoints (not batched) in the buffer. Defaults to batch_size * 10

class tensorpack.dataflow.FakeData(shapes, size=1000, random=True, dtype='float32', domain=0, 1)[source]

Bases: tensorpack.dataflow.base.RNGDataFlow

Generate fake data of given shapes

__init__(shapes, size=1000, random=True, dtype='float32', domain=0, 1)[source]
Parameters
  • shapes (list) – a list of lists/tuples. Shapes of each component.

  • size (int) – size of this DataFlow.

  • random (bool) – whether to randomly generate data every iteration. Note that merely generating the data could sometimes be time-consuming!

  • dtype (str or list) – data type as string, or a list of data types.

  • domain (tuple or list) – (min, max) tuple, or a list of such tuples

class tensorpack.dataflow.DataFromQueue(queue)[source]

Bases: tensorpack.dataflow.base.DataFlow

Produce data from a queue

__init__(queue)[source]
Parameters

queue (queue) – a queue with get() method.

class tensorpack.dataflow.DataFromList(lst, shuffle=True)[source]

Bases: tensorpack.dataflow.base.RNGDataFlow

Wrap a list of datapoints to a DataFlow

__init__(lst, shuffle=True)[source]
Parameters
  • lst (list) – input list. Each element is a datapoint.

  • shuffle (bool) – shuffle data.

class tensorpack.dataflow.DataFromGenerator(gen)[source]

Bases: tensorpack.dataflow.base.DataFlow

Wrap a generator to a DataFlow. The dataflow will not have length.

__init__(gen)[source]
Parameters

gen – iterable, or a callable that returns an iterable

class tensorpack.dataflow.DataFromIterable(iterable)[source]

Bases: tensorpack.dataflow.base.DataFlow

Wrap an iterable of datapoints to a DataFlow

__init__(iterable)[source]
Parameters

iterable – an iterable object

tensorpack.dataflow.send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False)[source]

Run DataFlow and send data to a ZMQ socket addr. It will serialize and send each datapoint to this address with a PUSH socket. This function never returns.

Parameters
  • df (DataFlow) – Will infinitely loop over the DataFlow.

  • addr – a ZMQ socket endpoint.

  • hwm (int) – ZMQ high-water mark (buffer size)

  • format (str) – The serialization format. Default format uses utils.serialize. This format works with dataflow.RemoteDataZMQ. An alternate format is ‘zmq_ops’, used by https://github.com/tensorpack/zmq_ops and input_source.ZMQInput.

  • bind (bool) – whether to bind or connect to the endpoint address.

class tensorpack.dataflow.RemoteDataZMQ(addr1, addr2=None, hwm=50, bind=True)[source]

Bases: tensorpack.dataflow.base.DataFlow

Produce data from ZMQ PULL socket(s). It is the receiver-side counterpart of send_dataflow_zmq(), which uses tensorpack.utils.serialize for serialization. See http://tensorpack.readthedocs.io/tutorial/efficient-dataflow.html#distributed-dataflow

cnt1, cnt2

number of data points received from addr1 and addr2

Type

int

__init__(addr1, addr2=None, hwm=50, bind=True)[source]
Parameters
  • addr1,addr2 (str) – addr of the zmq endpoint to connect to. Use both if you need two protocols (e.g. both IPC and TCP). I don’t think you’ll ever need 3.

  • hwm (int) – ZMQ high-water mark (buffer size)

  • bind (bool) – whether to connect or bind the endpoint

bind_or_connect(socket, addr)[source]
class tensorpack.dataflow.LMDBSerializer[source]

Bases: object

Serialize a Dataflow to a lmdb database, where the keys are indices and values are serialized datapoints.

You will need to pip install lmdb to use it.

Example:

LMDBSerializer.save(my_df, "output.lmdb")

new_df = LMDBSerializer.load("output.lmdb", shuffle=True)
static load(path, shuffle=True)[source]

Note

If you found deserialization being the bottleneck, you can use LMDBData as the reader and run deserialization as a mapper in parallel.

static save(df, path, write_frequency=5000)[source]
Parameters
  • df (DataFlow) – the DataFlow to serialize.

  • path (str) – output path. Either a directory or an lmdb file.

  • write_frequency (int) – the frequency to write back data to disk. A smaller value reduces memory usage.

class tensorpack.dataflow.NumpySerializer[source]

Bases: object

Serialize the entire dataflow to a npz dict. Note that this would have to store the entire dataflow in memory, and is also >10x slower than LMDB/TFRecord serializers.

static load(path, shuffle=True)[source]
static save(df, path)[source]
Parameters
  • df (DataFlow) – the DataFlow to serialize.

  • path (str) – output npz file.

class tensorpack.dataflow.TFRecordSerializer[source]

Bases: object

Serialize datapoints to bytes (by tensorpack’s default serializer) and write to a TFRecord file.

Note that TFRecord does not support random access and is in fact not very performant. It’s better to use LMDBSerializer.

static load(path, size=None)[source]
Parameters

size (int) – total number of records. If not provided, the returned dataflow will have no __len__(). It’s needed because this metadata is not stored in the TFRecord file.

static save(df, path)[source]
Parameters
  • df (DataFlow) – the DataFlow to serialize.

  • path (str) – output tfrecord file.

class tensorpack.dataflow.HDF5Serializer[source]

Bases: object

Write datapoints to a HDF5 file.

Note that HDF5 files are in fact not very performant and currently do not support lazy loading. It’s better to use LMDBSerializer.

static load(path, data_paths, shuffle=True)[source]
Parameters

data_paths (list) – list of h5 paths to be zipped.

static save(df, path, data_paths)[source]
Parameters
  • df (DataFlow) – the DataFlow to serialize.

  • path (str) – output hdf5 file.

  • data_paths (list[str]) – list of h5 paths. It should have the same length as each datapoint, and each path should correspond to one component of the datapoint.