我有两个类别:正类 (1) 和负类 (0)。
数据集非常不平衡,所以目前我的小批量大部分包含 0。事实上,许多批处理只包含 0。我想尝试为正例和负例设置单独的成本;请参阅下面的代码。
我的代码的问题是我收到了很多nan
,因为bound_index列表将为空。解决这个问题的优雅方法是什么?
def calc_loss_debug(logits, labels):
logits = tf.reshape(logits, [-1])
labels = tf.reshape(labels, [-1])
index_bound = tf.where(tf.equal(labels, tf.constant(1, dtype=tf.float32)))
index_unbound = tf.where(tf.equal(labels, tf.constant(0, dtype=tf.float32)))
entropies = tf.nn.sigmoid_cross_entropy_with_logits(logits, labels)
entropies_bound = tf.gather(entropies, index_bound)
entropies_unbound = tf.gather(entropies, index_unbound)
loss_bound = tf.reduce_mean(entropies_bound)
loss_unbound = tf.reduce_mean(entropies_unbound)
最佳答案
由于您有 0 和 1 标签,因此您可以使用这样的结构轻松避免 tf.where
labels = ...
entropies = ...
labels_complement = tf.constant(1.0, dtype=tf.float32) - labels
entropy_ones = tf.reduce_sum(tf.mul(labels, entropies))
entropy_zeros = tf.reduce_sum(tf.mul(labels_complement, entropies))
要获得平均损失,您需要除以批处理中 0 和 1 的数量,可以轻松计算为
num_ones = tf.reduce_sum(labels)
num_zeros = tf.reduce_sum(labels_complement)
当然,还是要防止batch中没有1的时候被0除。我建议使用 tf.cond(tf.equal(num_ones, 0), ...) 。
关于machine-learning - TensorFlow:如何实现二元分类的每类损失函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39894312/