tf.estimator 学习

基本结构

两个基本参数: model_fn and params

1
nn = tf.estimator.Estimator(model_fn=model_fn, params=model_params)

  • model_fn: A function object that contains all the aforementioned logic to support training, evaluation, and prediction. You are responsible for implementing that functionality.
  • params: An optional dict of hyperparameters (e.g., learning rate, dropout) that will be passed into the model_fn.

Notice:
The Estimator also accepts the general configuration arguments model_dir and config.

构建model_fn

基本骨架
Input: 接受参数
  • features
    • A dict containing the features passed to the model via input_fn.
  • labels
    • A Tensor containing the labels passed to the model via input_fn
    • Will be empty for predict() calls, as these are the values the model will infer.
  • mode
    • tf.estimator.ModeKeys.TRAIN The model_fn was invoked in training mode, namely via a train() call
    • tf.estimator.ModeKeys.EVAL The model_fn was invoked in evaluation mode, namely via an evaluate() call.
    • tf.estimator.ModeKeys.PREDICT The model_fn was invoked in predict mode, namely via a predict() call.
  • params
    • containing a dict of hyperparameters used for training
body: 逻辑过程
  1. Configure the model
    • a neural network
  2. Define the loss function
    • to calculate how closely the model’s predictions match the target values
  3. Define the training operation
    • to specify the optimizer algorithm to minimize the loss values calculated by the loss function.
  4. Generate predictions
  5. Return predictions/loss/train_op/eval_metric_ops in EstimatorSpec object
1
return EstimatorSpec(mode, predictions, loss, train_op, eval_metric_ops)
output: 返回

必须返回一个tf.estimator.EstimatorSpec对象

1
return EstimatorSpec(mode, predictions, loss, train_op, eval_metric_ops)

  • mode (required in all mode)
    • 直接将model_fn的mode传递下去
  • predictions (required in PREDICT mode)
    • A dict that maps key names of your choice to Tensors containing the predictions from the model.
    • In PREDICT mode, the dict that you return in EstimatorSpec will then be returned by predict()
    • you can construct it in the format in which you’d like to consume it
  • loss ((required in EVAL and TRAIN mode))
    • A Tensor containing a scalar loss value
    • the output of the model’s loss function
    • calculated over all the input examples
    • is used in TRAIN mode for error handling and logging
    • is automatically included as a metric in EVAL mode.
  • train_op (required only in TRAIN mode).
    • An Op that runs one step of training
  • eval_metric_ops (optional)
    • A dict of name/value pairs specifying the metrics that will be calculated when the model runs in EVAL mode.
    • The name is a label of your choice for the metric
    • the value is the result of your metric calculation.
    • The tf.metrics module provides predefined functions for a variety of common metrics.
      1
      eval_metric_ops = { "accuracy": tf.metrics.accuracy(labels, predictions) }
相关API

optimizers