Source code for tensorpack.tfutils.sessinit

# -*- coding: utf-8 -*-
# File: sessinit.py

import os
import numpy as np
import six

from ..compat import tfv1 as tf
from ..utils import logger
from .common import get_op_tensor_name
from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varname, is_training_name

__all__ = ['SessionInit', 'ChainInit',
           'SaverRestore', 'SaverRestoreRelaxed', 'DictRestore',
           'JustCurrentSession', 'get_model_loader', 'SmartInit']


[docs]class SessionInit(object): """ Base class for utilities to load variables to a (existing) session. """
[docs] def init(self, sess): """ Initialize a session Args: sess (tf.Session): the session """ self._setup_graph() self._run_init(sess)
def _setup_graph(self): pass def _run_init(self, sess): pass
[docs]class JustCurrentSession(SessionInit): """ This is a no-op placeholder""" pass
class CheckpointReaderAdapter(object): """ An adapter to work around old checkpoint format, where the keys are op names instead of tensor names (with :0). """ def __init__(self, reader): self._reader = reader m = self._reader.get_variable_to_shape_map() self._map = {k if k.endswith(':0') else k + ':0': v for k, v in six.iteritems(m)} def get_variable_to_shape_map(self): return self._map def get_tensor(self, name): if self._reader.has_tensor(name): return self._reader.get_tensor(name) if name in self._map: assert name.endswith(':0'), name name = name[:-2] return self._reader.get_tensor(name) def has_tensor(self, name): return name in self._map # some checkpoint might not have ':0' def get_real_name(self, name): if self._reader.has_tensor(name): return name assert self.has_tensor(name) return name[:-2] class MismatchLogger(object): def __init__(self, exists, nonexists): self._exists = exists self._nonexists = nonexists self._names = [] def add(self, name): self._names.append(get_op_tensor_name(name)[0]) def log(self): if len(self._names): logger.warn("The following variables are in the {}, but not found in the {}: {}".format( self._exists, self._nonexists, ', '.join(self._names)))
[docs]class SaverRestore(SessionInit): """ Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`. """
[docs] def __init__(self, model_path, prefix=None, ignore=()): """ Args: model_path (str): a model name (model-xxxx) or a ``checkpoint`` file. prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint. ignore (tuple[str]): tensor names that should be ignored during loading, e.g. learning-rate """ if model_path.endswith('.npy') or model_path.endswith('.npz'): logger.warn("SaverRestore expect a TF checkpoint, but got a model path '{}'.".format(model_path) + " To load from a dict, use 'DictRestore'.") model_path = get_checkpoint_path(model_path) self.path = model_path # attribute used by AutoResumeTrainConfig! self.prefix = prefix self.ignore = [i if i.endswith(':0') else i + ':0' for i in ignore]
def _setup_graph(self): dic = self._get_restore_dict() self.saver = tf.train.Saver(var_list=dic, name=str(id(dic))) def _run_init(self, sess): logger.info("Restoring checkpoint from {} ...".format(self.path)) self.saver.restore(sess, self.path) @staticmethod def _read_checkpoint_vars(model_path): """ return a set of strings """ reader = tf.train.NewCheckpointReader(model_path) reader = CheckpointReaderAdapter(reader) # use an adapter to standardize the name ckpt_vars = reader.get_variable_to_shape_map().keys() return reader, set(ckpt_vars) def _match_vars(self, func): reader, chkpt_vars = SaverRestore._read_checkpoint_vars(self.path) graph_vars = tf.global_variables() chkpt_vars_used = set() mismatch = MismatchLogger('graph', 'checkpoint') for v in graph_vars: name = get_savename_from_varname(v.name, varname_prefix=self.prefix) if name in self.ignore and reader.has_tensor(name): logger.info("Variable {} in the graph will not be loaded from the checkpoint!".format(name)) else: if reader.has_tensor(name): func(reader, name, v) chkpt_vars_used.add(name) else: # use tensor name (instead of op name) for logging, to be consistent with the reverse case if not is_training_name(v.name): mismatch.add(v.name) mismatch.log() mismatch = MismatchLogger('checkpoint', 'graph') if len(chkpt_vars_used) < len(chkpt_vars): unused = chkpt_vars - chkpt_vars_used for name in sorted(unused): if not is_training_name(name): mismatch.add(name) mismatch.log() def _get_restore_dict(self): var_dict = {} def f(reader, name, v): name = reader.get_real_name(name) assert name not in var_dict, "Restore conflict: {} and {}".format(v.name, var_dict[name].name) var_dict[name] = v self._match_vars(f) return var_dict
[docs]class SaverRestoreRelaxed(SaverRestore): """ Same as :class:`SaverRestore`, but has more relaxed constraints. It allows upcasting certain variables, or reshape certain variables when there is a mismatch that can be fixed. When variable shape and value shape do not match, it will print a warning but will not crash. Another advantage is that it doesn't add any new ops to the graph. """ def _setup_graph(self): # no need to setup saver like the parent class pass def _run_init(self, sess): logger.info( "Restoring checkpoint from {} ...".format(self.path)) matched_pairs = [] def f(reader, name, v): val = reader.get_tensor(name) val = SessionUpdate.relaxed_value_for_var(val, v, ignore_mismatch=True) if val is not None: matched_pairs.append((v, val)) with sess.as_default(): self._match_vars(f) upd = SessionUpdate(sess, [x[0] for x in matched_pairs]) upd.update({x[0].name: x[1] for x in matched_pairs})
[docs]class DictRestore(SessionInit): """ Restore variables from a dictionary. """
[docs] def __init__(self, variable_dict, ignore_mismatch=False): """ Args: variable_dict (dict): a dict of {name: value} ignore_mismatch (bool): ignore failures when the value and the variable does not match in their shapes. If False, it will throw exception on such errors. If True, it will only print a warning. """ assert isinstance(variable_dict, dict), type(variable_dict) # use varname (with :0) for consistency self._prms = {get_op_tensor_name(n)[1]: v for n, v in six.iteritems(variable_dict)} self._ignore_mismatch = ignore_mismatch
def _run_init(self, sess): variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) variable_names_list = [k.name for k in variables] variable_names = set(variable_names_list) param_names = set(six.iterkeys(self._prms)) # intersect has the original ordering of variables intersect = [v for v in variable_names_list if v in param_names] # use opname (without :0) for clarity in logging logger.info("Variables to restore from dict: {}".format( ', '.join(get_op_tensor_name(x)[0] for x in intersect))) mismatch = MismatchLogger('graph', 'dict') for k in sorted(variable_names - param_names): if not is_training_name(k): mismatch.add(k) mismatch.log() mismatch = MismatchLogger('dict', 'graph') for k in sorted(param_names - variable_names): mismatch.add(k) mismatch.log() upd = SessionUpdate(sess, [v for v in variables if v.name in intersect], ignore_mismatch=self._ignore_mismatch) logger.info("Restoring {} variables from dict ...".format(len(intersect))) upd.update({name: value for name, value in six.iteritems(self._prms) if name in intersect})
[docs]class ChainInit(SessionInit): """ Initialize a session by a list of :class:`SessionInit` instance, executed one by one. This can be useful for, e.g., loading several models from different files to form a composition of models. """
[docs] def __init__(self, sess_inits): """ Args: sess_inits (list[SessionInit]): list of :class:`SessionInit` instances. """ self.inits = sess_inits
def _setup_graph(self): for i in self.inits: i._setup_graph() def _run_init(self, sess): for i in self.inits: i._run_init(sess)
[docs]def SmartInit(obj, *, ignore_mismatch=False): """ Create a :class:`SessionInit` to be loaded to a session, automatically from any supported objects, with some smart heuristics. The object can be: + A TF checkpoint + A dict of numpy arrays + A npz file, to be interpreted as a dict + An empty string or None, in which case the sessinit will be a no-op + A list of supported objects, to be initialized one by one Args: obj: a supported object ignore_mismatch (bool): ignore failures when the value and the variable does not match in their shapes. If False, it will throw exception on such errors. If True, it will only print a warning. Returns: SessionInit: """ if not obj: return JustCurrentSession() if isinstance(obj, list): return ChainInit([SmartInit(x, ignore_mismatch=ignore_mismatch) for x in obj]) if isinstance(obj, six.string_types): obj = os.path.expanduser(obj) if obj.endswith(".npy") or obj.endswith(".npz"): assert tf.gfile.Exists(obj), "File {} does not exist!".format(obj) filename = obj logger.info("Loading dictionary from {} ...".format(filename)) if filename.endswith('.npy'): obj = np.load(filename, encoding='latin1').item() elif filename.endswith('.npz'): obj = dict(np.load(filename)) elif len(tf.gfile.Glob(obj + "*")): # Assume to be a TF checkpoint. # A TF checkpoint must be a prefix of an actual file. return (SaverRestoreRelaxed if ignore_mismatch else SaverRestore)(obj) else: raise ValueError("Invalid argument to SmartInit: " + obj) if isinstance(obj, dict): return DictRestore(obj, ignore_mismatch=ignore_mismatch) raise ValueError("Invalid argument to SmartInit: " + type(obj))
get_model_loader = SmartInit