我有一个深度 CNN,可以很好地进行多类分类。我想“升级”挑战并针对多标签分类问题对其进行训练。
为此,我用 sigmoid 替换了 softmax,并尝试训练我的网络以最小化:
tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred)
但我最终得到了奇怪的预测:
Prediction for Im1 : [ 0.59275776 0.08751075 0.37567005 0.1636796 0.42361438 0.08701646
0.38991812 0.54468459 0.34593087 0.82790571]
Prediction for Im1 : [ 0.52609032 0.07885984 0.45780018 0.04995904 0.32828355 0.07349177
0.35400775 0.36479294 0.30002621 0.84438241]
Prediction for Im1 : [ 0.58714485 0.03258472 0.3349618 0.03199361 0.54665488 0.02271551
0.43719986 0.54638696 0.20344526 0.88144571]
因此,我想尝试为每个类别设置网络学习阈值,以确定样本是否属于该类别。
所以我将其添加到我的代码中:
initial = tf.truncated_normal([numberOfClasses], stddev=0.1)
W_thresh = tf.Variable(initial)
y_predict_thresh = int(y_predict > W_thresh)
但是我有一个错误:
TypeError: int() argument must be a string or a number, not 'Tensor'.
任何人有任何想法可以帮助我前进(如何避免这个错误?,我的数据集真的不平衡这一事实会导致这些“恒定”预测吗?关于多标签分类的其他建议?,...) ?
谢谢
编辑:我刚刚意识到,对于反向传播来说,进行阈值处理可能并不是很酷:/
最佳答案
不知道你是否还需要它,但是你可以使用tensorflow转换函数tf.to_int32、tf.to_int64
。在评估之前,您的表达式是 python 的对象,因此它不能简单地将其转换为 int()
。
这可以满足您的需要:
with tf.Session() as sess:
check = sess.run([tf.to_int64(W1 > W2)])
关于python - 多标签分类: How to learn threshold values?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44265934/