Source code for tensorpack.models.tflayer

# -*- coding: utf-8 -*-
# File: tflayer.py

import functools
import six
import tensorflow as tf

from ..tfutils.common import get_tf_version_tuple
from ..tfutils.varreplace import custom_getter_scope
from ..utils.argtools import get_data_format

__all__ = []


def map_common_tfargs(kwargs):
    df = kwargs.pop('data_format', None)
    if df is not None:
        df = get_data_format(df, keras_mode=True)
        kwargs['data_format'] = df

    old_nl = kwargs.pop('nl', None)
    if old_nl is not None:
        kwargs['activation'] = lambda x, name=None: old_nl(x, name=name)

    if 'W_init' in kwargs:
        kwargs['kernel_initializer'] = kwargs.pop('W_init')

    if 'b_init' in kwargs:
        kwargs['bias_initializer'] = kwargs.pop('b_init')
    return kwargs


def convert_to_tflayer_args(args_names, name_mapping):
    """
    After applying this decorator:
    1. data_format becomes tf.layers style
    2. nl becomes activation
    3. initializers are renamed
    4. positional args are transformed to corresponding kwargs, according to args_names
    5. kwargs are mapped to tf.layers names if needed, by name_mapping
    """

    def decorator(func):
        @functools.wraps(func)
        def decorated_func(inputs, *args, **kwargs):
            kwargs = map_common_tfargs(kwargs)

            posarg_dic = {}
            assert len(args) <= len(args_names), \
                "Please use kwargs instead of positional args to call this model, " \
                "except for the following arguments: {}".format(', '.join(args_names))
            for pos_arg, name in zip(args, args_names):
                posarg_dic[name] = pos_arg

            ret = {}
            for name, arg in six.iteritems(kwargs):
                newname = name_mapping.get(name, None)
                if newname is not None:
                    assert newname not in kwargs, \
                        "Argument {} and {} conflicts!".format(name, newname)
                else:
                    newname = name
                ret[newname] = arg
            ret.update(posarg_dic)  # Let pos arg overwrite kw arg, for argscope to work

            return func(inputs, **ret)

        return decorated_func

    return decorator


def rename_get_variable(mapping):
    """
    Args:
        mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'}

    Returns:
        A context where the variables are renamed.
    """
    def custom_getter(getter, name, *args, **kwargs):
        splits = name.split('/')
        basename = splits[-1]
        if basename in mapping:
            basename = mapping[basename]
            splits[-1] = basename
            name = '/'.join(splits)
        return getter(name, *args, **kwargs)
    return custom_getter_scope(custom_getter)


[docs]def rename_tflayer_get_variable(): """ Rename all :func:`tf.get_variable` with rules that transforms tflayer style to tensorpack style. Returns: A context where the variables are renamed. Example: .. code-block:: python with rename_tflayer_get_variable(): x = tf.layer.conv2d(input, 3, 3, name='conv0') # variables will be named 'conv0/W', 'conv0/b' """ mapping = { 'kernel': 'W', 'bias': 'b', 'moving_mean': 'mean/EMA', 'moving_variance': 'variance/EMA', } return rename_get_variable(mapping)
def monkeypatch_tf_layers(): if get_tf_version_tuple() < (1, 4): if not hasattr(tf.layers, 'Dense'): from tensorflow.python.layers.core import Dense tf.layers.Dense = Dense from tensorflow.python.layers.normalization import BatchNormalization tf.layers.BatchNormalization = BatchNormalization from tensorflow.python.layers.convolutional import Conv2DTranspose, Conv2D tf.layers.Conv2DTranspose = Conv2DTranspose tf.layers.Conv2D = Conv2D from tensorflow.python.layers.pooling import MaxPooling2D, AveragePooling2D tf.layers.MaxPooling2D = MaxPooling2D tf.layers.AveragePooling2D = AveragePooling2D monkeypatch_tf_layers()