python - 多标签分类: How to learn threshold values?

标签 python tensorflow deep-learning

我有一个深度 CNN,可以很好地进行多类分类。我想“升级”挑战并针对多标签分类问题对其进行训练。

为此,我用 sigmoid 替换了 softmax,并尝试训练我的网络以最小化:

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred)

但我最终得到了奇怪的预测:

Prediction for Im1 : [ 0.59275776  0.08751075  0.37567005  0.1636796   0.42361438  0.08701646

0.38991812 0.54468459 0.34593087 0.82790571]

Prediction for Im1 : [ 0.52609032  0.07885984  0.45780018  0.04995904  0.32828355  0.07349177

0.35400775 0.36479294 0.30002621 0.84438241]

Prediction for Im1 : [ 0.58714485  0.03258472  0.3349618   0.03199361  0.54665488  0.02271551

0.43719986 0.54638696 0.20344526 0.88144571]

因此,我想尝试为每个类别设置网络学习阈值,以确定样本是否属于该类别。

所以我将其添加到我的代码中:

initial = tf.truncated_normal([numberOfClasses], stddev=0.1)
W_thresh = tf.Variable(initial)

y_predict_thresh = int(y_predict > W_thresh)

但是我有一个错误:

TypeError: int() argument must be a string or a number, not 'Tensor'.

任何人有任何想法可以帮助我前进(如何避免这个错误?,我的数据集真的不平衡这一事实会导致这些“恒定”预测吗?关于多标签分类的其他建议?,...) ?

谢谢

编辑:我刚刚意识到,对于反向传播来说,进行阈值处理可能并不是很酷:/

最佳答案

不知道你是否还需要它,但是你可以使用tensorflow转换函数tf.to_int32、tf.to_int64。在评估之前,您的表达式是 python 的对象,因此它不能简单地将其转换为 int()

这可以满足您的需要:

with tf.Session() as sess:
    check = sess.run([tf.to_int64(W1 > W2)])

关于python - 多标签分类: How to learn threshold values?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44265934/

相关文章:

python - sklearn 中的分类树给出不一致的答案

python - 为什么运行 python 脚本会导致 Apache 出错

python - estimator.predict 引发 "ValueError: None values not supported"

machine-learning - 验证和测试准确性差异很大

machine-learning - 在测试 MNIST 时,caffe 测试错误没有名为 "net"的字段

python - 带有列表元素 : split, pad 的 Pandas 数据框

python - 如何从矩阵中选择特定的列索引?

python - 使用 Tensorflow contrib keras 时导入语句

python - 如何解释 model.predict 返回的结果?

python - 图像批处理中的随机补丁