# -*- coding: utf-8 -*-
# File: varmanip.py
import glob
import operator
import numpy as np
import os
import pprint
import six
import tensorflow as tf
from ..compat import tfv1
from ..utils import logger
from .common import get_op_tensor_name
__all__ = ['SessionUpdate', 'dump_session_params',
'load_chkpt_vars', 'save_chkpt_vars',
'load_checkpoint_vars', 'save_checkpoint_vars',
'get_checkpoint_path', 'get_all_checkpoints']
def get_savename_from_varname(
varname, varname_prefix=None,
savename_prefix=None):
"""
Args:
varname(str): a variable name in the graph
varname_prefix(str): an optional prefix that may need to be removed in varname
savename_prefix(str): an optional prefix to append to all savename
Returns:
str: the name used to save the variable
"""
name = varname
if varname_prefix is not None \
and name.startswith(varname_prefix):
name = name[len(varname_prefix) + 1:]
if savename_prefix is not None:
name = savename_prefix + '/' + name
return name
class SessionUpdate(object):
""" Update the variables in a session """
def __init__(self, sess, vars_to_update, ignore_mismatch=False):
"""
Args:
sess (tf.Session): a session object
vars_to_update: a collection of variables to update
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
"""
self.sess = sess
self.name_map = {v.name: v for v in vars_to_update}
self.ignore_mismatch = ignore_mismatch
@staticmethod
def relaxed_value_for_var(value, var, ignore_mismatch=False):
"""
Returns a relaxed (possibly reshaped/upcast-ed) version of value,
to be loaded to the given variable.
Args:
value (ndarray): an numpy array to be loaded to var
var (tf.Variable):
ignore_mismatch (bool): ignore failures when the value and the
variable does not match.
Returns:
ndarray: a possibly reshaped or casted version of value.
Returns None if `ignore_mismatch==True` and the value and the variable
mismatch.
"""
assert isinstance(var, tf.Variable)
name = var.op.name
# check incompatible shape
varshape = tuple(var.get_shape().as_list())
if varshape != value.shape:
if np.prod(varshape) != np.prod(value.shape):
if ignore_mismatch:
logger.warn(
"Cannot load an array of shape {} into variable '{}' whose shape is {}.".format(
value.shape, name, varshape))
return None
else:
raise ValueError(
"Trying to load an array of shape {} into variable '{}' whose shape is {}.".format(
value.shape, name, varshape))
# TODO only allow reshape when shape different by empty axis
logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format(
value.shape, varshape, name))
value = value.reshape(varshape)
# Be permissive, and allow some common type incompatibility problems
def allow_cast(to_type, from_type):
# to_type: a tf dtype
# from_type: a numpy dtype
from_type = tf.as_dtype(from_type)
# allow up/down casting between floating points
if from_type.is_floating and to_type.is_floating:
return True
if from_type.is_integer and to_type.is_integer:
# only allow up-casting between integers
if to_type.min <= from_type.min and to_type.max >= from_type.max:
return True
return False
if hasattr(value, 'dtype'):
vartype = var.dtype.as_numpy_dtype
if vartype != value.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, var.dtype, value.dtype)
if allow_cast(var.dtype.base_dtype, value.dtype):
value = vartype(value)
logger.warn(msg + " The value will be loaded after casting!")
else:
assert vartype == value.dtype, msg
return value
def update(self, prms):
"""
Args:
prms(dict): dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
"""
with self.sess.as_default():
fetches = []
feeds = {}
for name, value in six.iteritems(prms):
assert name in self.name_map
var = self.name_map[name]
value = SessionUpdate.relaxed_value_for_var(
value, var, ignore_mismatch=self.ignore_mismatch)
# This is the implementation of `var.load`
if value is not None:
fetches.append(var.initializer)
feeds[var.initializer.inputs[1]] = value
self.sess.run(fetches, feed_dict=feeds)
[docs]def dump_session_params(path):
"""
Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npz format (loadable by :func:`sessinit.SmartInit`).
Args:
path(str): the file name to save the parameters. Must ends with npz.
"""
# save variables that are GLOBAL, and either TRAINABLE or MODEL
var = tfv1.get_collection(tfv1.GraphKeys.TRAINABLE_VARIABLES)
var.extend(tfv1.get_collection(tfv1.GraphKeys.MODEL_VARIABLES))
# TODO dedup
assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!"
gvars = {k.name for k in tfv1.global_variables()}
var = [v for v in var if v.name in gvars]
result = {}
for v in var:
result[v.name] = v.eval()
save_checkpoint_vars(result, path)
[docs]def save_checkpoint_vars(dic, path):
"""
Save variables in dic to path.
Args:
dic: {name: value}. values have to be numpy arrays
path: save as npz if the name ends with '.npz', otherwise save as a checkpoint.
"""
logger.info("Variables to save to {}:".format(path))
keys = sorted(dic.keys())
logger.info(pprint.pformat(keys))
assert not path.endswith('.npy')
if path.endswith('.npz'):
np.savez_compressed(path, **dic)
else:
with tfv1.Graph().as_default(), \
tfv1.Session() as sess:
for k, v in six.iteritems(dic):
k = get_op_tensor_name(k)[0]
_ = tfv1.Variable(name=k, initial_value=v) # noqa
sess.run(tfv1.global_variables_initializer())
saver = tfv1.train.Saver()
saver.save(sess, path, write_meta_graph=False)
def get_checkpoint_path(path):
"""
Work around TF problems in checkpoint path handling.
Args:
path: a user-input path
Returns:
str: the argument that can be passed to `tf.train.NewCheckpointReader`
"""
if os.path.basename(path) == path:
path = os.path.join('.', path) # avoid #4921 and #6142
if os.path.basename(path) == 'checkpoint':
assert tfv1.gfile.Exists(path), path
path = tfv1.train.latest_checkpoint(os.path.dirname(path))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
new_path = path
if '00000-of-00001' in path:
new_path = path.split('.data')[0]
elif path.endswith('.index'):
new_path = path.split('.index')[0]
if new_path != path:
logger.info(
"Checkpoint path {} is auto-corrected to {}.".format(path, new_path))
path = new_path
assert tfv1.gfile.Exists(path) or tfv1.gfile.Exists(path + '.index'), path
return path
[docs]def get_all_checkpoints(dir: str, prefix: str = "model"):
"""
Get a sorted list of all checkpoints found in directory.
Args:
dir (str): checkpoint directory
prefix (str): common prefix among all checkpoints (without the final "-")
Returns:
list[(str, int)]: list of (name, step) sorted by step.
Name is a checkpoint handle that can be passed to
`tf.train.NewCheckpointReader` or :func:`load_checkpoint_vars`.
"""
def step_from_filename(name):
name = os.path.basename(name)
name = name[len("{}-".format(prefix)):-len(".index")]
return int(name)
checkpoints = glob.glob(os.path.join(dir, "model-*.index"))
checkpoints = [(f, step_from_filename(f)) for f in checkpoints]
checkpoints = sorted(checkpoints, key=operator.itemgetter(1))
return checkpoints
[docs]def load_checkpoint_vars(path):
""" Load all variables from a checkpoint to a dict.
Args:
path(str): path to a checkpoint.
Returns:
dict: a name:value dict
"""
if path.endswith(".npz"):
ret = dict(np.load(path))
ret = {get_op_tensor_name(k)[0]: v for k, v in ret.items()}
return ret
path = get_checkpoint_path(path)
reader = tfv1.train.NewCheckpointReader(path)
var_names = reader.get_variable_to_shape_map().keys()
result = {}
for n in var_names:
result[n] = reader.get_tensor(n)
return result
def is_training_name(name):
"""
**Guess** if this variable is only used in training.
Only used internally to avoid too many logging. Do not use it.
"""
# TODO: maybe simply check against TRAINABLE_VARIABLES and MODEL_VARIABLES?
# TODO or use get_slot_names()
name = get_op_tensor_name(name)[0]
if name.endswith('/Adam') or name.endswith('/Adam_1'):
return True
if name.endswith('/Momentum'):
return True
if name.endswith('/Adadelta') or name.endswith('/Adadelta_1'):
return True
if name.endswith('/RMSProp') or name.endswith('/RMSProp_1'):
return True
if name.endswith('/Adagrad'):
return True
if name.startswith('EMA/') or '/EMA/' in name: # all the moving average summaries
return True
if name.startswith('AccumGrad') or name.endswith('/AccumGrad'):
return True
if name.startswith('apply_gradients'):
return True
return False
load_chkpt_vars = load_checkpoint_vars
save_chkpt_vars = save_checkpoint_vars