python - "Cannot convert a ndarray into a Tensor or Operation."尝试从 tensorflow 中的 session.run 获取值时出错

标签 python numpy tensorflow

我在 tensorflow 中创建了一个孪生网络。我正在使用以下代码计算两个张量之间的距离:

distance = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(question1_predictions, question2_predictions)), reduction_indices=1))

我能够毫无错误地训练模型。在推理部分,我正在检索 distance张量如下:

test_state, distance = sess.run([question1_final_state, distance], feed_dict=feed)

Tensorflow 抛出错误:

Fetch argument array([....], dtype=float32) has invalid type , must be a string or Tensor. (Can not convert a ndarray into a Tensor or Operation.)

当我打印 distance张量,在 session.run 之前和之后在训练部分,它显示为<class 'tensorflow.python.framework.ops.Tensor'> .所以张量的替换distance用 numpy distance正在发生在 session.run推理部分。按照推理部分的代码:

with graph.as_default():
    saver = tf.train.Saver()

with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer(), feed_dict={embedding_placeholder: embedding_matrix})
    saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))

    test_state = sess.run(initial_state)

    for ii, (x1, x2, batch_test_ids) in enumerate(get_test_batches(test_question1_features, test_question2_features, test_ids, batch_size), 1):
        feed = {question1_inputs: x1,
                question2_inputs: x2,
                keep_prob: 1,
                initial_state: test_state
               }
        test_state, distance = sess.run([question1_final_state, distance], feed_dict=feed)

最佳答案

看起来你用 numpy 数组 distance = sess.run(distance) 覆盖了 Tensor distance = tf.sqrt(...)

你的循环是罪魁祸首。将 t_state, distance = sess.run([question1_final_state, distance] 更改为 t_state, distance_other = sess.run([question1_final_state, distance]

关于python - "Cannot convert a ndarray into a Tensor or Operation."尝试从 tensorflow 中的 session.run 获取值时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44088706/

相关文章:

python - 在集群上使用 Python、Cython 和 GSL

python - 将包含日期时间对象的元组转换为 numpy 结构化数组时出现 `TypeError: float() argument must be a string or a number`

python - 如何在自定义 Keras 损失中将 softmax 输出转换为 one-hot 格式

python - 模块 'tensorflow._api.v2.train' 没有属性 'GradientDescentOptimizer'

tensorflow - 为什么在 Keras 指标函数中使用 axis=-1 ?

python - 使用 Spark 获取值超过某个阈值的所有列的名称

python - 使用请求通过 http 下载文件时的进度条

python - 是否可以使用BLAS来加速稀疏矩阵乘法?

python - 使用 3d 数组的索引来填充 4d 数组

python - 反转 3D 矩阵中嵌套数组的顺序