python - 如何使用 TensorFlow Estimator API 运行异步预测?

标签 python tensorflow tensorflow-estimator

我正在使用 tf.estimator API 来预测标点符号。我使用 TFRecords 和 tf.train.shuffle_batch 对预处理数据进行了训练。现在我想做预测。我可以将静态 NumPy 数据馈送到 tf.constant 并从 input_fn 返回。

但是我正在处理序列数据,我需要一次提供一个示例,下一个输入取决于上一个输出。我还希望能够处理通过 HTTP 请求输入的数据。

每次调用 estimator.predict 时,它都会重新加载检查点并重新创建整个图。这是缓慢且昂贵的。所以我需要能够动态地将数据馈送到 input_fn

我目前的尝试大致是这样的:

feature_input = tf.placeholder(tf.int32, shape=[1, MAX_SUBSEQUENCE_LEN])
q = tf.FIFOQueue(1, tf.int32, shapes=[[1, MAX_SUBSEQUENCE_LEN]])
enqueue_op = q.enqueue(feature_input)

def input_fn():
    return q.dequeue()

estimator = tf.estimator.Estimator(model_fn, model_dir=model_file)
predictor = estimator.predict(input_fn=input_fn)
sess = tf.Session()
output = None

while True:
    x = get_numpy_data(x, output)
    if x is None:
        break
    sess.run(enqueue_op, {feature_input: x})
    output = predictor.next()
    save_to_file(output)

sess.close()

但是我收到以下错误: ValueError:输入图和层图不相同:Tensor("EmbedSequence/embedding_lookup:0", shape=(1, 200, 128), dtype=float32) 不是来自传入的图。

如何通过 input_fn 将数据异步插入我现有的图表中,一次获得一个预测?

最佳答案

事实证明,主要问题是所有张量都需要在 input_fn 中创建,否则它们不会添加到同一个图中。我需要运行入队操作,但无法访问从输入函数返回的任何内容。

我最终继承了 Estimator 类并创建了一个自定义预测函数,它允许我动态地将数据添加到预测队列并返回结果:

# async_estimator.py

import six
import tensorflow as tf
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.estimator import _check_hooks_type
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.training import saver
from tensorflow.python.training import training


class AsyncEstimator(Estimator):

    def async_predictor(self,
                dtype,
                shape=None,
                predict_keys=None,
                hooks=None,
                checkpoint_path=None):
        """Returns a tuple of functions: first runs predicitons on the model, second cleans up
        Args:
          dtype: the dtype of the input
          shape: the shape of the input placeholder (optional)
          predict_keys: list of `str`, name of the keys to predict. It is used if
            the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used
            then rest of the predictions will be filtered from the dictionary. If
            `None`, returns all.
          hooks: List of `SessionRunHook` subclass instances. Used for callbacks
            inside the prediction call.
          checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
            latest checkpoint in `model_dir` is used.
        Returns:
          (predict, finish): tuple of functions

            predict: runs a single prediction and returns the results
                Args:
                    x: NumPy array of input
                Returns:
                    Evaluated value of the prediction

            finish: closes the session, allowing the program to exit

        Raises:
          ValueError: Could not find a trained model in model_dir.
          ValueError: if batch length of predictions are not same.
          ValueError: If there is a conflict between `predict_keys` and
            `predictions`. For example if `predict_keys` is not `None` but
            `EstimatorSpec.predictions` is not a `dict`.
        """
        hooks = _check_hooks_type(hooks)
        # Check that model has been trained.
        if not checkpoint_path:
            checkpoint_path = saver.latest_checkpoint(self._model_dir)
        if not checkpoint_path:
            raise ValueError('Could not find trained model in model_dir: {}.'.format(
                self._model_dir))

        with ops.Graph().as_default() as g:
            random_seed.set_random_seed(self._config.tf_random_seed)
            training.create_global_step(g)
            input_placeholder = tf.placeholder(dtype=dtype, shape=shape)
            queue = tf.FIFOQueue(1, dtype, shapes=shape)
            enqueue_op = queue.enqueue(input_placeholder)
            features = queue.dequeue()
            estimator_spec = self._call_model_fn(features, None,
                                                 model_fn_lib.ModeKeys.PREDICT)
            predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
            mon_sess = training.MonitoredSession(
                    session_creator=training.ChiefSessionCreator(
                        checkpoint_filename_with_path=checkpoint_path,
                        scaffold=estimator_spec.scaffold,
                        config=self._session_config),
                    hooks=hooks)

            def predict(x):
                if mon_sess.should_stop():
                    raise StopIteration
                mon_sess.run(enqueue_op, {input_placeholder: x})
                preds_evaluated = mon_sess.run(predictions)
                if not isinstance(predictions, dict):
                    return preds_evaluated
                else:
                    preds = []
                    for i in range(self._extract_batch_length(preds_evaluated)):
                        preds.append({
                            key: value[i]
                            for key, value in six.iteritems(preds_evaluated)
                        })
                    return preds

            def finish():
                mon_sess.close()

            return predict, finish

下面是使用它的粗略代码:

import tensorflow as tf
from async_estimator import AsyncEstimator


def doPrediction(model_fn, model_dir, max_seq_length):
    estimator = AsyncEstimator(model_fn, model_dir=model_dir)
    predict, finish = estimator.async_predictor(dtype=tf.int32, shape=(1, max_seq_length))
    output = None

    while True:
        # my input is dependent on the previous output
        x = get_numpy_data(output)
        if x is None:
            break
        output = predict(x)
        save_to_disk(output)

    finish()

注意:这是一个适合我需要的简单解决方案,可能需要针对其他情况进行修改。它在 TensorFlow 1.2.1 上运行。

希望 TF 能正式采用这样的东西,使使用 Estimator 进行动态预测服务变得更加容易。

关于python - 如何使用 TensorFlow Estimator API 运行异步预测?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45111480/

相关文章:

python - 创建 Estimator 后更改 Keras 有状态 RNN 模型、层和方法的状态

python - 如何在 tf.estimator 的 input_fn 中使用 tf.data 的可初始化迭代器?

python - 在 TensorFlow 中,Session.run() 和 Tensor.eval() 有什么区别?

tensorflow - 使用 BERT (TF 1.x) 保存的模型执行推理

python - Python 3.7中根据字符串的输入调用特定函数

python - 导入 flask 企业时没有名为核心的模块

python - 加减

python - 如何将字符串转换为字典列表

python - Tensorflow 1.11 支持 python 3.7 吗?

python - 将张量转换为 Numpy 数组 - keras 中的自定义损失函数