Source code for

# -*- coding: utf-8 -*-
# File:

import traceback
from contextlib import contextmanager
from time import perf_counter as timer  # noqa
from ..compat import tfv1 as tf

from ..utils import logger
from ..utils.utils import humanize_time_delta
from .base import Callback
from .hooks import CallbackToHook

__all__ = ['Callbacks']

class CallbackTimeLogger(object):
    def __init__(self):
        self.times = []
        self.tot = 0

    def add(self, name, time):
        self.tot += time
        self.times.append((name, time))

    def timed_callback(self, name):
        s = timer()
        self.add(name, timer() - s)

    def log(self):

        """ log the time of some heavy callbacks """
        if self.tot < 2:
        msgs = []
        for name, t in self.times:
            if t / self.tot > 0.3 and t > 1:
                msgs.append(name + ": " + humanize_time_delta(t))
            "Callbacks took {:.3f} sec in total. {}".format(
                self.tot, '; '.join(msgs)))

[docs]class Callbacks(Callback): """ A container to hold all callbacks, and trigger them iteratively. This is only used by the base trainer to run all the callbacks. Users do not need to use this class. """
[docs] def __init__(self, cbs): """ Args: cbs(list): a list of :class:`Callback` instances. """ # check type for cb in cbs: assert isinstance(cb, Callback), cb.__class__ = cbs
def _setup_graph(self): with tf.name_scope(None): # clear the name scope for cb in cb.setup_graph(self.trainer) def _before_train(self): for cb in cb.before_train() def _after_train(self): for cb in # make sure callbacks are properly finalized try: cb.after_train() except Exception: traceback.print_exc() def get_hooks(self): return [CallbackToHook(cb) for cb in] def trigger_step(self): for cb in cb.trigger_step() def _trigger_epoch(self): tm = CallbackTimeLogger() for cb in display_name = str(cb) with tm.timed_callback(display_name): cb.trigger_epoch() tm.log() def _before_epoch(self): for cb in cb.before_epoch() def _after_epoch(self): for cb in cb.after_epoch()