代码摘自Tensorflow tutorial 。该函数在 MNIST 数据集(0-9 的手写图片数据集)上运行操作。为什么要将标签转换为 int64
,我认为 int32
就足够了。
def loss(logits,labels):
labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits,labels,name='xentropy')
loss = tf.reduce_mean(cross_entropy,name='xentropy_mean')
return loss
最佳答案
这个documentation说它可以是 int32
或 int64
。因此,您可以选择其中之一。在这里,他们更愿意选择 int64
。
引用文档:
labels
: Tensor of shape[d_0, d_1, ..., d_{r-2}]
and dtypeint32
orint64
. Each entry inlabels
must be an index in[0, num_classes)
. Other values will raise an exception when this op is run on CPU, and returnNaN
for corresponding corresponding loss and gradient rows on GPU.
关于python - 为什么我们在 tensorflow 的损失函数中需要 `int64` 作为 MNIST 标签?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41174294/