Source code for tensorpack.models.shapes

# -*- coding: utf-8 -*-
# File: shapes.py


import tensorflow as tf

from .common import layer_register

__all__ = ['ConcatWith']


[docs]@layer_register(use_scope=None) def ConcatWith(x, tensor, dim): """ A wrapper around ``tf.concat`` to cooperate with :class:`LinearWrap`. Args: x (tf.Tensor): input tensor (list[tf.Tensor]): a tensor or list of tensors to concatenate with x. x will be at the beginning dim (int): the dimension along which to concatenate Returns: tf.Tensor: ``tf.concat([x] + tensor, dim)`` """ if type(tensor) != list: tensor = [tensor] return tf.concat([x] + tensor, dim)