Source code for tensorpack.utils.gpu

# -*- coding: utf-8 -*-
# File:

import os

from . import logger
from .concurrency import subproc_call
from .nvml import NVMLContext
from .utils import change_env

__all__ = ['change_gpu', 'get_nr_gpu', 'get_num_gpu']

[docs]def change_gpu(val): """ Args: val: an integer, the index of the GPU or -1 to disable GPU. Returns: a context where ``CUDA_VISIBLE_DEVICES=val``. """ val = str(val) if val == '-1': val = '' return change_env('CUDA_VISIBLE_DEVICES', val)
[docs]def get_num_gpu(): """ Returns: int: #available GPUs in CUDA_VISIBLE_DEVICES, or in the system. """ def warn_return(ret, message): try: import tensorflow as tf except ImportError: return ret built_with_cuda = tf.test.is_built_with_cuda() if not built_with_cuda and ret > 0: logger.warn(message + "But TensorFlow was not built with CUDA support and could not use GPUs!") return ret try: # Use NVML to query device properties with NVMLContext() as ctx: nvml_num_dev = ctx.num_devices() except Exception: nvml_num_dev = None env = os.environ.get('CUDA_VISIBLE_DEVICES', None) if env: num_dev = len(env.split(',')) assert num_dev <= nvml_num_dev, \ "Only {} GPU(s) available, but CUDA_VISIBLE_DEVICES is set to {}".format(nvml_num_dev, env) return warn_return(num_dev, "Found non-empty CUDA_VISIBLE_DEVICES. ") output, code = subproc_call("nvidia-smi -L", timeout=5) if code == 0: output = output.decode('utf-8') return warn_return(len(output.strip().split('\n')), "Found nvidia-smi. ") if nvml_num_dev is not None: return warn_return(nvml_num_dev, "NVML found nvidia devices. ") # Fallback to TF"Loading local devices by TensorFlow ...") try: import tensorflow as tf # available since TF 1.14 gpu_devices = tf.config.experimental.list_physical_devices('GPU') except AttributeError: from tensorflow.python.client import device_lib local_device_protos = device_lib.list_local_devices() # Note this will initialize all GPUs and therefore has side effect # gpu_devices = [ for x in local_device_protos if x.device_type == 'GPU'] return len(gpu_devices)
get_nr_gpu = get_num_gpu