python - keras中model.compile的参数 'weighted_metrics'和model.fit_generator的参数 'class_weight'之间的区别?

标签 python keras deep-learning image-recognition

在训练用于图像分类的 keras 模型(来自 DOG BREED IDENTIFICATION 数据集的 120 个类,KAGGLE)时,我需要使用我在某处读到的类权重来平衡类,在示例中我看到人们使用 fit_generator 的参数 class_weight。但我在 model.compile 中发现了另一个参数,weighted_metrics,其在文档中的描述是:“在训练和测试期间由sample_weight或class_weight评估和加权的指标列表”。我要用这个吗?请举例说明该参数的用途。

#Calculating Class weights
counter = Counter(train_generator.classes)
max_value = float(max(counter.values()))

CLASS_WEIGHTS = {classid: max_value / num_occurences
                 for classid, num_occurences in counter.items()}
# Model Compile
model.compile(optimizer=Adam(lr=LR),
              loss=categorical_crossentropy,
              metrics=[categorical_accuracy],
              weighted_metrics=None) # <--------------- This parameter

STEPS_PER_EPOCH = train_generator.n//train_generator.batch_size
VAL_STEPS = val_generator.n//val_generator.batch_size

model.fit_generator(train_generator,
                    steps_per_epoch=STEPS_PER_EPOCH,
                    epochs=EPOCHS,
                    callbacks=callback_list,
                    verbose=1,
                    class_weight=CLASS_WEIGHTS,
                    validation_data=val_generator,
                    validation_steps=VAL_STEPS) # USED CLASS_WEIGHTS HERE

最佳答案

是的,您可以将它们用于不平衡的数据集。

weighted_metrics

是考虑到的指标列表

class_weights

您传入 fit_generator。

因此在您的示例中,您可以设置

weighted_metrics=['accuracy']

class_weight = {0:3, 1:4}

weighted_metrics 参数的目的是提供一个指标列表,该列表将考虑您在 fit_generator 中传递的 class_weights。

关于python - keras中model.compile的参数 'weighted_metrics'和model.fit_generator的参数 'class_weight'之间的区别?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56868174/

相关文章:

python - 如何使用 Python 的 timeit 为代码段计时以测试性能?

python - 获取分割图像的最大连通分量

python - 与词法作用域和 for 循环作斗争

python - 如何分隔 csv 列中的值以用作 tf.contrib.learn.DNNRegressor 中的稀疏列_with_integerized_feature 中的多价特征列

python - 雪茄 - 10/Unpickle

python - 使用 Pandas 读取 CSV 并处理评论

machine-learning - Keras中的initial_epoch是什么意思?

python - 使用 Keras 进行贪婪分层训练

python - 如何保存 MNIST 在 tensorflow 上测试的训练数据权重以供将来使用?

python - 在 `tf.estimator` 中,如何在训练结束时(不是每次迭代时)将变量设置为 `tf.assign` ?