# -*- coding: utf-8 -*-
# File: gpu.py
import os
from . import logger
from .concurrency import subproc_call
from .nvml import NVMLContext
from .utils import change_env
__all__ = ['change_gpu', 'get_nr_gpu', 'get_num_gpu']
[docs]def change_gpu(val):
"""
Args:
val: an integer, the index of the GPU or -1 to disable GPU.
Returns:
a context where ``CUDA_VISIBLE_DEVICES=val``.
"""
val = str(val)
if val == '-1':
val = ''
return change_env('CUDA_VISIBLE_DEVICES', val)
[docs]def get_num_gpu():
"""
Returns:
int: #available GPUs in CUDA_VISIBLE_DEVICES, or in the system.
"""
def warn_return(ret, message):
try:
import tensorflow as tf
except ImportError:
return ret
built_with_cuda = tf.test.is_built_with_cuda()
if not built_with_cuda and ret > 0:
logger.warn(message + "But TensorFlow was not built with CUDA support and could not use GPUs!")
return ret
try:
# Use NVML to query device properties
with NVMLContext() as ctx:
nvml_num_dev = ctx.num_devices()
except Exception:
nvml_num_dev = None
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if env:
num_dev = len(env.split(','))
assert num_dev <= nvml_num_dev, \
"Only {} GPU(s) available, but CUDA_VISIBLE_DEVICES is set to {}".format(nvml_num_dev, env)
return warn_return(num_dev, "Found non-empty CUDA_VISIBLE_DEVICES. ")
output, code = subproc_call("nvidia-smi -L", timeout=5)
if code == 0:
output = output.decode('utf-8')
return warn_return(len(output.strip().split('\n')), "Found nvidia-smi. ")
if nvml_num_dev is not None:
return warn_return(nvml_num_dev, "NVML found nvidia devices. ")
# Fallback to TF
logger.info("Loading local devices by TensorFlow ...")
try:
import tensorflow as tf
# available since TF 1.14
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
except AttributeError:
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
# Note this will initialize all GPUs and therefore has side effect
# https://github.com/tensorflow/tensorflow/issues/8136
gpu_devices = [x.name for x in local_device_protos if x.device_type == 'GPU']
return len(gpu_devices)
get_nr_gpu = get_num_gpu