python - 用于大型数据集的 sklearn utils compute_class_weight 函数

标签 python tensorflow machine-learning scikit-learn data-science

我正在大约 20+ GB 上训练 tensorflow keras 顺序模型postgres 数据库中基于文本的分类数据我需要为模型赋予类权重。
这就是我正在做的事情。

class_weights = sklearn.utils.class_weight.compute_class_weight('balanced', classes, y)

model.fit(x, y, epochs=100, batch_size=32, class_weight=class_weights, validation_split=0.2, callbacks=[early_stopping])

由于我无法将整个内容加载到内存中,我想我可以使用 fit_generator keras模型中的方法。

但是我怎么能计算类权重 在这个数据上? sklearn 不为此提供任何特殊功能,它是这个 的正确工具吗? ?

我想在多个 上进行操作随机样本但是有没有更好的方法 全部数据可以用吗?

最佳答案

您可以使用生成器,也可以计算类权重。

假设你有这样的发电机

train_generator = train_datagen.flow_from_directory(
        'train_directory',
        target_size=(224, 224),
        batch_size=32,
        class_mode = "categorical"
        )

训练集的类权重可以这样计算
class_weights = class_weight.compute_class_weight(
           'balanced',
            np.unique(train_generator.classes), 
            train_generator.classes)

[编辑 1]
由于您在评论中提到了 postgres sql,因此我在此处添加了原型(prototype)答案。

首先使用来自 postgres sql 的单独查询获取每个类的计数,并使用它来计算类权重。你可以手动计算。基本逻辑是权重最小的类的计数取值为 1,其余类根据与权重最小的类的相对计数得到 <1。

例如,您有 3 个类别 A、B、C 和 100,200,150,然后类别权重变为 {A:1,B:0.5,C:0.66}

从 postgres sql 获取值后,让我们手动计算它。

[询问]
cur.execute("SELECT class, count(*) FROM table group by classes order by 1")
rows = cur.fetchall()

上面的查询将返回具有从最低到最高排序的元组(类名,每个类的计数)的行。

然后下面的代码将创建类权重字典
class_weights = {}
for row in rows:
    class_weights[row[0]]=rows[0][1]/row[1] 
    #dividing the least value the current value to get the weight, 
    # so that the least value becomes 1, 
    # and other values becomes < 1

关于python - 用于大型数据集的 sklearn utils compute_class_weight 函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60408901/

相关文章:

Tor 的 Python 脚本异常

json - tensorflow 服务 : Expects arg[0] to be float but string is provided

tensorflow - 难以理解tf.contrib.seq2seq.TrainingHelper

machine-learning - 如何禁用 tensorflow 中特定层的动量?

python - 将输入语句连接到python中的函数

python - Keras 准确率没有提高超过 50%

python - 如何使用 Selenium 提取 src 属性的值

tensorflow - 得到形状 [4575, 32, 32, 3],但想要 [4575] Tensorflow

r - 使用 RNN 预测多元时间序列

machine-learning - sklearn 中的哪些预测模型受训练数据框中列顺序的影响?