tensorpack.contrib package

class tensorpack.contrib.keras.KerasPhaseCallback(isTrain)[source]

Bases: tensorpack.callbacks.base.Callback

Keras needs an extra input if learning_phase is used by the model This callback will be used: 1. By the trainer with isTrain=True 2. By InferenceRunner with isTrain=False, in the form of hooks

If you use KerasModel or setup_keras_trainer(), this callback will be automatically added when needed.

tensorpack.contrib.keras.setup_keras_trainer(trainer, get_model, input_signature, target_signature, input, optimizer, loss, metrics)[source]
Parameters
  • trainer (SingleCostTrainer) –

  • get_model (input1, input2, .. -> tf.keras.Model) – A function which takes tensors, builds and returns a Keras model. It will be part of the tower function.

  • input (InputSource) –

  • optimizer (tf.train.Optimizer) –

  • metrics (loss,) – list of strings

class tensorpack.contrib.keras.KerasModel(get_model, input_signature=None, target_signature=None, input=None, trainer=None)[source]

Bases: object

__init__(get_model, input_signature=None, target_signature=None, input=None, trainer=None)[source]
Parameters
  • get_model (input1, input2, .. -> keras.Model) – A function which takes tensors, builds and returns a Keras model. It will be part of the tower function.

  • input_signature ([tf.TensorSpec]) – required. The signature for inputs.

  • target_signature ([tf.TensorSpec]) – required. The signature for the targets tensors.

  • input (InputSource | DataFlow) – the InputSource or DataFlow where the input data comes from.

  • trainer (Trainer) – the default will check the number of available GPUs and use them all.

compile(optimizer, loss, metrics=None)[source]
Parameters
  • optimizer (tf.train.Optimizer) –

  • metrics (loss,) – string or list of strings

fit(validation_data=None, **kwargs)[source]
Parameters
  • validation_data (DataFlow or InputSource) – to be used for inference. The inference callback is added as the first in the callback list. If you need to use it in a different order, please write it in the callback list manually.

  • kwargs – same arguments as Trainer.train_with_defaults().