# -*- coding: utf-8 -*-
# File: base.py
from abc import ABCMeta, abstractmethod
import six
import tensorflow as tf
from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names, get_op_tensor_name
from ..tfutils.tower import PredictTowerContext
__all__ = ['PredictorBase',
'OnlinePredictor', 'OfflinePredictor']
[docs]@six.add_metaclass(ABCMeta)
class PredictorBase(object):
"""
Base class for all predictors.
Attributes:
return_input (bool): whether the call will also return (inputs, outputs)
or just outputs
"""
[docs] def __call__(self, *dp):
"""
Call the predictor on some inputs.
Example:
When you have a predictor defined with two inputs, call it with:
.. code-block:: python
predictor(e1, e2)
Returns:
list[array]: list of outputs
"""
output = self._do_call(dp)
if self.return_input:
return (dp, output)
else:
return output
@abstractmethod
def _do_call(self, dp):
"""
Args:
dp: input datapoint. must have the same length as input_names
Returns:
output as defined by the config
"""
class AsyncPredictorBase(PredictorBase):
""" Base class for all async predictors. """
@abstractmethod
def put_task(self, dp, callback=None):
"""
Args:
dp (list): A datapoint as inputs. It could be either batched or not
batched depending on the predictor implementation).
callback: a thread-safe callback to get called with
either outputs or (inputs, outputs), if `return_input` is True.
Returns:
concurrent.futures.Future: a Future of results
"""
@abstractmethod
def start(self):
""" Start workers """
def _do_call(self, dp):
fut = self.put_task(dp)
# in Tornado, Future.result() doesn't wait
return fut.result()
[docs]class OnlinePredictor(PredictorBase):
"""
A predictor which directly use an existing session and given tensors.
Attributes:
sess: The tf.Session object associated with this predictor.
"""
ACCEPT_OPTIONS = False
""" See Session.make_callable """
[docs] def __init__(self, input_tensors, output_tensors,
return_input=False, sess=None):
"""
Args:
input_tensors (list): list of names.
output_tensors (list): list of names.
return_input (bool): same as :attr:`PredictorBase.return_input`.
sess (tf.Session): the session this predictor runs in. If None,
will use the default session at the first call.
Note that in TensorFlow, default session is thread-local.
"""
def normalize_name(t):
if isinstance(t, six.string_types):
return get_op_tensor_name(t)[1]
return t
self.return_input = return_input
self.input_tensors = [normalize_name(x) for x in input_tensors]
self.output_tensors = [normalize_name(x) for x in output_tensors]
self.sess = sess
if sess is not None:
self._callable = sess.make_callable(
fetches=output_tensors,
feed_list=input_tensors,
accept_options=self.ACCEPT_OPTIONS)
else:
self._callable = None
def _do_call(self, dp):
assert len(dp) == len(self.input_tensors), \
"{} != {}".format(len(dp), len(self.input_tensors))
if self.sess is None:
self.sess = tf.get_default_session()
assert self.sess is not None, "Predictor isn't called under a default session!"
if self._callable is None:
self._callable = self.sess.make_callable(
fetches=self.output_tensors,
feed_list=self.input_tensors,
accept_options=self.ACCEPT_OPTIONS)
# run_metadata = tf.RunMetadata()
# options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
return self._callable(*dp)
[docs]class OfflinePredictor(OnlinePredictor):
""" A predictor built from a given config.
A single-tower model will be built without any prefix.
Example:
.. code-block:: python
config = PredictConfig(model=my_model,
inputs_names=['image'],
# use names of tensors defined in the model
output_names=['linear/output', 'prediction'])
predictor = OfflinePredictor(config)
image = np.random.rand(1, 100, 100, 3) # the shape of "image" defined in the model
linear_output, prediction = predictor(image)
"""
[docs] def __init__(self, config):
"""
Args:
config (PredictConfig): the config to use.
"""
self.graph = config._maybe_create_graph()
with self.graph.as_default():
input = PlaceholderInput()
input.setup(config.input_signature)
with PredictTowerContext(''):
config.tower_func(*input.get_input_tensors())
input_tensors = get_tensors_by_names(config.input_names)
output_tensors = get_tensors_by_names(config.output_names)
config.session_init._setup_graph()
sess = config.session_creator.create_session()
config.session_init._run_init(sess)
super(OfflinePredictor, self).__init__(
input_tensors, output_tensors, config.return_input, sess)