python - 如何设置 keras 自定义指标仅在纪元结束时调用?

标签 python tensorflow keras

我正在尝试为我的神经网络使用自定义指标,并且该指标只能在纪元结束时进行评估。我遇到的问题是,每批都会评估指标,这不是想要的行为。请注意,我正在使用发电机和 fit_generator与keras。

validation_data 加载了一个实现 keras.utils.Sequence 的生成器

class DataGenerator(keras.utils.Sequence): 
   def __init__(self, inputs, labels, batch_size):
    self.inputs = inputs
    self.labels = labels
    self.batch_size = batch_size

   def __getitem__(self, index):
    #some processing done here
    return batch_inputs, batch_labels

   def __len__(self):
    return int(np.floor(len(self.inputs) / self.batch_size))

我尝试实现 keras 文档建议的内容,但没有找到任何信息来指定该指标仅应在纪元结束时使用。

def auc_roc(y_true, y_pred):
   auc, up_opt = tf.metrics.auc(y_true, y_pred)
   K.get_session().run(tf.local_variables_initializer())
   with tf.control_dependencies([up_opt]):
       auc = tf.identity(auc)
   return auc

所以现在 auc_roc在每个批处理之后调用,而不是在 epoch 末尾调用一次.

最佳答案

from sklearn.metrics import roc_auc_score
from keras.callbacks import Callback

class IntervalEvaluation(Callback):
    def __init__(self, validation_data=(), interval=10):
        super(Callback, self).__init__()

        self.interval = interval
        self.X_val, self.y_val = validation_data

    def on_epoch_end(self, epoch, logs={}):
        if epoch % self.interval == 0:
            y_pred = self.model.predict_proba(self.X_val, verbose=0)
            score = roc_auc_score(self.y_val, y_pred)
            print("interval evaluation - epoch: {:d} - score: {:.6f}".format(epoch, score))

用法:

ival = IntervalEvaluation(validation_data=(x_test2, y_test2), interval=1)

更多信息: http://digital-thinking.de/keras-three-ways-to-use-custom-validation-metrics-in-keras/

关于python - 如何设置 keras 自定义指标仅在纪元结束时调用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54634435/

相关文章:

tensorflow - Pytorch相当于tf.map_fn与parallel_iterations?

python - 为什么 SparseCategoricalCrossentropy 不适用于此机器学习模型?

python - 如何在 Tensorflow 中结合 feature_columns、model_to_estimator 和数据集 API

python - 安装后无法导入keras

python - 关于keras.utils.Sequence的澄清

python-3.x - 如何在Python 3中对大量文本进行分类?

python - 如何从二进制文件中读取信息

python - 更改一系列数据框中元素的值python

python - 使用 Numeric Python 的数组的逐元素中值和百分位数

python - Scikit-learn微调: Postprocess predicted labels before evaluation