tensorpack.train package

class tensorpack.train.Trainer(config)[source]

Bases: object

Base class for a trainer.


TrainConfig – the config used in this trainer.




tf.Session – the current session in use.


tf.MonitoredSession – the session with hooks.


Monitors – the monitors. Callbacks can use it for logging.


int – the number of epochs that have finished.


int – the number of steps that have finished in the current epoch.


int – the number of steps that have finished or is currently running.

Parameters:config (TrainConfig) – the train config.
get_predictor(input_names, output_names, tower=0)[source]
  • input_names (list), output_names(list) – list of names

  • tower (int) – return the predictor on the kth tower, defined by config.predict_tower.


an OnlinePredictor.

get_predictors(input_names, output_names, n)[source]

Return n predictors.

is_chief = True

Run the main training loop.


Use this method before Trainer._setup() finishes, to register a callback to the trainer.

The hooks of the registered callback will be bind to the self.hooked_sess session.


Abstract method: run one iteration. Subclass should define what is “iteration”.


Setup the trainer and be ready for the main loop.


Start training


The variable scope name a predictor should be built in.

exception tensorpack.train.StopTraining[source]

Bases: exceptions.BaseException

An exception thrown to stop training.

class tensorpack.train.TrainConfig(dataflow=None, data=None, model=None, callbacks=None, extra_callbacks=None, monitors=None, session_creator=None, session_config=None, session_init=None, starting_epoch=1, steps_per_epoch=None, max_epoch=99999, nr_tower=1, tower=None, predict_tower=[0], **kwargs)[source]

Bases: object

Config for trainer.

__init__(dataflow=None, data=None, model=None, callbacks=None, extra_callbacks=None, monitors=None, session_creator=None, session_config=None, session_init=None, starting_epoch=1, steps_per_epoch=None, max_epoch=99999, nr_tower=1, tower=None, predict_tower=[0], **kwargs)[source]
  • dataflow (DataFlow) – the dataflow to train.

  • data (InputSource) – an InputSource instance. Only one of dataflow or data has to be present.

  • model (ModelDesc) – the model to train.

  • callbacks (list) – a list of Callback to perform during training.

  • extra_callbacks (list) – the same as callbacks. This argument is only used to provide the defaults. The defaults are [MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), RunUpdateOps()]. The list of callbacks that will be used in the end are callbacks + extra_callbacks.

  • monitors (list) – a list of TrainingMonitor. Defaults to [TFEventWriter(), JSONWriter(), ScalarPrinter()].

  • session_creator (tf.train.SessionCreator) – Defaults to sesscreate.NewSessionCreator() with the config returned by tfutils.get_default_sess_config().

  • session_config (tf.ConfigProto) – when session_creator is None, use this to create the session.

  • session_init (SessionInit) – how to initialize variables of a session. Defaults to do nothing.

  • starting_epoch (int) – The index of the first epoch.

  • steps_per_epoch (int) – the number of steps (defined by Trainer.run_step()) to run in each epoch. Defaults to the input data size.

  • max_epoch (int) – maximum number of epoch to run training.

  • nr_tower (int) – number of training towers.

  • tower (list of int) – list of training towers in relative id.

  • predict_tower (list of int) – list of prediction towers in their relative gpu id. Use -1 for cpu.

class tensorpack.train.DistributedReplicatedTrainer(config, server)[source]

Bases: tensorpack.train.multigpu.MultiGPUTrainerBase

Distributed replicated training. Each worker process builds the same model on one or more GPUs. Gradients across GPUs are averaged within the worker, and get synchronously applied to the global copy of variables located on PS. Then each worker copy the latest variables from PS back to local.


Gradients are not averaged across workers.

__init__(config, server)[source]
  • config (TrainConfig) – the train config.

  • server (tf.train.Server) – the server object with ps and workers

class tensorpack.train.FeedfreeTrainerBase(config)[source]

Bases: tensorpack.train.base.Trainer

A base trainer which runs iteration without feed_dict (therefore faster) Expect config.data to be a FeedfreeInput.

build_train_tower(*args, **kwargs)[source]

Simply run self.train_op.

class tensorpack.train.SingleCostFeedfreeTrainer(*args, **kwargs)[source]

Bases: tensorpack.train.feedfree.FeedfreeTrainerBase

A feedfree Trainer which assumes a single cost.

tensorpack.train.SimpleFeedfreeTrainer(*args, **kwargs)[source]
tensorpack.train.QueueInputTrainer(config, input_queue=None)[source]

A wrapper trainer which automatically wraps config.dataflow by a QueueInput. It is an equivalent of SimpleTrainer(config) with config.data = QueueInput(dataflow).

  • config (TrainConfig) – a TrainConfig instance. config.dataflow must exist.

  • input_queue (tf.QueueBase) – an input queue. Defaults to the QueueInput default.

class tensorpack.train.MultiGPUTrainerBase(config)[source]

Bases: tensorpack.train.feedfree.FeedfreeTrainerBase

Base class for multi-gpu training

static build_on_multi_tower(towers, func, devices=None, var_strategy='shared', vs_names=None)[source]
  • towers – list of gpu relative ids

  • func – a lambda to be called inside each tower

  • devices – a list of devices to be used. By default will use GPUs in towers.

  • var_strategy (str) – ‘shared’ or ‘replicated’

  • vs_names (list[str]) – list of variable scope names to use.


List of outputs of func, evaluated on each tower.


Alias for SyncMultiGPUTrainerParameterServer(config, ps_device='gpu'), as this is the most commonly used synchronous multigpu trainer (but may not be more efficient than the other).

class tensorpack.train.AsyncMultiGPUTrainer(config, scale_gradient=True)[source]

Bases: tensorpack.train.multigpu.MultiGPUTrainerBase

A multi-tower multi-GPU trainer where each tower independently asynchronously updates the model without averaging the gradient.

__init__(config, scale_gradient=True)[source]
  • config (TrainConfig) –

  • scale_gradient (bool) – if True, will scale each gradient by 1.0/nr_gpu.

class tensorpack.train.LeastLoadedDeviceSetter(worker_device, ps_devices)[source]

Bases: object

Helper class to assign variables on the least loaded ps-device.

__init__(worker_device, ps_devices)[source]
  • worker_device – the device to use for compute ops.

  • ps_devices – a list of device to use for Variable ops.

class tensorpack.train.SyncMultiGPUTrainerReplicated(config, gpu_prefetch=True)[source]

Bases: tensorpack.train.multigpu.MultiGPUTrainerBase

Data-parallel Multi-GPU trainer where each GPU contains a replicate of the whole model. Each gradient update is broadcast and synced.

__init__(config, gpu_prefetch=True)[source]
Parameters:gpu_prefetch (config,) – same as in SyncMultiGPUTrainerParameterServer
static get_post_init_ops()[source]
class tensorpack.train.SyncMultiGPUTrainerParameterServer(config, ps_device='gpu', gpu_prefetch=True)[source]

Bases: tensorpack.train.multigpu.MultiGPUTrainerBase

A data-parallel Multi-GPU trainer which synchronoizes the gradients computed from each tower, averages them and update to variables stored across all GPUs or on CPU.

__init__(config, ps_device='gpu', gpu_prefetch=True)[source]
  • config (TrainConfig) –

  • ps_device – either ‘gpu’ or ‘cpu’, where variables are stored.

  • gpu_prefetch (bool) – whether to prefetch the data to each GPU. Usually improve performance.

class tensorpack.train.SimpleTrainer(config)[source]

Bases: tensorpack.train.base.Trainer

A naive single-tower single-cost demo trainer. Support both InputSource and DataFlow. When DataFlow is given, the InputSource to be used will be FeedInput(df).

Parameters:config (TrainConfig) – the training config.