Source code for tensorpack.tfutils.gradproc

# -*- coding: utf-8 -*-
# File: gradproc.py


import inspect
import re
from abc import ABCMeta, abstractmethod
import six
import tensorflow as tf

from ..compat import tfv1
from ..utils import logger
from .summary import add_moving_summary
from .symbolic_functions import print_stat, rms

__all__ = ['GradientProcessor',
           'FilterNoneGrad', 'GlobalNormClip', 'MapGradient', 'SummaryGradient',
           'PrintGradient', 'CheckGradient', 'ScaleGradient']


[docs]@six.add_metaclass(ABCMeta) class GradientProcessor(object): """ Base class for all gradient processors. Gradient processors can be applied to optimizers by :func:`optimizer.apply_grad_processors`. Subclass should override the ``_process()`` method. """ _name_scope = None
[docs] def process(self, grads): """ Process the symbolic gradients. Args: grads (list): list of (grad, var). Returns: list: processed gradients, with the same type as input. """ # reuse the old name_scope, if process() is called multiple times if self._name_scope is None: with tfv1.name_scope(type(self).__name__) as scope: self._name_scope = scope return self._process(grads) else: with tfv1.name_scope(self._name_scope): return self._process(grads)
@abstractmethod def _process(self, grads): pass
[docs]class FilterNoneGrad(GradientProcessor): """ Skip the update and print a warning (instead of crashing), when the gradient of certain variable is None. """
[docs] def __init__(self, verbose=True): """ Args: verbose (bool): whether to print warning about None gradients. """ super(FilterNoneGrad, self).__init__() self._verbose = verbose
def _process(self, grads): g = [] to_print = [] for grad, var in grads: if grad is None: to_print.append(var.op.name) else: g.append((grad, var)) if self._verbose and len(to_print): message = ', '.join(to_print) logger.warn("No gradient w.r.t {} trainable variables: {}".format(len(to_print), message)) return g
[docs]class GlobalNormClip(GradientProcessor): """ Clip by global norm. The global norm is the sum of norm for **all** gradients. See :func:`tf.clip_by_global_norm` for more information. """
[docs] def __init__(self, global_norm): """ Args: global_norm(float): the threshold to clip with. """ super(GlobalNormClip, self).__init__() self._norm = float(global_norm)
def _process(self, grads): g = [k[0] for k in grads] v = [k[1] for k in grads] g, _ = tf.clip_by_global_norm(g, self._norm, name='clip_by_global_norm') return list(zip(g, v))
[docs]class MapGradient(GradientProcessor): """ Apply a function on all gradient if the name matches regex. Keep the other gradients unchanged. It can be used for gradient clipping, etc. """
[docs] def __init__(self, func, regex='.*'): """ Args: func: a user-supplied function which takes one or two arguments. The argument(s) can be either a `grad` tensor, or `grad` and `var`. The function should return the new gradient to be used. If it return None, the gradient is discarded (hence no update to the variable will happen). regex (str): used to match variables. Defaults to match all variables. """ args = inspect.getfullargspec(func).args arg_num = len(args) - inspect.ismethod(func) assert arg_num in [1, 2], \ "The function must take 1 or 2 arguments! ({})".format(args) if arg_num == 1: self.func = lambda grad, var: func(grad) else: self.func = func if not regex.endswith('$'): regex = regex + '$' self.regex = regex super(MapGradient, self).__init__()
def _process(self, grads): ret = [] matched = False for grad, var in grads: if re.match(self.regex, var.op.name): matched = True grad = self.func(grad, var) if grad is not None: ret.append((grad, var)) else: ret.append((grad, var)) if not matched: logger.warn("[MapGradient] No match was found for regex {}.".format(self.regex)) return ret
# TODO has dependency problems: sess.run may not depend on grad # maybe group maintain op and grad ?
[docs]class SummaryGradient(MapGradient): """ For each gradient tensor, summary its histogram and add it to moving summaries. """ # avoid duplicate summaries from towers # TODO this is global. not good. _summaried_gradient = set()
[docs] def __init__(self, regex='.*', collections=None): """ Args: regex(str): same as in :class:`MapGradient`. collections (list[str]): list of collection names """ super(SummaryGradient, self).__init__(self._mapper, regex) self._coll = collections
def _mapper(self, grad, var): name = var.op.name if re.match('tower[0-9]+/', name): # replicated training, var may come from different towers return grad if name not in SummaryGradient._summaried_gradient: SummaryGradient._summaried_gradient.add(name) tfv1.summary.histogram(name + '-grad', grad, collections=self._coll) add_moving_summary(rms(grad, name=name + '/rms')) return grad
[docs]class PrintGradient(MapGradient): """ Print the gradients every step with :func:`symbolic_functions.print_stat`. """ _printed = set() # TODO this is global. not good.
[docs] def __init__(self, regex='.*'): """ Args: regex(str): same as in :class:`MapGradient`. """ super(PrintGradient, self).__init__(self._mapper, regex)
def _mapper(self, grad, var): name = var.op.name if name not in PrintGradient._printed: PrintGradient._printed.add(name) grad = print_stat(grad, message=name + '-grad') return grad
[docs]class CheckGradient(MapGradient): """ Run :func:`tf.check_numerics` for each gradient. """ def __init__(self): super(CheckGradient, self).__init__(self._mapper) def _mapper(self, grad, var): # this was very slow.... see #3649 # op = tf.Assert(tf.reduce_all(tf.is_finite(var)), [var], summarize=100) grad = tf.check_numerics(grad, 'CheckGradient/' + var.op.name) return grad
[docs]class ScaleGradient(MapGradient): """ Scale certain gradient by a multiplier. """
[docs] def __init__(self, multipliers, verbose=True): """ Args: multipliers (tuple or list): tuple of (regex, float), or list of such tuples. verbose (bool): whether to print logs or not Example: Use double learning rate for all the bias (as in caffe), and freeze layer0: .. code-block:: python from tensorpack.tfutils import optimizer, gradproc opt = optimizer.apply_grad_processors( opt, [gradproc.ScaleGradient( [('.*/b', 2.), ('layer0/.*', 0.)] )]) """ if not isinstance(multipliers, list): multipliers = [multipliers] self.multipliers = multipliers assert verbose in [True, False], verbose self._verbose = verbose super(ScaleGradient, self).__init__(self._mapper)
def _mapper(self, grad, var): varname = var.op.name for regex, val in self.multipliers: # always match against the whole name if not regex.endswith('$'): regex = regex + '$' if re.match(regex, varname): if self._verbose: logger.info("Gradient of '{}' is multipled by {}".format(varname, val)) if val != 0: # skip zero to speed up return grad * val else: return None return grad