Source code for tensorpack.models.nonlin

# -*- coding: utf-8 -*-
# File:

import tensorflow as tf

from ..utils.develop import log_deprecated
from ..compat import tfv1
from .batch_norm import BatchNorm
from .common import VariableHolder, layer_register
from .utils import disable_autograph

__all__ = ['Maxout', 'PReLU', 'BNReLU']

[docs]@layer_register(use_scope=None) def Maxout(x, num_unit): """ Maxout as in the paper `Maxout Networks <>`_. Args: x (tf.Tensor): a NHWC or NC tensor. Channel has to be known. num_unit (int): a int. Must be divisible by C. Returns: tf.Tensor: of shape NHW(C/num_unit) named ``output``. """ input_shape = x.get_shape().as_list() ndim = len(input_shape) assert ndim == 4 or ndim == 2 ch = input_shape[-1] assert ch is not None and ch % num_unit == 0 if ndim == 4: x = tf.reshape(x, [-1, input_shape[1], input_shape[2], ch / num_unit, num_unit]) else: x = tf.reshape(x, [-1, ch / num_unit, num_unit]) return tf.reduce_max(x, ndim, name='output')
[docs]@layer_register() @disable_autograph() def PReLU(x, init=0.001, name=None): """ Parameterized ReLU as in the paper `Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification <>`_. Args: x (tf.Tensor): input init (float): initial value for the learnable slope. name (str): deprecated argument. Don't use Variable Names: * ``alpha``: learnable slope. """ if name is not None: log_deprecated("PReLU(name=...)", "The output tensor will be named `output`.") init = tfv1.constant_initializer(init) alpha = tfv1.get_variable('alpha', [], initializer=init) x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x)) ret = tf.multiply(x, 0.5, name=name or None) ret.variables = VariableHolder(alpha=alpha) return ret
[docs]@layer_register(use_scope=None) def BNReLU(x, name=None): """ A shorthand of BatchNormalization + ReLU. Args: x (tf.Tensor): the input name: deprecated, don't use. """ if name is not None: log_deprecated("BNReLU(name=...)", "The output tensor will be named `output`.") x = BatchNorm('bn', x) x = tf.nn.relu(x, name=name) return x