Source code for tensorpack.dataflow.parallel_map

# -*- coding: utf-8 -*-
# File: parallel_map.py
import copy
import ctypes
import multiprocessing as mp
import numpy as np
import threading
import zmq
from six.moves import queue

from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils.serialize import dumps_once as dumps, loads_once as loads
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
from .common import RepeatedData, BatchData
from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error

__all__ = ['MultiThreadMapData',
           'MultiProcessMapData', 'MultiProcessMapDataZMQ',
           'MultiProcessMapAndBatchData', 'MultiProcessMapAndBatchDataZMQ']


class _ParallelMapData(ProxyDataFlow):
    def __init__(self, ds, buffer_size, strict=False):
        super(_ParallelMapData, self).__init__(ds)
        assert buffer_size > 0, buffer_size
        self._buffer_size = buffer_size
        self._buffer_occupancy = 0  # actual #elements in buffer, only useful in strict mode
        self._strict = strict

    def reset_state(self):
        super(_ParallelMapData, self).reset_state()
        if not self._strict:
            ds = RepeatedData(self.ds, -1)
        else:
            ds = self.ds
        self._iter = ds.__iter__()

    def _recv(self):
        pass

    def _send(self, dp):
        pass

    def _recv_filter_none(self):
        ret = self._recv()
        assert ret is not None, \
            "[{}] Map function cannot return None when strict mode is used.".format(type(self).__name__)
        return ret

    def _fill_buffer(self, cnt=None):
        if cnt is None:
            cnt = self._buffer_size - self._buffer_occupancy
        try:
            for _ in range(cnt):
                dp = next(self._iter)
                self._send(dp)
        except StopIteration:
            raise RuntimeError(
                "[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True! "
                "Please use a smaller buffer_size!".format(type(self).__name__))
        self._buffer_occupancy += cnt

    def get_data_non_strict(self):
        for dp in self._iter:
            self._send(dp)
            ret = self._recv()
            if ret is not None:
                yield ret

    def get_data_strict(self):
        self._fill_buffer()
        for dp in self._iter:
            self._send(dp)
            yield self._recv_filter_none()
        self._iter = self.ds.__iter__()   # refresh

        # first clear the buffer, then fill
        for k in range(self._buffer_size):
            dp = self._recv_filter_none()
            self._buffer_occupancy -= 1
            if k == self._buffer_size - 1:
                self._fill_buffer()
            yield dp

    def __iter__(self):
        if self._strict:
            yield from self.get_data_strict()
        else:
            yield from self.get_data_non_strict()


[docs]class MultiThreadMapData(_ParallelMapData): """ Same as :class:`MapData`, but start threads to run the mapping function. This is useful when the mapping function is the bottleneck, but you don't want to start processes for the entire dataflow pipeline. The semantics of this class is **identical** to :class:`MapData` except for the ordering. Threads run in parallel and can take different time to run the mapping function. Therefore the order of datapoints won't be preserved. When ``strict=True``, ``MultiThreadMapData(df, ...)`` is guaranteed to produce the exact set of data as ``MapData(df, ...)``, if both are iterated until ``StopIteration``. But the produced data will have different ordering. The behavior of strict mode is undefined if the given dataflow ``df`` is infinite. When ``strict=False``, the data that's produced by ``MultiThreadMapData(df, ...)`` is a reordering of the data produced by ``RepeatedData(MapData(df, ...), -1)``. In other words, first pass of ``MultiThreadMapData.__iter__`` may contain datapoints from the second pass of ``df.__iter__``. Note: 1. You should avoid starting many threads in your main process to reduce GIL contention. The threads will only start in the process which calls :meth:`reset_state()`. Therefore you can use ``MultiProcessRunnerZMQ(MultiThreadMapData(...), 1)`` to reduce GIL contention. """ class _Worker(StoppableThread): def __init__(self, inq, outq, evt, map_func): super(MultiThreadMapData._Worker, self).__init__(evt) self.inq = inq self.outq = outq self.func = map_func self.daemon = True def run(self): try: while True: dp = self.queue_get_stoppable(self.inq) if self.stopped(): return # cannot ignore None here. will lead to unsynced send/recv obj = self.func(dp) self.queue_put_stoppable(self.outq, obj) except Exception: if self.stopped(): pass # skip duplicated error messages else: raise finally: self.stop()
[docs] def __init__(self, ds, num_thread=None, map_func=None, *, buffer_size=200, strict=False): """ Args: ds (DataFlow): the dataflow to map num_thread (int): number of threads to use map_func (callable): datapoint -> datapoint | None. Return None to discard/skip the datapoint. buffer_size (int): number of datapoints in the buffer strict (bool): use "strict mode", see notes above. """ if strict: # In strict mode, buffer size cannot be larger than the total number of datapoints try: buffer_size = min(buffer_size, len(ds)) except Exception: # ds may not have a length pass super(MultiThreadMapData, self).__init__(ds, buffer_size, strict) assert num_thread > 0, num_thread self._strict = strict self.num_thread = num_thread self.map_func = map_func self._threads = [] self._evt = None
def reset_state(self): super(MultiThreadMapData, self).reset_state() if self._threads: self._threads[0].stop() for t in self._threads: t.join() self._in_queue = queue.Queue() self._out_queue = queue.Queue() self._evt = threading.Event() self._threads = [MultiThreadMapData._Worker( self._in_queue, self._out_queue, self._evt, self.map_func) for _ in range(self.num_thread)] for t in self._threads: t.start() self._guard = DataFlowReentrantGuard() # Call once at the beginning, to ensure inq+outq has a total of buffer_size elements self._fill_buffer() def _recv(self): return self._out_queue.get() def _send(self, dp): self._in_queue.put(dp) def __iter__(self): with self._guard: yield from super(MultiThreadMapData, self).__iter__() def __del__(self): if self._evt is not None: self._evt.set() for p in self._threads: p.stop() p.join(timeout=5.0)
# if p.is_alive(): # logger.warn("Cannot join thread {}.".format(p.name))
[docs]class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): """ Same as :class:`MapData`, but start processes to run the mapping function, and communicate with ZeroMQ pipe. The semantics of this class is **identical** to :class:`MapData` except for the ordering. Processes run in parallel and can take different time to run the mapping function. Therefore the order of datapoints won't be preserved. When ``strict=True``, ``MultiProcessMapData(df, ...)`` is guaranteed to produce the exact set of data as ``MapData(df, ...)``, if both are iterated until ``StopIteration``. But the produced data will have different ordering. The behavior of strict mode is undefined if the given dataflow ``df`` is infinite. When ``strict=False``, the data that's produced by ``MultiProcessMapData(df, ...)`` is a reordering of the data produced by ``RepeatedData(MapData(df, ...), -1)``. In other words, first pass of ``MultiProcessMapData.__iter__`` may contain datapoints from the second pass of ``df.__iter__``. """ class _Worker(mp.Process): def __init__(self, identity, map_func, pipename, hwm): super(MultiProcessMapDataZMQ._Worker, self).__init__() self.identity = identity self.map_func = map_func self.pipename = pipename self.hwm = hwm def run(self): enable_death_signal(_warn=self.identity == b'0') ctx = zmq.Context() socket = ctx.socket(zmq.REP) socket.setsockopt(zmq.IDENTITY, self.identity) socket.set_hwm(self.hwm) socket.connect(self.pipename) while True: dp = loads(socket.recv(copy=False)) dp = self.map_func(dp) socket.send(dumps(dp), copy=False)
[docs] def __init__(self, ds, num_proc=None, map_func=None, *, buffer_size=200, strict=False): """ Args: ds (DataFlow): the dataflow to map num_proc(int): number of threads to use map_func (callable): datapoint -> datapoint | None. Return None to discard/skip the datapoint. buffer_size (int): number of datapoints in the buffer strict (bool): use "strict mode", see notes above. """ if strict: # In strict mode, buffer size cannot be larger than the total number of datapoints try: buffer_size = min(buffer_size, len(ds)) except Exception: # ds may not have a length pass _ParallelMapData.__init__(self, ds, buffer_size, strict) _MultiProcessZMQDataFlow.__init__(self) assert num_proc > 0, num_proc self.num_proc = num_proc self.map_func = map_func self._strict = strict self._procs = []
def _create_worker(self, id, pipename, hwm): return MultiProcessMapDataZMQ._Worker(id, self.map_func, pipename, hwm) def reset_state(self): _MultiProcessZMQDataFlow.reset_state(self) _ParallelMapData.reset_state(self) self._guard = DataFlowReentrantGuard() self.context = zmq.Context() self.socket = self.context.socket(zmq.DEALER) self.socket.set_hwm(self._buffer_size * 2) pipename = _get_pipe_name('dataflow-map') _bind_guard(self.socket, pipename) self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)] worker_hwm = int(self._buffer_size * 2 // self.num_proc) self._procs = [self._create_worker(self._proc_ids[k], pipename, worker_hwm) for k in range(self.num_proc)] self._start_processes() self._fill_buffer() # pre-fill the bufer def _send(self, dp): msg = [b"", dumps(dp)] self.socket.send_multipart(msg, copy=False) def _recv(self): msg = self.socket.recv_multipart(copy=False) dp = loads(msg[1]) return dp def __iter__(self): with self._guard, _zmq_catch_error(type(self).__name__): yield from super(MultiProcessMapDataZMQ, self).__iter__()
[docs]class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow): """ Similar to :class:`MultiProcessMapDataZMQ`, except that this DataFlow also does batching in parallel in the worker processes. Therefore it can be helpful if you wish to hide the latency of batching. When `nr_proc==1`, the behavior of this class is identical to `BatchData(MapData(ds, map_func), batch_size)`. When `nr_proc>1`, the datapoints may be grouped in arbitrary order, or grouped with datapoints from a different pass of the given dataflow. """ class _Dispatcher(mp.Process): def __init__(self, ds, pipename, hwm): super(MultiProcessMapAndBatchDataZMQ._Dispatcher, self).__init__() self.ds = RepeatedData(ds, -1) self.pipename = pipename self.hwm = hwm def run(self): enable_death_signal() ctx = zmq.Context() socket = ctx.socket(zmq.PUSH) socket.set_hwm(self.hwm) socket.bind(self.pipename) self.ds.reset_state() for dp in self.ds: socket.send(dumps(dp), copy=False) class _Worker(mp.Process): def __init__(self, identity, map_func, input_pipe, result_pipe, hwm, batch_size): super(MultiProcessMapAndBatchDataZMQ._Worker, self).__init__() self.identity = identity self.map_func = map_func self.input_pipe = input_pipe self.result_pipe = result_pipe self.hwm = hwm self.batch_size = batch_size def run(self): enable_death_signal(_warn=self.identity == b'0') ctx = zmq.Context() # recv jobs socket = ctx.socket(zmq.PULL) socket.setsockopt(zmq.IDENTITY, self.identity) socket.set_hwm(self.hwm * self.batch_size) socket.connect(self.input_pipe) # send results out_socket = ctx.socket(zmq.PUSH) out_socket.set_hwm(max(self.hwm, 5)) out_socket.connect(self.result_pipe) batch = [] while True: dp = loads(socket.recv(copy=False)) dp = self.map_func(dp) if dp is not None: batch.append(dp) if len(batch) == self.batch_size: dp = BatchData.aggregate_batch(batch) out_socket.send(dumps(dp), copy=False) del batch[:]
[docs] def __init__(self, ds, num_proc, map_func, batch_size, buffer_size=None): """ Args: ds (DataFlow): the dataflow to map num_proc(int): number of threads to use map_func (callable): datapoint -> datapoint | None. Return None to discard/skip the datapoint. batch_size (int): batch size buffer_size (int): number of datapoints (not batched) in the buffer. Defaults to batch_size * 10 """ super(MultiProcessMapAndBatchDataZMQ, self).__init__() assert batch_size < buffer_size self.ds = ds self.num_proc = num_proc self.map_func = map_func self.batch_size = batch_size if buffer_size is None: buffer_size = batch_size * 10 self.buffer_size = buffer_size
def reset_state(self): _MultiProcessZMQDataFlow.reset_state(self) self._guard = DataFlowReentrantGuard() job_pipe = _get_pipe_name("dataflow_MaB_job") result_pipe = _get_pipe_name("dataflow_MaB_result") self.context = zmq.Context() self.socket = self.context.socket(zmq.PULL) self.socket.set_hwm(max(5, self.buffer_size // self.batch_size)) _bind_guard(self.socket, result_pipe) dispatcher = MultiProcessMapAndBatchDataZMQ._Dispatcher(self.ds, job_pipe, self.buffer_size) self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)] worker_hwm = max(3, self.buffer_size // self.num_proc // self.batch_size) self._procs = [MultiProcessMapAndBatchDataZMQ._Worker( self._proc_ids[k], self.map_func, job_pipe, result_pipe, worker_hwm, self.batch_size) for k in range(self.num_proc)] self._procs.append(dispatcher) self._start_processes() def __iter__(self): with self._guard, _zmq_catch_error(type(self).__name__): while True: yield loads(self.socket.recv(copy=False))
def _pool_map(data): global SHARED_ARR, WORKER_ID, MAP_FUNC res = MAP_FUNC(data) if res is None: return None shared = np.reshape(SHARED_ARR, res.shape) assert shared.dtype == res.dtype shared[:] = res return WORKER_ID # TODO shutdown pool, improve speed. class MultiProcessMapDataComponentSharedArray(DataFlow): """ Similar to :class:`MapDataComponent`, but perform IPC by shared memory, therefore more efficient when data (result of map_func) is large. It requires `map_func` to always return a numpy array of fixed shape and dtype, or None. """ def __init__(self, ds, nr_proc, map_func, output_shape, output_dtype, index=0): """ Args: ds (DataFlow): the dataflow to map on nr_proc(int): number of processes map_func (data component -> ndarray | None): the mapping function output_shape (tuple): the shape of the output of map_func output_dtype (np.dtype): the type of the output of map_func index (int): the index of the datapoint component to map on. """ self.ds = ds self.nr_proc = nr_proc self.map_func = map_func self.output_shape = output_shape self.output_dtype = np.dtype(output_dtype).type self.index = index self._shared_mem = [self._create_shared_arr() for k in range(nr_proc)] id_queue = mp.Queue() for k in range(nr_proc): id_queue.put(k) def _init_pool(arrs, queue, map_func): id = queue.get() global SHARED_ARR, WORKER_ID, MAP_FUNC SHARED_ARR = arrs[id] WORKER_ID = id MAP_FUNC = map_func self._pool = mp.pool.Pool( processes=nr_proc, initializer=_init_pool, initargs=(self._shared_mem, id_queue, map_func)) def _create_shared_arr(self): TYPE = { np.float32: ctypes.c_float, np.float64: ctypes.c_double, np.uint8: ctypes.c_uint8, np.int8: ctypes.c_int8, np.int32: ctypes.c_int32, } ctype = TYPE[self.output_dtype] arr = mp.RawArray(ctype, int(np.prod(self.output_shape))) return arr def __len__(self): return len(self.ds) def reset_state(self): self.ds.reset_state() self._guard = DataFlowReentrantGuard() def __iter__(self): ds_itr = _repeat_iter(self.ds.get_data) with self._guard: while True: dps = [] for k in range(self.nr_proc): dps.append(copy.copy(next(ds_itr))) to_map = [x[self.index] for x in dps] res = self._pool.map_async(_pool_map, to_map) for index in res.get(): if index is None: continue arr = np.reshape(self._shared_mem[index], self.output_shape) dp = dps[index] dp[self.index] = arr.copy() yield dp # alias MultiProcessMapData = MultiProcessMapDataZMQ MultiProcessMapAndBatchData = MultiProcessMapAndBatchDataZMQ if __name__ == '__main__': import time class Zero(DataFlow): def __init__(self, size): self._size = size def __iter__(self): for k in range(self._size): yield [k] def __len__(self): return self._size def f(x): if x[0] < 10: time.sleep(1) return x ds = Zero(100) ds = MultiThreadMapData(ds, 50, f, buffer_size=50, strict=True) ds.reset_state() for idx, k in enumerate(ds): print("Bang!", k) if idx == 100: break print("END!")