Source code for tensorpack.tfutils.common

# -*- coding: utf-8 -*-
# File:

import tensorflow as tf

from ..compat import tfv1
from ..utils.argtools import graph_memoized
from .collect_env import collect_env_info

__all__ = ['get_default_sess_config',
           # 'get_op_tensor_name',
           # 'get_tensors_by_names',
           # 'get_op_or_tensor_by_name',

def get_default_sess_config(mem_fraction=0.99):
    Return a tf.ConfigProto to use as default session config.
    You can modify the returned config to fit your needs.

        mem_fraction(float): see the `per_process_gpu_memory_fraction` option
            in TensorFlow's GPUOptions protobuf:

        tf.ConfigProto: the config to use.
    conf = tfv1.ConfigProto()

    conf.allow_soft_placement = True
    # conf.log_device_placement = True

    conf.intra_op_parallelism_threads = 1
    conf.inter_op_parallelism_threads = 0
    # TF benchmark use cpu_count() - gpu_thread_count(), e.g. 80 - 8 * 2
    # Didn't see much difference.

    conf.gpu_options.per_process_gpu_memory_fraction = mem_fraction

    # This hurt performance of large data pipeline:
    # conf.gpu_options.force_gpu_compatible = True

    conf.gpu_options.allow_growth = True

    # from tensorflow.core.protobuf import rewriter_config_pb2 as rwc
    # conf.graph_options.rewrite_options.memory_optimization = \
    #     rwc.RewriterConfig.HEURISTICS

    # May hurt performance?
    # conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
    # conf.graph_options.place_pruned_graph = True
    return conf

def get_global_step_var():
        tf.Tensor: the global_step variable in the current graph. Create if doesn't exist.
    scope = tfv1.VariableScope(reuse=False, name='')  # the root vs
    with tfv1.variable_scope(scope):
        var = tfv1.train.get_or_create_global_step()
    return var

def get_global_step_value():
        int: global_step value in current graph and session

    Has to be called under a default session.

    return tfv1.train.global_step(

def get_op_tensor_name(name):
    Will automatically determine if ``name`` is a tensor name (ends with ':x')
    or a op name.
    If it is an op name, the corresponding tensor name is assumed to be ``op_name + ':0'``.

        name(str): name of an op or a tensor
        tuple: (op_name, tensor_name)
    if len(name) >= 3 and name[-2] == ':':
        return name[:-2], name
        return name, name + ':0'

def get_tensors_by_names(names):
    Get a list of tensors in the default graph by a list of names.

        names (list):
    ret = []
    G = tfv1.get_default_graph()
    for n in names:
        opn, varn = get_op_tensor_name(n)
    return ret

def get_op_or_tensor_by_name(name):
    Get either tf.Operation of tf.Tensor from names.

        name (list[str] or str): names of operations or tensors.

        KeyError, if the name doesn't exist
    G = tfv1.get_default_graph()

    def f(n):
        if len(n) >= 3 and n[-2] == ':':
            return G.get_tensor_by_name(n)
            return G.get_operation_by_name(n)

    if not isinstance(name, list):
        return f(name)
        return list(map(f, name))

def gpu_available_in_session():
    sess = tfv1.get_default_session()
    for dev in sess.list_devices():
        if dev.device_type.lower() == 'gpu':
            return True
    return False

def get_tf_version_tuple():
    Return TensorFlow version as a 2-element tuple (for comparison).
    return tuple(map(int, tf.__version__.split('.')[:2]))