Source code for tensorpack.tfutils.export

# -*- coding: utf-8 -*-
# File: export.py

"""
A collection of functions to ease the process of exporting
a model for production.

"""

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib

from ..compat import tfv1
from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names, get_tf_version_tuple
from ..tfutils.tower import PredictTowerContext
from ..utils import logger

__all__ = ['ModelExporter']


[docs]class ModelExporter(object): """Export models for inference."""
[docs] def __init__(self, config): """Initialise the export process. Args: config (PredictConfig): the config to use. The graph will be built with the tower function defined by this `PredictConfig`. Then the input / output names will be used to export models for inference. """ super(ModelExporter, self).__init__() self.config = config
[docs] def export_compact(self, filename, optimize=True, toco_compatible=False): """Create a self-contained inference-only graph and write final graph (in pb format) to disk. Args: filename (str): path to the output graph optimize (bool): whether to use TensorFlow's `optimize_for_inference` to prune and optimize the graph. This does not work on all types of graphs. toco_compatible (bool): See TensorFlow's `optimize_for_inference <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_ for details. Only available after TF 1.8. """ if toco_compatible: assert optimize, "toco_compatible is only effective when optimize=True!" self.graph = self.config._maybe_create_graph() with self.graph.as_default(): input = PlaceholderInput() input.setup(self.config.input_signature) with PredictTowerContext(''): self.config.tower_func(*input.get_input_tensors()) input_tensors = get_tensors_by_names(self.config.input_names) output_tensors = get_tensors_by_names(self.config.output_names) self.config.session_init._setup_graph() # we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True)) self.config.session_init._run_init(sess) dtypes = [n.dtype for n in input_tensors] # freeze variables to constants frozen_graph_def = graph_util.convert_variables_to_constants( sess, self.graph.as_graph_def(), [n.name[:-2] for n in output_tensors], variable_names_whitelist=None, variable_names_blacklist=None) # prune unused nodes from graph if optimize: toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, ) frozen_graph_def = optimize_for_inference_lib.optimize_for_inference( frozen_graph_def, [n.name[:-2] for n in input_tensors], [n.name[:-2] for n in output_tensors], [dtype.as_datatype_enum for dtype in dtypes], *toco_args) with gfile.FastGFile(filename, "wb") as f: f.write(frozen_graph_def.SerializeToString()) logger.info("Output graph written to {}.".format(filename))
[docs] def export_serving(self, filename, tags=None, signature_name='prediction_pipeline'): """ Converts a checkpoint and graph to a servable for TensorFlow Serving. Use TF's `SavedModelBuilder` to export a trained model without tensorpack dependency. Args: filename (str): path for export directory tags (tuple): tuple of user specified tags. Defaults to just "SERVING". signature_name (str): name of signature for prediction Note: This produces .. code-block:: none variables/ # output from the vanilla Saver variables.data-?????-of-????? variables.index saved_model.pb # a `SavedModel` protobuf Currently, we only support a single signature, which is the general PredictSignatureDef: https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md """ if tags is None: tags = (tf.saved_model.SERVING if get_tf_version_tuple() >= (1, 12) else tf.saved_model.tag_constants.SERVING, ) self.graph = self.config._maybe_create_graph() with self.graph.as_default(): input = PlaceholderInput() input.setup(self.config.input_signature) with PredictTowerContext(''): self.config.tower_func(*input.get_input_tensors()) input_tensors = get_tensors_by_names(self.config.input_names) saved_model = tfv1.saved_model.utils inputs_signatures = {t.name: saved_model.build_tensor_info(t) for t in input_tensors} output_tensors = get_tensors_by_names(self.config.output_names) outputs_signatures = {t.name: saved_model.build_tensor_info(t) for t in output_tensors} self.config.session_init._setup_graph() # we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True)) self.config.session_init._run_init(sess) builder = tfv1.saved_model.builder.SavedModelBuilder(filename) prediction_signature = tfv1.saved_model.signature_def_utils.build_signature_def( inputs=inputs_signatures, outputs=outputs_signatures, method_name=tfv1.saved_model.signature_constants.PREDICT_METHOD_NAME) builder.add_meta_graph_and_variables( sess, list(tags), signature_def_map={signature_name: prediction_signature}) builder.save() logger.info("SavedModel created at {}.".format(filename))