tf.estimator 学习
基本结构
两个基本参数: model_fn and params1nn = tf.estimator.Estimator(model_fn=model_fn, params=model_params)
- model_fn: A function object that contains all the aforementioned logic to support training, evaluation, and prediction. You are responsible for implementing that functionality.
- params: An optional dict of hyperparameters (e.g., learning rate, dropout) that will be passed into the model_fn.
Notice:
The Estimator
also accepts the general configuration arguments model_dir and config.
构建model_fn
基本骨架
Input: 接受参数
- features
- A dict containing the features passed to the model via
input_fn
.
- A dict containing the features passed to the model via
- labels
- A Tensor containing the labels passed to the model via
input_fn
- Will be empty for predict() calls, as these are the values the model will infer.
- A Tensor containing the labels passed to the model via
- mode
tf.estimator.ModeKeys.TRAIN
The model_fn was invoked in training mode, namely via atrain()
calltf.estimator.ModeKeys.EVAL
The model_fn was invoked in evaluation mode, namely via anevaluate()
call.tf.estimator.ModeKeys.PREDICT
The model_fn was invoked in predict mode, namely via apredict()
call.
- params
- containing a dict of hyperparameters used for training
body: 逻辑过程
- Configure the model
- a neural network
- Define the loss function
- to calculate how closely the model’s predictions match the target values
- Define the training operation
- to specify the
optimizer
algorithm to minimize the loss values calculated by the loss function.
- to specify the
- Generate predictions
- Return predictions/loss/train_op/eval_metric_ops in EstimatorSpec object
|
|
output: 返回
必须返回一个tf.estimator.EstimatorSpec
对象1return EstimatorSpec(mode, predictions, loss, train_op, eval_metric_ops)
- mode (required in all mode)
- 直接将model_fn的mode传递下去
- predictions (required in PREDICT mode)
- A dict that maps key names of your choice to Tensors containing the predictions from the model.
- In PREDICT mode, the dict that you return in EstimatorSpec will then be returned by
predict()
- you can construct it in the format in which you’d like to consume it
- loss ((required in EVAL and TRAIN mode))
- A Tensor containing a scalar loss value
- the output of the model’s loss function
- calculated over all the input examples
- is used in TRAIN mode for error handling and logging
- is automatically included as a metric in EVAL mode.
- train_op (required only in TRAIN mode).
- An Op that runs one step of training
- eval_metric_ops (optional)
- A dict of name/value pairs specifying the metrics that will be calculated when the model runs in EVAL mode.
- The name is a label of your choice for the metric
- the value is the result of your metric calculation.
- The tf.metrics module provides predefined functions for a variety of common metrics. 1eval_metric_ops = { "accuracy": tf.metrics.accuracy(labels, predictions) }