Write a Callback

Everything other than the training iterations happen in the callbacks. Most of the fancy things you want to do will probably end up here.

Callbacks are called during training. The time where each callback method gets called is demonstrated in this snippet.

def train(self):
  # ... a predefined trainer may create graph for the model here ...
  callbacks.setup_graph()
  # ... create session, initialize session, finalize graph ...
  # start training:
  with sess.as_default():
    callbacks.before_train()
    for epoch in range(starting_epoch, max_epoch + 1):
      callbacks.before_epoch()
      for step in range(steps_per_epoch):
        self.run_step()  # callbacks.{before,after}_run are hooked with session
        callbacks.trigger_step()
      callbacks.after_epoch()
      callbacks.trigger_epoch()
    callbacks.after_train()

Note that at each place, each callback will be called in the order they are given to the trainer.

Explain the Callback Methods

To write a callback, subclass Callback and implement the corresponding underscore-prefixed methods. You can overwrite any of the following methods in the new callback:

  • _setup_graph(self)

    Create any tensors/ops in the graph which you might need to use in the callback. This method exists to fully separate between “define” and “run”, and also to avoid the common mistake to create ops inside loops. All changes to the graph should be made in this method.

    To access tensors/ops which are already defined, you can use TF methods such as graph.get_tensor_by_name. If you’re using a TowerTrainer, more tools are available:

  • _before_train(self)

    Can be used to run some manual initialization of variables, or start some services for the training.

  • _after_train(self)

    Usually some finalization work.

  • _before_epoch(self), _after_epoch(self)

    _trigger_epoch should be enough for most cases, as can be seen from the scheduling snippet above. These two methods should be used only when you really need something to happen immediately before/after an epoch. And when you do need to use them, make sure they are very very fast to avoid affecting other callbacks which use them.

  • _before_run(self, ctx), _after_run(self, ctx, values)

    These are the equivalence of tf.train.SessionRunHook. Please refer to TensorFlow documentation for detailed API. They are used to run extra ops / eval extra tensors / feed extra values along with the actual training iterations.

    IMPORTANT Note the difference between running along with an iteration and running after an iteration. When you write

    def _before_run(self, _):
      return tf.train.SessionRunArgs(fetches=my_op)
    

    The training loops would become equivalent to sess.run([training_op, my_op]).

    However, if you write my_op.run() in _trigger_step, the training loop would become sess.run(training_op); sess.run(my_op);. Usually the difference matters, please choose carefully.

    If you want to run ops that depend on your inputs, it’s better to run it along with the training iteration, to avoid wasting a datapoint and avoid messing up hooks of the InputSource.

  • _trigger_step(self)

    Do something (e.g. running ops, print stuff) after each step has finished. Be careful to only do light work here because it could affect training speed.

  • _trigger_epoch(self)

    Do something after each epoch has finished. This method calls self.trigger() by default.

  • _trigger(self)

    Define something to do here without knowing how often it will get called. By default it will get called by _trigger_epoch, but you can customize the scheduling of this method by PeriodicTrigger, to let this method run every k steps or every k epochs.

What you can do in the callback

  • Access tensors / ops (details mentioned above):

    • For existing tensors/ops created in the tower, access them through self.trainer.towers.

    • Extra tensors/ops have to be created in _setup_graph callback method.

  • Access the current graph and session by self.trainer.graph and self.trainer.sess, self.trainer.hooked_sess. Note that calling (hooked_)sess.run to evaluate tensors may have unexpected effect in certain scenarios. In general, use sess.run to evaluate tensors that do not depend on the inputs. And use _{before,after}_run to evaluate tensors together with inputs if the tensors depend on the inputs.

  • Write stuff to the monitor backend, by self.trainer.monitors.put_xxx. The monitors might direct your events to TensorFlow events file, JSON file, stdout, etc. You can access history monitor data as well. See the docs for Monitors

  • Access the current status of training, such as self.epoch_num, self.global_step. See docs of Callback

  • Stop training by raise StopTraining() (with from tensorpack.train import StopTraining).

  • Anything else that can be done with plain python.

Typical Steps about Writing/Using a Callback

  • Define the callback in __init__, prepare for it in _setup_graph, _before_train.

  • Know whether you want to do something along with the training iterations or not. If yes, implement the logic with _{before,after}_run. Otherwise, implement in _trigger, or _trigger_step.

  • You can choose to only implement “what to do”, and leave “when to do” to other wrappers such as PeriodicTrigger, PeriodicCallback, or EnableCallbackIf. Of course you also have the freedom to implement “what to do” and “when to do” altogether.

Examples

The builtin callbacks listed in API docs and callbacks in tensorpack examples are great examples to learn how to write a callback.

Custom callbacks in tensorpack examples can be found by ack Callback in the example directory. Some interesting ones are:

  • Run inference during training with a predictor: CycleGAN, RCNN

  • Run some ops to modify weights during training: WGAN, MOCO