python - 使用 Tensorflow CNN 分类器获取精度和召回值

标签 python neural-network tensorflow classification precision-recall

我想知道是否有一个简单的解决方案来获取分类器类别的召回率和精度值?

为了说明一些背景,我在 Denny Britz 代码的帮助下使用 Tensorflow 实现了 20 类 CNN 分类器:https://github.com/dennybritz/cnn-text-classification-tf .

正如您在 text_cnn.py 末尾所看到的,他实现了一个简单的函数来计算全局精度:

# Accuracy
        with tf.name_scope("accuracy"):
            correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
            self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")

关于如何做类似的事情来获得不同类别的召回率和精确度值有什么想法吗?

也许我的问题听起来很愚蠢,但说实话我对此有点迷失。感谢您的帮助。

最佳答案

使用 tf.metrics 对我来说很有效:

#define the method
x = tf.placeholder(tf.int32, )
y = tf.placeholder(tf.int32, )
acc, acc_op = tf.metrics.accuracy(labels=x, predictions=y)
rec, rec_op = tf.metrics.recall(labels=x, predictions=y)
pre, pre_op = tf.metrics.precision(labels=x, predictions=y)

#predict the class using your classifier
scorednn = list(DNNClassifier.predict_classes(input_fn=lambda: input_fn(testing_set)))
scoreArr = np.array(scorednn).astype(int)

#run the session to compare the label with the prediction
sess=tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
v = sess.run(acc_op, feed_dict={x: testing_set["target"],y: scoreArr}) #accuracy
r = sess.run(rec_op, feed_dict={x: testing_set["target"],y: scoreArr}) #recall
p = sess.run(pre_op, feed_dict={x: testing_set["target"],y: scoreArr}) #precision

print("accuracy %f", v)
print("recall %f", r)
print("precision %f", p)

结果:

accuracy %f 0.686667
recall %f 0.978723
precision %f 0.824373

注意:为了准确性,我会使用:

accuracy_score = DNNClassifier.evaluate(input_fn=lambda:input_fn(testing_set),steps=1)["accuracy"]

因为它更简单并且已经在评估中计算。

如果你不想要累积结果,也可以调用variables_initializer。

关于python - 使用 Tensorflow CNN 分类器获取精度和召回值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38978073/

相关文章:

python - 用 Python 计算 Keras 神经网络的准确性

python - 错误 : tensorboard 2. 0.2 要求 setuptools>=41.0.0,但您将拥有不兼容的 setuptools 40.6.2

tensorflow - 如何在不保存检查点的情况下运行 estimator.train

python - "dest_directory = FLAGS.model_dir"是什么意思?

python - Google API - 电子表格

python - 计算 pandas 中不重要的行数

python - 如何在没有交叉验证的情况下检查机器学习的准确性

python - 如何将我的训练数据输入到这个神经网络中

Python 相当于 R 数据框

Python __init__ 不执行变量