python - logits 和标签必须是可广播的 : logits_size=[32, 1] labels_size=[16,1]

标签 python tensorflow

我正在尝试学习 tensorflow,但出现以下错误:
logits 和标签必须是可广播的:logits_size=[32,1] labels_size=[16,1]

当我将其作为输入时,代码运行良好:

self.input = np.ones((500, 784))
self.y = np.ones((500, 1))

但是,当我添加额外的维度时,会抛出错误:
self.input = np.ones((500, 2, 784))
    self.y = np.ones((500, 1))

构建图的代码
    self.x = tf.placeholder(tf.float32, shape=[None] + self.config.state_size)
    self.y = tf.placeholder(tf.float32, shape=[None, 1])

    # network architecture
    d1 = tf.layers.dense(self.x, 512, activation=tf.nn.relu, name="dense1")
    d2 = tf.layers.dense(d1, 1, name="dense2")


    with tf.name_scope("loss"):
        self.cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=d2))
        self.train_step = tf.train.AdamOptimizer(self.config.learning_rate).minimize(self.cross_entropy,
                                                                                     global_step=self.global_step_tensor)
        correct_prediction = tf.equal(tf.argmax(d2, 1), tf.argmax(self.y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

有人可以向我解释为什么会发生这种情况以及如何解决这个问题吗?

最佳答案

logits是通常赋予网络输出的名称,这些是您的预测。 [32, 10]的尺寸告诉我您的批次大小为 32 和 10 个输出,这在 mnist 中很常见,因为您似乎正在使用它。

您的标签尺寸[16, 10] ,也就是说,您提供了 16 个大小为 10 的标签/向量。您提供的标签数量与网络的输出冲突,它们应该相同。

我不太清楚你对输入中的额外维度做了什么,但我想你一定不小心以某种方式将样本加倍。也许[500, 2, 784]形状正在 reshape 为 [1000, 784]沿途自动在某个地方,然后与 500 个标签不匹配。另外,您的 self.y应形[500, 10]不是,[500, 1] ,您的标签需要采用 one-hot 编码格式。例如。单个形状标签 [1, 10]数字 3 将是 [[0,0,0,1,0,0,0,0,0,0,0]] , 不是以数字表示,例如[3]因为您似乎在此处的健全性测试中设置了它。

关于python - logits 和标签必须是可广播的 : logits_size=[32, 1] labels_size=[16,1],我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52977700/

相关文章:

tensorflow - 如何分析在 tf-serving 上运行的 tensorflow 模型?

python - 迭代不同的数据框

python - Pandas:从其他 Dataframe 中的某些列创建新 Dataframe

python - Tensorflow 多个 session.run( ) 在同一次迭代中

python - 将 batchnorm(TensorFlow) 的 is_training 变为 False

machine-learning - 我应该在哪里将 dropout 应用于卷积层?

python - 模块未找到错误 : No module named 'tensorflow.contrib' ; 'tensorflow' is not a package

python - Pandas :时间戳系列中的唯一天数

python - 嵌入 libPython 时重定向 stdout 和 stderr

python - 添加 PyCharm 中未显示的远程选项