# -*- coding: utf-8 -*-
# File: concurrency.py
import multiprocessing
import numpy as np
from concurrent.futures import Future
import tensorflow as tf
from six.moves import queue, range
from ..compat import tfv1
from ..tfutils.model_utils import describe_trainable_vars
from ..utils import logger
from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread
from .base import AsyncPredictorBase, OfflinePredictor, OnlinePredictor
__all__ = ['MultiThreadAsyncPredictor']
class MultiProcessPredictWorker(multiprocessing.Process):
""" Base class for predict worker that runs offline in multiprocess"""
def __init__(self, idx, config):
"""
Args:
idx (int): index of the worker. the 0th worker will print log.
config (PredictConfig): the config to use.
"""
super(MultiProcessPredictWorker, self).__init__()
self.name = "MultiProcessPredictWorker-{}".format(idx)
self.idx = idx
self.config = config
def _init_runtime(self):
""" Call _init_runtime under different CUDA_VISIBLE_DEVICES, you'll
have workers that run on multiGPUs
"""
if self.idx != 0:
from tensorpack.models.registry import disable_layer_logging
disable_layer_logging()
self.predictor = OfflinePredictor(self.config)
if self.idx == 0:
with self.predictor.graph.as_default():
describe_trainable_vars()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
"""
An offline predictor worker that takes input and produces output by queue.
Each process will exit when they see :class:`DIE`.
"""
def __init__(self, idx, inqueue, outqueue, config):
"""
Args:
idx, config: same as in :class:`MultiProcessPredictWorker`.
inqueue (multiprocessing.Queue): input queue to get data point. elements are (task_id, dp)
outqueue (multiprocessing.Queue): output queue to put result. elements are (task_id, output)
"""
super(MultiProcessQueuePredictWorker, self).__init__(idx, config)
self.inqueue = inqueue
self.outqueue = outqueue
assert isinstance(self.inqueue, multiprocessing.queues.Queue)
assert isinstance(self.outqueue, multiprocessing.queues.Queue)
def run(self):
self._init_runtime()
while True:
tid, dp = self.inqueue.get()
if tid == DIE:
self.outqueue.put((DIE, None))
return
else:
self.outqueue.put((tid, self.predictor(*dp)))
class PredictorWorkerThread(StoppableThread, ShareSessionThread):
def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__()
self.name = "PredictorWorkerThread-{}".format(id)
self.queue = queue
self.func = pred_func
self.daemon = True
self.batch_size = batch_size
self.id = id
def run(self):
with self.default_sess():
while not self.stopped():
batched, futures = self.fetch_batch()
try:
outputs = self.func(*batched)
except tf.errors.CancelledError:
for f in futures:
f.cancel()
logger.warn("In PredictorWorkerThread id={}, call was cancelled.".format(self.id))
return
# print "Worker {} batched {} Queue {}".format(
# self.id, len(futures), self.queue.qsize())
# debug, for speed testing
# if not hasattr(self, 'xxx'):
# self.xxx = outputs = self.func(batched)
# else:
# outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs])
def fetch_batch(self):
""" Fetch a batch of data without waiting"""
inp, f = self.queue.get()
nr_input_var = len(inp)
batched, futures = [[] for _ in range(nr_input_var)], []
for k in range(nr_input_var):
batched[k].append(inp[k])
futures.append(f)
while len(futures) < self.batch_size:
try:
inp, f = self.queue.get_nowait()
for k in range(nr_input_var):
batched[k].append(inp[k])
futures.append(f)
except queue.Empty:
break # do not wait
for k in range(nr_input_var):
batched[k] = np.asarray(batched[k])
return batched, futures
[docs]class MultiThreadAsyncPredictor(AsyncPredictorBase):
"""
An multithreaded online async predictor which runs a list of OnlinePredictor.
It would do an extra batching internally.
"""
[docs] def __init__(self, predictors, batch_size=5):
"""
Args:
predictors (list): a list of OnlinePredictor available to use.
batch_size (int): the maximum of an internal batch.
"""
assert len(predictors)
self._need_default_sess = False
for k in predictors:
assert isinstance(k, OnlinePredictor), type(k)
if k.sess is None:
self._need_default_sess = True
# TODO support predictors.return_input here
assert not k.return_input
self.input_queue = queue.Queue(maxsize=len(predictors) * 100)
self.threads = [
PredictorWorkerThread(
self.input_queue, f, id, batch_size=batch_size)
for id, f in enumerate(predictors)]
[docs] def start(self):
if self._need_default_sess:
assert tfv1.get_default_session() is not None, \
"Not session is bind to predictors, " \
"MultiThreadAsyncPredictor.start() has to be called under a default session!"
for t in self.threads:
t.start()
[docs] def put_task(self, dp, callback=None):
"""
Args:
dp (list): A datapoint as inputs. It could be either batched or not
batched depending on the predictor implementation).
callback: a thread-safe callback. When the results are ready, it will be called
with the "future" object.
Returns:
concurrent.futures.Future: a Future of results.
"""
f = Future()
if callback is not None:
f.add_done_callback(callback)
self.input_queue.put((dp, f))
return f