android - 保存和使用 TensorForestEstimator for Android 模型时出错

标签 android tensorflow skflow

我使用在 tensorflow 中实现的随机森林估计器来预测文本是否为英文。我使用以下代码(train_input_fn 函数返回特征和类标签)保存了我的模型(一个包含 2k 个样本和 2 个类标签 0/1(非英语/英语)的数据集):

model_path='test/'
TensorForestEstimator(params, model_dir='model/')
estimator.fit(input_fn=train_input_fn, max_steps=1)

运行上述代码后,graph.pbtxt和checkpoints被保存在模型文件夹中。现在我想在 Android 上使用它。我有两个问题:

  1. 作为第一步,我需要将图形和检查点卡住到一个 .pb 文件中,以便在 Android 上使用它。我尝试了 freeze_graph(我在这里使用了代码:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)。当我在我的模式中调用 freeze_graph 时,出现以下错误并且代码无法创建最终的 .pb 图:

    文件“/Users/XXXXXXX/freeze_graph.py”,第 105 行,freeze_graph _ = tf.import_graph_def(input_graph_def, name="") 文件“/anaconda/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/importer.py”,第 258 行,在 import_graph_def op_def = op_dict[node.op] KeyError: u'CountExtremelyRandomStats'

这就是我调用 freeze_graph 的方式:

def save_model_android():
    checkpoint_state_name = "model.ckpt-1"
    input_graph_name = "graph.pbtxt"
    output_graph_name = "output_graph.pb"
    checkpoint_path = os.path.join(model_path, checkpoint_state_name)

    input_graph_path = os.path.join(model_path, input_graph_name)
    input_saver_def_path = None
    input_binary = False
    output_node_names = "output"
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_graph_path = os.path.join(model_path, output_graph_name)
    clear_devices = True

    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                              input_binary, checkpoint_path,
                              output_node_names, restore_op_name,
                              filename_tensor_name, output_graph_path,
                              clear_devices, "")

我还尝试了在“tf.contrib.learn.datasets.load_iris”中对 iris 数据集进行卡住。我犯了同样的错误。所以我认为它与数据集无关。

  1. 作为第二步,我需要使用手机上的 .pb 文件来预测文本。我通过谷歌找到了相机演示示例,它包含很多代码。我想知道是否有分步教程如何通过传递特征向量并获取类标签在 Android 上使用 Tensorflow 模型。

提前致谢!

更新

通过使用最新版本的tensorflow(0.12),问题得到解决。但是,现在的问题是我应该将什么传递给 output_node_names ???我怎样才能得到图中的输出节点是什么?

最佳答案

Re (1) 看起来您正在 tensorflow 构建上运行 freeze_graph,它无法访问 contrib 操作。也许在调用 freeze_graph 之前尝试显式导入 tensorforest?

Re (2) 我不知道更简单的例子。

关于android - 保存和使用 TensorForestEstimator for Android 模型时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40849477/

相关文章:

android - 带有子组件的 Dagger 2 绑定(bind)委托(delegate)

java - 安卓Java : Show the value only of `id` when a listrow is selected

android - 如何在Interactor/UseCase中获取CoroutineScope

linux - 没有名为 'tensorflow.python.platform' 的模块

Tensorflow : Shape error in LSTM model expected shape=(None, 无,90),发现形状=[90, 1, 78]

python - skflow pandas数据集平均每2行

android - 如何访问对话框中的 EditText 字段?

python - 重置 tensorflow 流指标的变量

frameworks - TF Learn(又名 Scikit Flow)和 TFLearn(又名 TFLearn.org)有什么区别

python - skflow.TensorFlowDNNRegressor 参数