python - 如何让 TensorFlow 的 'import_graph_def' 返回 Tensors

标签 python machine-learning tensorflow restore

如果我尝试导入已保存的 TensorFlow图定义与

import tensorflow as tf
from tensorflow.python.platform import gfile

with gfile.FastGFile(FLAGS.model_save_dir.format(log_id) + '/graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
x, y, y_ = tf.import_graph_def(graph_def, 
                               return_elements=['data/inputs',
                                                'output/network_activation',
                                                'data/correct_outputs'],
                               name='')

返回值不是预期的 Tensor,而是其他东西:例如,将 x 获取为

Tensor("data/inputs:0", shape=(?, 784), dtype=float32)

我明白了

name: "data/inputs_1"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
    }
  }
}

也就是说,我得到的不是预期的张量 x,而是 x.op。这让我感到困惑,因为 documentation似乎在说我应该得到一个 Tensor(尽管那里有一堆 让人难以理解)。

如何让 tf.import_graph_def 返回特定的 Tensor 然后我可以使用(例如,在提供加载模型或运行分析时)?

最佳答案

名称'data/inputs''output/network_activation''data/correct_outputs' 实际上是操作名称。要让 tf.import_graph_def() 返回 tf.Tensor 对象,您应该将输出索引附加到操作名称,通常是 ':0' 对于单输出操作:

x, y, y_ = tf.import_graph_def(graph_def, 
                               return_elements=['data/inputs:0',
                                                'output/network_activation:0',
                                                'data/correct_outputs:0'],
                               name='')

关于python - 如何让 TensorFlow 的 'import_graph_def' 返回 Tensors,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37146272/

相关文章:

machine-learning - 随机梯度下降的成本函数是针对所有行计算还是仅针对迭代行计算?

machine-learning - 从头开始的随机森林回归

python - model.fit(...) 和 "Failed to convert a NumPy array to a Tensor"

python - 超拉斯 : 'List' object has no attribute 'shape'

python - 如何在Tensorflow中的tf.estimator上使用tensorflow调试工具tfdbg?

python - 类型错误 : Mean() missing 1 required positional argument: 'data'

python - 生成基准表

python - docker-compose 卷未正确安装

python - 拟合后手动计算的损失比上一个时期的损失高一个数量级

python - 几秒钟后 Subprocess.Popen 停止(或出现故障)