python - Tensorflow in_top_kv

标签 python tensorflow shapes recurrent-neural-network tensor

您好,我正在尝试构建一个具有 11 个输入和 2 个输出的简单 rnn X=tf.placeholder(tf.float32,[无,n_steps,n_inputs]) y=tf.placeholder(tf.int32,[None,n_steps,n_outputs])

我知道 rnn 除了 [batch_size,n_steps,n_inputs] 形状的输入,所以这就是为什么我将占位符塑造成这样的原因

但是当我运行代码时,我得到一个错误

ValueError: Shape must be rank 2 but is rank 3 for 'in_top_k/InTopKV2' (op: 'InTopKV2') with input shapes: [1,270,2], [1,270,2], [].

错误似乎起源于此:correct = tf.nn.in_top_k(logits,tf.reshape(y,[1,n_steps,n_outputs]),1)

我尝试过 reshape logits、压缩 logits、扩大 y 维度,但似乎没有任何效果。

我注意到的一个区别是,当我用

压缩 logits 时
tf.squeeze(logits)

现在的错误是

ValueError: Shape must be rank 1 but is rank 3

这是我能够取得的唯一“进步”,我们将不胜感激。

请放轻松,这是我的第一个问题

最佳答案

您必须将输入重新整形为 2D 张量,然后您可以将结果重新整形为所需的形状:

logits_res = tf.reshape(logits, (-1, n_outputs))
y_res = tf.reshape(y, (-1, n_outputs))
correct_res = tf.nn.in_top_k(logits_res, y_res, 1)
correct = tf.reshape(correct_res, (-1, n_steps))

关于python - Tensorflow in_top_kv,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53363349/

相关文章:

javascript - 在 raiseevent() 方法中存储 HTML 对象(div 元素) -胡椒机器人

python - 打破嵌套循环

linux - ImportError:/lib64/libc.so.6: 版本 `GLIBC_2.17' 在 RHEL 6.9 上导入 tensorflow 时找不到(圣地亚哥)

python - Tensorflow 2.0 意外 OOM

java - BlueJ 的矩形类

python - Django + mod_wsgi + apache2 : server hangs

python - 应用引擎 SDK : How do I view keys in a specific namespace using the Memcache Viewer?

python - Tensorflow 的 DNNLinearCombinedClassifier 打印回归损失而不是分类损失

Java - 从 PNG 图像创建形状 (NullPointerException)

python - 展平一个不规则列表并 reshape 另一个重建原始元素顺序的相同长度的平面列表