Source code for tensorpack.models.conv2d

# -*- coding: utf-8 -*-
# File: conv2d.py


from ..compat import tfv1 as tf  # this should be avoided first in model code

from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import get_data_format, shape2d, shape4d, log_once
from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable

__all__ = ['Conv2D', 'Deconv2D', 'Conv2DTranspose']


[docs]@layer_register(log_shape=True) @convert_to_tflayer_args( args_names=['filters', 'kernel_size'], name_mapping={ 'out_channel': 'filters', 'kernel_shape': 'kernel_size', 'stride': 'strides', }) def Conv2D( inputs, filters, kernel_size, strides=(1, 1), padding='same', data_format='channels_last', dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=None, bias_initializer=tf.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, split=1): """ Similar to `tf.layers.Conv2D`, but with some differences: 1. Default kernel initializer is variance_scaling_initializer(2.0). 2. Default padding is 'same'. 3. Support 'split' argument to do group convolution. Variable Names: * ``W``: weights * ``b``: bias """ if kernel_initializer is None: if get_tf_version_tuple() <= (1, 12): kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) # deprecated else: kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal') dilation_rate = shape2d(dilation_rate) if split == 1 and dilation_rate == [1, 1]: # tf.layers.Conv2D has bugs with dilations (https://github.com/tensorflow/tensorflow/issues/26797) with rename_get_variable({'kernel': 'W', 'bias': 'b'}): layer = tf.layers.Conv2D( filters, kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, _reuse=tf.get_variable_scope().reuse) ret = layer.apply(inputs, scope=tf.get_variable_scope()) ret = tf.identity(ret, name='output') ret.variables = VariableHolder(W=layer.kernel) if use_bias: ret.variables.b = layer.bias else: # group conv implementation data_format = get_data_format(data_format, keras_mode=False) in_shape = inputs.get_shape().as_list() channel_axis = 3 if data_format == 'NHWC' else 1 in_channel = in_shape[channel_axis] assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!" assert in_channel % split == 0, in_channel assert kernel_regularizer is None and bias_regularizer is None and activity_regularizer is None, \ "Not supported by group conv or dilated conv!" out_channel = filters assert out_channel % split == 0, out_channel assert dilation_rate == [1, 1] or get_tf_version_tuple() >= (1, 5), 'TF>=1.5 required for dilated conv.' kernel_shape = shape2d(kernel_size) filter_shape = kernel_shape + [in_channel // split, out_channel] stride = shape4d(strides, data_format=data_format) kwargs = {"data_format": data_format} if get_tf_version_tuple() >= (1, 5): kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format) # matching input dtype (ex. tf.float16) since the default dtype of variable if tf.float32 inputs_dtype = inputs.dtype W = tf.get_variable( 'W', filter_shape, dtype=inputs_dtype, initializer=kernel_initializer) if use_bias: b = tf.get_variable('b', [out_channel], dtype=inputs_dtype, initializer=bias_initializer) if split == 1: conv = tf.nn.conv2d(inputs, W, stride, padding.upper(), **kwargs) else: conv = None if get_tf_version_tuple() >= (1, 13): try: conv = tf.nn.conv2d(inputs, W, stride, padding.upper(), **kwargs) except ValueError: log_once("CUDNN group convolution support is only available with " "https://github.com/tensorflow/tensorflow/pull/25818 . " "Will fall back to a loop-based slow implementation instead!", 'warn') if conv is None: inputs = tf.split(inputs, split, channel_axis) kernels = tf.split(W, split, 3) outputs = [tf.nn.conv2d(i, k, stride, padding.upper(), **kwargs) for i, k in zip(inputs, kernels)] conv = tf.concat(outputs, channel_axis) ret = tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv if activation is not None: ret = activation(ret) ret = tf.identity(ret, name='output') ret.variables = VariableHolder(W=W) if use_bias: ret.variables.b = b return ret
[docs]@layer_register(log_shape=True) @convert_to_tflayer_args( args_names=['filters', 'kernel_size', 'strides'], name_mapping={ 'out_channel': 'filters', 'kernel_shape': 'kernel_size', 'stride': 'strides', }) def Conv2DTranspose( inputs, filters, kernel_size, strides=(1, 1), padding='same', data_format='channels_last', activation=None, use_bias=True, kernel_initializer=None, bias_initializer=tf.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None): """ A wrapper around `tf.layers.Conv2DTranspose`. Some differences to maintain backward-compatibility: 1. Default kernel initializer is variance_scaling_initializer(2.0). 2. Default padding is 'same' Variable Names: * ``W``: weights * ``b``: bias """ if kernel_initializer is None: if get_tf_version_tuple() <= (1, 12): kernel_initializer = tf.contrib.layers.variance_scaling_initializer(2.0) # deprecated else: kernel_initializer = tf.keras.initializers.VarianceScaling(2.0, distribution='untruncated_normal') if get_tf_version_tuple() <= (1, 12): with rename_get_variable({'kernel': 'W', 'bias': 'b'}): layer = tf.layers.Conv2DTranspose( filters, kernel_size, strides=strides, padding=padding, data_format=data_format, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, _reuse=tf.get_variable_scope().reuse) ret = layer.apply(inputs, scope=tf.get_variable_scope()) ret = tf.identity(ret, name='output') ret.variables = VariableHolder(W=layer.kernel) if use_bias: ret.variables.b = layer.bias else: # Our own implementation, to avoid Keras bugs. https://github.com/tensorflow/tensorflow/issues/25946 assert kernel_regularizer is None and bias_regularizer is None and activity_regularizer is None, \ "Unsupported arguments due to Keras bug in TensorFlow 1.13" data_format = get_data_format(data_format, keras_mode=False) shape_dyn = tf.shape(inputs) shape_sta = inputs.shape.as_list() strides2d = shape2d(strides) kernel_shape = shape2d(kernel_size) assert padding.lower() in ['valid', 'same'], "Padding {} is not supported!".format(padding) if padding.lower() == 'valid': shape_res2d = [max(kernel_shape[0] - strides2d[0], 0), max(kernel_shape[1] - strides2d[1], 0)] else: shape_res2d = shape2d(0) if data_format == 'NCHW': channels_in = shape_sta[1] out_shape_dyn = tf.stack( [shape_dyn[0], filters, shape_dyn[2] * strides2d[0] + shape_res2d[0], shape_dyn[3] * strides2d[1] + shape_res2d[1]]) out_shape3_sta = [filters, None if shape_sta[2] is None else shape_sta[2] * strides2d[0] + shape_res2d[0], None if shape_sta[3] is None else shape_sta[3] * strides2d[1] + shape_res2d[1]] else: channels_in = shape_sta[-1] out_shape_dyn = tf.stack( [shape_dyn[0], shape_dyn[1] * strides2d[0] + shape_res2d[0], shape_dyn[2] * strides2d[1] + shape_res2d[1], filters]) out_shape3_sta = [None if shape_sta[1] is None else shape_sta[1] * strides2d[0] + shape_res2d[0], None if shape_sta[2] is None else shape_sta[2] * strides2d[1] + shape_res2d[1], filters] inputs_dtype = inputs.dtype W = tf.get_variable('W', kernel_shape + [filters, channels_in], dtype=inputs_dtype, initializer=kernel_initializer) if use_bias: b = tf.get_variable('b', [filters], dtype=inputs_dtype, initializer=bias_initializer) conv = tf.nn.conv2d_transpose( inputs, W, out_shape_dyn, shape4d(strides, data_format=data_format), padding=padding.upper(), data_format=data_format) conv.set_shape(tf.TensorShape([shape_sta[0]] + out_shape3_sta)) ret = tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv if activation is not None: ret = activation(ret) ret = tf.identity(ret, name='output') ret.variables = VariableHolder(W=W) if use_bias: ret.variables.b = b return ret
Deconv2D = Conv2DTranspose