Tensorpack trainers contain logic of:
Building the graph.
Running the iterations (with callbacks).
Usually you won't touch these methods directly, but use higher-level interface on trainers. You'll only need to select what trainer to use. But some basic knowledge of how they work is useful:
Following the terminology in TensorFlow, a tower function is a callable that takes input tensors and adds one replicate of the model to the graph.
Most types of neural-network training could fall into this category. All trainers in tensorpack is a subclass of TowerTrainer. The concept of tower is used mainly to support:
Data-parallel multi-GPU training, where a replicate is built on each GPU.
Graph construction for inference, where a replicate is built under inference mode.
You'll provide a tower function to use
The function needs to follow some conventions:
It will always be called under a
TowerContext. which will contain information about reuse, training/inference, scope name, etc.
It might get called multiple times for data-parallel training or inference.
To respect variable reuse, use
tf.Variablein the function, unless you want to force creation of new variables.
In particular, when working with the
ModelDesc interface, its
build_graph method will be the tower function.
For data-parallel multi-GPU training, different multi-GPU trainers implement different parallel logic. They take care of device placement, gradient averaging and synchronoization in the efficient way and all reach the same performance as the official TF benchmarks. It takes only one line of code change to use them.
Note some common problems when using these trainers:
In each iteration, all GPUs (all replicates of the model) take tensors from the
InputSource, instead of take one for all and split. So the total batch size would become
(batch size of InputSource/DataFlow) * #GPU.
Splitting a tensor for data-parallel training makes no sense at all, only to put unnecessary shape constraints on the data. By letting each GPU train on its own input tensors, they can train on inputs of different shapes simultaneously.
The tower function (your model code) will get called multipile times. You'll need to be very careful when modifying global states in those functions, e.g. adding ops to TF collections.