Trainer
[TOC]
Generic trainer for TensorFlow models.
Other Functions and Classes
class skflow.TensorFlowTrainer
General trainer class.
Attributes: model: Model object. gradients: Gradients tensor. - - -
skflow.TensorFlowTrainer.__init__(loss, global_step, optimizer, learning_rate, clip_gradients=5.0) {#TensorFlowTrainer.init}
Build a trainer part of graph.
Args:
loss: Tensor that evaluates to model’s loss.global_step: Tensor with global step of the model.optimizer: Name of the optimizer class (SGD, Adam, Adagrad) or class.learning_rate: If this is constant float value, no decay function is used. Instead, a customized decay function can be passed that accepts global_step as parameter and returns a Tensor. e.g. exponential decay function: def exp_decay(global_step): return tf.train.exponential_decay( learning_rate=0.1, global_step=global_step, decay_steps=2, decay_rate=0.001)
Raises:
ValueError: if learning_rate is not a float or a callable.
skflow.TensorFlowTrainer.initialize(sess) {#TensorFlowTrainer.initialize}
Initalizes all variables.
Args:
sess: Session object.
Returns:
Values of initializers.
skflow.TensorFlowTrainer.train(sess, feed_dict_fn, steps, monitor, summary_writer=None, summaries=None, feed_params_fn=None) {#TensorFlowTrainer.train}
Trains a model for given number of steps, given feed_dict function.
Args:
sess: Session object.feed_dict_fn: Function that will return a feed dictionary.summary_writer: SummaryWriter object to use for writing summaries.steps: Number of steps to run.monitor: Monitor object to track training progress and induce early stoppingsummaries: Joined object of all summaries that should be ran.
Returns:
List of losses for each step.