Source code for tensorpack.predict.concurrency

# -*- coding: utf-8 -*-
# File:

import multiprocessing
import numpy as np
from concurrent.futures import Future
import tensorflow as tf
from six.moves import queue, range

from ..compat import tfv1
from ..tfutils.model_utils import describe_trainable_vars
from ..utils import logger
from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread
from .base import AsyncPredictorBase, OfflinePredictor, OnlinePredictor

__all__ = ['MultiThreadAsyncPredictor']

class MultiProcessPredictWorker(multiprocessing.Process):
    """ Base class for predict worker that runs offline in multiprocess"""

    def __init__(self, idx, config):
            idx (int): index of the worker. the 0th worker will print log.
            config (PredictConfig): the config to use.
        super(MultiProcessPredictWorker, self).__init__() = "MultiProcessPredictWorker-{}".format(idx)
        self.idx = idx
        self.config = config

    def _init_runtime(self):
        """ Call _init_runtime under different CUDA_VISIBLE_DEVICES, you'll
            have workers that run on multiGPUs
        if self.idx != 0:
            from tensorpack.models.registry import disable_layer_logging
        self.predictor = OfflinePredictor(self.config)
        if self.idx == 0:
            with self.predictor.graph.as_default():

class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
    An offline predictor worker that takes input and produces output by queue.
    Each process will exit when they see :class:`DIE`.

    def __init__(self, idx, inqueue, outqueue, config):
            idx, config: same as in :class:`MultiProcessPredictWorker`.
            inqueue (multiprocessing.Queue): input queue to get data point. elements are (task_id, dp)
            outqueue (multiprocessing.Queue): output queue to put result. elements are (task_id, output)
        super(MultiProcessQueuePredictWorker, self).__init__(idx, config)
        self.inqueue = inqueue
        self.outqueue = outqueue
        assert isinstance(self.inqueue, multiprocessing.queues.Queue)
        assert isinstance(self.outqueue, multiprocessing.queues.Queue)

    def run(self):
        while True:
            tid, dp = self.inqueue.get()
            if tid == DIE:
                self.outqueue.put((DIE, None))
                self.outqueue.put((tid, self.predictor(*dp)))

class PredictorWorkerThread(StoppableThread, ShareSessionThread):
    def __init__(self, queue, pred_func, id, batch_size=5):
        super(PredictorWorkerThread, self).__init__() = "PredictorWorkerThread-{}".format(id)
        self.queue = queue
        self.func = pred_func
        self.daemon = True
        self.batch_size = batch_size = id

    def run(self):
        with self.default_sess():
            while not self.stopped():
                batched, futures = self.fetch_batch()
                    outputs = self.func(*batched)
                except tf.errors.CancelledError:
                    for f in futures:
                    logger.warn("In PredictorWorkerThread id={}, call was cancelled.".format(
                # print "Worker {} batched {} Queue {}".format(
                #, len(futures), self.queue.qsize())
                #  debug, for speed testing
                # if not hasattr(self, 'xxx'):
                    # = outputs = self.func(batched)
                # else:
                    # outputs = [[[0][0]] * len(batched[0]), [[1][0]] * len(batched[0])]

                for idx, f in enumerate(futures):
                    f.set_result([k[idx] for k in outputs])

    def fetch_batch(self):
        """ Fetch a batch of data without waiting"""
        inp, f = self.queue.get()
        nr_input_var = len(inp)
        batched, futures = [[] for _ in range(nr_input_var)], []
        for k in range(nr_input_var):
        while len(futures) < self.batch_size:
                inp, f = self.queue.get_nowait()
                for k in range(nr_input_var):
            except queue.Empty:
                break   # do not wait

        for k in range(nr_input_var):
            batched[k] = np.asarray(batched[k])
        return batched, futures

[docs]class MultiThreadAsyncPredictor(AsyncPredictorBase): """ An multithreaded online async predictor which runs a list of OnlinePredictor. It would do an extra batching internally. """
[docs] def __init__(self, predictors, batch_size=5): """ Args: predictors (list): a list of OnlinePredictor available to use. batch_size (int): the maximum of an internal batch. """ assert len(predictors) self._need_default_sess = False for k in predictors: assert isinstance(k, OnlinePredictor), type(k) if k.sess is None: self._need_default_sess = True # TODO support predictors.return_input here assert not k.return_input self.input_queue = queue.Queue(maxsize=len(predictors) * 100) self.threads = [ PredictorWorkerThread( self.input_queue, f, id, batch_size=batch_size) for id, f in enumerate(predictors)]
[docs] def start(self): if self._need_default_sess: assert tfv1.get_default_session() is not None, \ "Not session is bind to predictors, " \ "MultiThreadAsyncPredictor.start() has to be called under a default session!" for t in self.threads: t.start()
[docs] def put_task(self, dp, callback=None): """ Args: dp (list): A datapoint as inputs. It could be either batched or not batched depending on the predictor implementation). callback: a thread-safe callback. When the results are ready, it will be called with the "future" object. Returns: concurrent.futures.Future: a Future of results. """ f = Future() if callback is not None: f.add_done_callback(callback) self.input_queue.put((dp, f)) return f