# -*- coding: utf-8 -*-
# File: tower.py
from abc import ABCMeta, abstractmethod, abstractproperty
import six
from ..compat import tfv1 as tf
from ..utils import logger
from ..utils.argtools import call_only_once
from ..utils.develop import HIDE_DOC
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .collection import CollectionGuard
from .common import get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'BaseTowerContext', 'TowerContext',
'TowerFuncWrapper', 'TowerFunc',
'TowerTensorHandle', 'TowerTensorHandles']
_CurrentTowerContext = None
[docs]@six.add_metaclass(ABCMeta)
class BaseTowerContext(object):
""" A context where the current model is built in.
You need to use :func:`TowerContext` to create a :class:`BaseTowerContext`.
"""
[docs] @HIDE_DOC
def __init__(self, ns_name, vs_name=''):
"""
This is not supposed to be used by users.
You need to use :func:`TowerContext` to create a :class:`BaseTowerContext`.
Args:
ns_name (str): The name scope of the tower.
vs_name (str): Open a new variable scope with this name.
"""
self._name = ns_name
self._vs_name = vs_name
if len(vs_name):
assert len(ns_name), "TowerContext(vs_name) cannot be used with an empty name!"
@abstractproperty
def is_main_training_tower(self):
"""
bool: Whether this tower is the main (i.e., the first) training tower.
"""
pass
@abstractproperty
def has_own_variables(self):
"""
bool: Whether this tower is supposed to have its own trainable variables.
"""
pass
@property
def name(self):
"""
str: The name scope of the tower.
"""
return self._name
@property
def vs_name(self):
"""
str: The variable scope of the tower.
"""
return self._vs_name
@property
def ns_name(self):
"""
str: The name scope of the tower.
"""
return self._name
[docs] def get_collection_in_tower(self, key):
"""
From a collection, get items that are __added__ to the collection in this tower.
Note that it works by tracking the collection at the beginning and end of
the tower function.
Therefore it does not guarantee that the items are __created__ in this tower.
"""
return self._collection_guard.get_collection_in_tower(key)
@call_only_once
def _get_scopes(self):
"""
Returns the ns and vs for this tower.
"""
if not len(self._name):
# work around https://github.com/tensorflow/tensorflow/issues/14703
return [tf.variable_scope(tf.get_variable_scope())]
ret = []
if len(self._vs_name):
ret.append(tf.variable_scope(self._vs_name))
else:
# caller should have handled reuse outside of TowerContext
ret.append(tf.variable_scope(tf.get_variable_scope()))
# always clear existing ns # TODO check existing ns
if len(self._name):
ret.append(tf.name_scope(self._name + '/'))
return ret
@abstractmethod
def _keys_to_freeze(self):
pass
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
_CurrentTowerContext = self
self._collection_guard = CollectionGuard(
self._name,
check_diff=not self.is_main_training_tower,
freeze_keys=self._keys_to_freeze())
self._ctxs = self._get_scopes()
self._ctxs.append(self._collection_guard)
for c in self._ctxs:
c.__enter__()
# check that ns_name is always the same as _name
ns = tf.get_default_graph().get_name_scope()
assert ns == self._name, \
"Name conflict: name_scope inside tower '{}' becomes '{}'!".format(self._name, ns) \
+ " You may need a different name for the tower!"
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext
_CurrentTowerContext = None
if not self.has_own_variables:
diff_trainable_vars = self._collection_guard.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
assert len(diff_trainable_vars) == 0, \
"New TRAINABLE_VARIABLES shouldn't be created in {}: ".format(
self._name) + ', '.join([k.name for k in diff_trainable_vars])
for c in self._ctxs[::-1]:
c.__exit__(exc_type, exc_val, exc_tb)
return False
def __str__(self):
return "TowerContext(name={}, is_training={})".format(
self._name, self._is_training)
@property
def is_training(self):
"""
bool: whether the context is training or not
"""
return self._is_training
class TrainTowerContext(BaseTowerContext):
def __init__(self, ns_name, vs_name='', index=0, total=1):
"""
Args:
index (int): index of this tower, only used in training.
total (int): total number of towers to be built.
"""
super(TrainTowerContext, self).__init__(ns_name, vs_name)
self._is_training = True
self.index = int(index)
self.total = int(total)
if self.index > 0:
assert self.total > self.index, "(index, total) = ({}, {})".format(self.index, self.total)
vs = tf.get_variable_scope()
assert vs.name == '', "Cannot nest TrainTowerContext with an existing variable scope!"
if vs_name:
assert not vs.reuse, \
"Cannot create tower {} with vs_name={} under reuse=True!".format(ns_name, vs_name)
self._original_vs_reuse = vs.reuse
@property
def is_main_training_tower(self):
return self.index == 0
@property
def has_own_variables(self):
if self._original_vs_reuse:
return False
return self.index == 0 or len(self._vs_name) > 0
def _keys_to_freeze(self):
if self.index == 0:
return []
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
class PredictTowerContext(BaseTowerContext):
def __init__(self, ns_name, vs_name=''):
super(PredictTowerContext, self).__init__(ns_name, vs_name)
self._is_training = False
self._initial_vs_reuse = tf.get_variable_scope().reuse
@property
def has_own_variables(self):
return not self._initial_vs_reuse
@property
def is_main_training_tower(self):
return False
def _keys_to_freeze(self):
# freeze UPDATE_OPS during inference because they should never be used
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY, tf.GraphKeys.UPDATE_OPS]
[docs]def get_current_tower_context():
"""
When called inside a TowerContext, returns the TowerContext.
Returns:
a :class:`BaseTowerContext` instance or None, if not called under a TowerContext.
"""
return _CurrentTowerContext
[docs]def TowerContext(tower_name, is_training, vs_name=''):
"""
The context for a tower function, containing metadata about the current tower.
Tensorpack trainers use :class:`TowerContext` to manage tower function.
Many tensorpack layers have to be called under a :class:`TowerContext`.
Example:
.. code-block:: python
with TowerContext('', is_training=True):
# call a tensorpack layer or a tower function
"""
if is_training:
return TrainTowerContext(tower_name, vs_name=vs_name)
else:
return PredictTowerContext(tower_name, vs_name=vs_name)
[docs]class TowerFunc(object):
"""
A tower function (see
`tutorial on tower function
<http://tensorpack.readthedocs.io/tutorial/extend/trainer.html#tower-trainer>`_)
It keeps track of the name scope, variable scope and input/output tensors
each time the function is called.
:class:`TowerTrainer` needs this so that it knows how to build a predictor.
Conceptually, this class is roughly equivalent to `tf.function` with input signature, introduced in TF 2.0.
"""
[docs] def __init__(self, tower_fn, input_signature):
"""
Args:
tower_func: a function which builds one tower in the graph.
It takes several input tensors and could return anything.
input_signature ([TensorSpec]): list of :class:`tf.TensorSpec`.
They are used to figure out the names for the input tensors.
"""
assert callable(tower_fn), tower_fn
self._inputs_names = [k.name for k in input_signature]
assert len(set(self._inputs_names)) == len(self._inputs_names), \
"Duplicated names in input_signature! " + str(self._inputs_names)
for name in self._inputs_names:
if any(k in name for k in [':', '/', ' ']):
raise ValueError("Invalid input name: '{}'".format(name))
self._tower_fn = tower_fn
self._input_signature = input_signature
self._handles = []
def __new__(cls, tower_fn, _):
# to avoid double-wrapping a function
if isinstance(tower_fn, TowerFunc):
return tower_fn
else:
return super(TowerFunc, cls).__new__(cls)
def __call__(self, *args):
ctx = get_current_tower_context()
assert ctx is not None, "Function must be called under TowerContext!"
output = self._tower_fn(*args)
handle = TowerTensorHandle(ctx, args, output, self._input_signature)
self._handles.append(handle)
return output
@property
def towers(self):
"""
TowerTensorHandles: a :class:`TowerTensorHandles` object, that can
access the tower handles by either indices or names.
"""
return TowerTensorHandles(self._handles)
@property
def input_signature(self):
return self._input_signature
TowerFuncWrapper = TowerFunc
[docs]class TowerTensorHandles(object):
"""
Wrap a list of :class:`TowerTensorHandle`,
to support access to them by index or names.
"""
def __init__(self, handles):
self._handles = handles
self._name_to_handle = {k.ns_name: k for k in handles}
def __len__(self):
return len(self._handles)
[docs] def __getitem__(self, name_or_index):
"""
Args:
name_or_index (str or int):
Returns:
a :class:`TowerTensorHandle`.
"""
if isinstance(name_or_index, int):
return self._handles[name_or_index]
return self._name_to_handle[name_or_index]
[docs] def training(self):
"""
Returns:
A :class:`TowerTensorHandles`, containing only the training towers.
"""
handles = [h for h in self._handles if h.is_training]
return TowerTensorHandles(handles)
[docs] def inference(self):
"""
Returns:
A :class:`TowerTensorHandles`, containing only the inference towers.
"""
handles = [h for h in self._handles if not h.is_training]
return TowerTensorHandles(handles)
[docs]class TowerTensorHandle(object):
"""
When a function is called multiple times under each tower,
it becomes hard to keep track of the scope and access those tensors
in each tower.
This class provides easy access to the tensors as well as the
inputs/outputs created in each tower.
"""
@HIDE_DOC
def __init__(self, ctx, inputs, outputs, input_signature=None):
self._ctx = ctx
self._extra_tensor_names = {}
if input_signature is not None:
assert len(input_signature) == len(inputs)
self._extra_tensor_names = {
get_op_tensor_name(x.name)[1]: y for x, y in zip(input_signature, inputs)}
self._inputs = inputs
self._outputs = outputs
# TODO: deprecated. Remove them later
self.input = inputs
self.output = outputs
@property
def vs_name(self):
return self._ctx.vs_name
@property
def ns_name(self):
return self._ctx.ns_name
[docs] def get_tensor(self, name):
"""
Get a tensor in this tower. The name argument can be:
1. The name of a tensor/variable without any tower prefix.
2. A name in the input signature, if it is used when building the tower.
In the second case, this method will return the tensor that's used as the corresponding
input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue).
"""
name = get_op_tensor_name(name)[1]
if len(self.ns_name):
name_with_ns = self.ns_name + "/" + name
else:
name_with_ns = name
try:
ret = get_op_or_tensor_by_name(name_with_ns)
except KeyError:
if name in self._extra_tensor_names:
return self._extra_tensor_names[name]
else:
if name in self._extra_tensor_names:
mapped_tensor = self._extra_tensor_names[name]
logger.info(
"'{}' may refer to both the Tensor/Placeholder '{}' or the input to the tower '{}'.".format(
name, ret.name, mapped_tensor.name) +
" Assuming it is the input '{}'.".format(mapped_tensor.name))
return mapped_tensor
return ret
# should also allow variables in get_tensor
return self.get_variable(name)
[docs] def get_tensors(self, names):
"""
Like :meth:`get_tensor`, but takes a list and returns a list.
"""
return [self.get_tensor(name) for name in names]
[docs] def __getitem__(self, name):
"""
The same as :meth:`get_tensor`.
"""
return self.get_tensor(name)
[docs] def get_variable(self, name):
"""
Get a variable used in this tower.
The name should not contain the variable scope prefix of the tower.
When the tower has the same variable scope and name scope, this is equivalent to
:meth:`get_tensor`.
"""
name = get_op_tensor_name(name)[1]
if len(self.vs_name):
name_with_vs = self.vs_name + "/" + name
else:
name_with_vs = name
return get_op_or_tensor_by_name(name_with_vs)
[docs] def get_variables(self, names):
"""
Like :meth:`get_variable`, but takes a list and returns a list.
"""
return [self.get_variable(name) for name in names]
[docs] def get_collection(self, key=None, name=None):
"""
See :meth:`BaseTowerContext.get_collection_in_tower`.
Args:
key (str): the key of the collection
name: deprecated
"""
if name is not None:
logger.warn("TowerTensorHandle.get_collection(name=..) was renamed to (key=..) !")
key = name
return self._ctx.get_collection_in_tower(key)
@property
def inputs(self):
"""
list[Tensor]: The list of input tensors used to build the tower.
"""
return self._inputs
@property
def outputs(self):
"""
list[Tensor]: The outputs returned by the tower function.
"""
return self._outputs
@property
def is_training(self):
return self._ctx.is_training