Source code for tensorpack.callbacks.steps

# -*- coding: utf-8 -*-
# File: steps.py

""" Some common step callbacks. """

import tqdm
from six.moves import zip

from ..compat import tfv1 as tf
from ..tfutils.common import get_global_step_var, get_op_tensor_name
from ..utils import logger
from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
from ..utils.utils import get_tqdm_kwargs
from .base import Callback

__all__ = ['TensorPrinter', 'ProgressBar', 'SessionRunTimeout']


[docs]class TensorPrinter(Callback): """ Prints the value of some tensors in each step. It's an example of how ``before_run/after_run`` works. """
[docs] def __init__(self, names): """ Args: names(list): list of string, the names of the tensors to print. """ names = [get_op_tensor_name(n)[1] for n in names] logger.warn("Using tf.Print in the graph is much faster than TensorPrinter!") self._names = names
def _setup_graph(self): self._fetches = self.get_tensors_maybe_in_tower(self._names) def _before_run(self, _): return self._fetches def _after_run(self, _, vals): args = vals.results assert len(args) == len(self._names), len(args) for n, v in zip(self._names, args): logger.info("{}: {}".format(n, v))
[docs]class ProgressBar(Callback): """ A progress bar based on tqdm. This callback is one of the :func:`DEFAULT_CALLBACKS()`. """ _chief_only = False
[docs] def __init__(self, names=()): """ Args: names(tuple[str]): the names of the tensors to monitor on the progress bar. """ super(ProgressBar, self).__init__() self._names = [get_op_tensor_name(n)[1] for n in names] self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names] self._bar = None
def _before_train(self): self._last_updated = self.local_step self._total = self.trainer.steps_per_epoch self._tqdm_args = get_tqdm_kwargs(leave=True) self._fetches = self.get_tensors_maybe_in_tower(self._names) or None if self._fetches: for t in self._fetches: assert t.shape.ndims == 0, "ProgressBar can only print scalars, not {}".format(t) self._fetches = tf.train.SessionRunArgs(self._fetches) self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} " def _before_epoch(self): self._bar = tqdm.trange(self._total, **self._tqdm_args) def _after_epoch(self): self._bar.close() def _before_run(self, _): # update progress bar when local step changed (one step is finished) if self.local_step != self._last_updated: self._last_updated = self.local_step return self._fetches else: return None def _after_run(self, _, run_values): res = run_values.results if res: self._bar.set_postfix(zip(self._tags, res)) def _trigger_step(self): self._bar.update() def _after_train(self): if self._bar: # training may get killed before the first step self._bar.close()
class MaintainStepCounter(Callback): """ It maintains the global step in the graph, making sure it's increased by one at every `hooked_sess.run`. This callback is used internally by the trainer, you don't need to worry about it. """ _chief_only = False """ In distributed training, we let each worker maintain its local global_step. """ def _setup_graph(self): # ensure it exists gs_var = get_global_step_var() with tf.name_scope(None): self.gs_incr_op = tf.assign_add( gs_var, 1, name=GLOBAL_STEP_INCR_OP_NAME).op self._fetches = tf.train.SessionRunArgs(self.gs_incr_op) def _before_train(self): if self.global_step != 0: logger.info("Start training with global_step={}".format(self.global_step)) def _before_run(self, _): # always increase global_step when hooked_sess.run is called return self._fetches def _after_run(self, _, __): # Keep python-side global_step in agreement with TF-side self.trainer.loop._global_step += 1
[docs]class SessionRunTimeout(Callback): """ Add timeout option to each sess.run call. """
[docs] def __init__(self, timeout_in_ms): """ Args: timeout_in_ms (int): """ self._timeout = int(timeout_in_ms) opt = tf.RunOptions(timeout_in_ms=timeout_in_ms) self._runargs = tf.train.SessionRunArgs(fetches=[], options=opt)
def _before_run(self, _): return self._runargs