Source code for tensorpack.tfutils.argscope

# -*- coding: utf-8 -*-
# File: argscope.py

import copy
from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from inspect import getmembers, isfunction
import tensorflow as tf

from ..compat import is_tfv2
from ..utils import logger
from .model_utils import get_shape_str
from .tower import get_current_tower_context

__all__ = ['argscope', 'get_arg_scope', 'enable_argscope_for_module',
           'enable_argscope_for_function']

_ArgScopeStack = []


[docs]@contextmanager def argscope(layers, **kwargs): """ Args: layers (list or layer): layer or list of layers to apply the arguments. Returns: a context where all appearance of these layer will by default have the arguments specified by kwargs. Example: .. code-block:: python with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32): x = Conv2D('conv0', x) x = Conv2D('conv1', x) x = Conv2D('conv2', x, out_channel=64) # override argscope """ if not isinstance(layers, list): layers = [layers] for l in layers: assert hasattr(l, '__argscope_enabled__'), "Argscope not supported for {}".format(l) # need to deepcopy so that changes to new_scope does not affect outer scope new_scope = copy.deepcopy(get_arg_scope()) for l in layers: new_scope[l.__name__].update(kwargs) _ArgScopeStack.append(new_scope) yield del _ArgScopeStack[-1]
[docs]def get_arg_scope(): """ Returns: dict: the current argscope. An argscope is a dict of dict: ``dict[layername] = {arg: val}`` """ if len(_ArgScopeStack) > 0: return _ArgScopeStack[-1] else: return defaultdict(dict)
[docs]def enable_argscope_for_function(func, log_shape=True): """Decorator for function to support argscope Example: .. code-block:: python from mylib import myfunc myfunc = enable_argscope_for_function(myfunc) Args: func: A function mapping one or multiple tensors to one or multiple tensors. log_shape (bool): Specify whether the first input resp. output tensor shape should be printed once. Remarks: If the function ``func`` returns multiple input or output tensors, only the first input/output tensor shape is displayed during logging. Returns: The decorated function. """ assert callable(func), "func should be a callable" @wraps(func) def wrapped_func(*args, **kwargs): actual_args = copy.copy(get_arg_scope()[func.__name__]) actual_args.update(kwargs) out_tensor = func(*args, **actual_args) in_tensor = args[0] ctx = get_current_tower_context() name = func.__name__ if 'name' not in kwargs else kwargs['name'] if log_shape: if ('tower' not in ctx.ns_name.lower()) or ctx.is_main_training_tower: # we assume the first parameter is the most interesting if isinstance(out_tensor, tuple): out_tensor_descr = out_tensor[0] else: out_tensor_descr = out_tensor logger.info("{:<12}: {} --> {}".format( "'" + name + "'", get_shape_str(in_tensor), get_shape_str(out_tensor_descr))) return out_tensor wrapped_func.__argscope_enabled__ = True return wrapped_func
[docs]def enable_argscope_for_module(module, log_shape=True): """ Overwrite all functions of a given module to support argscope. Note that this function monkey-patches the module and therefore could have unexpected consequences. It has been only tested to work well with ``tf.layers`` module. Example: .. code-block:: python import tensorflow as tf enable_argscope_for_module(tf.layers) Args: log_shape (bool): print input/output shapes of each function. """ if is_tfv2() and module == tf.layers: module = tf.compat.v1.layers for name, obj in getmembers(module): if isfunction(obj): setattr(module, name, enable_argscope_for_function(obj, log_shape=log_shape))