python - Keras 模型永远用 dask 数据框进行训练

标签 python dataframe keras large-data dask

我正在处理内存不足的大型数据集,我被介绍了 Dask 数据框。我从文档中了解到 Dask 不会将整个数据集加载到内存中。相反,它创建了多个线程,这些线程将按需从磁盘中获取记录。所以我假设批量大小 = 500 的 keras 模型,在训练时它的内存中应该只有 500 条记录。但是当我开始训练时。这需要永远。可能是我做错了什么。请提出建议。

训练数据的形状:1000000 * 1290

import glob
import dask.dataframe
paths_train = glob.glob(r'x_train_d_final*.csv')

X_train_d = dd.read_csv('.../x_train_d_final0.csv')
Y_train1 = keras.utils.to_categorical(Y_train.iloc[,1], num_classes)
batch_size = 500
num_classes = 2
epochs = 5

model = Sequential()
model.add(Dense(645, activation='sigmoid', input_shape=(1290,),kernel_initializer='glorot_normal'))
#model.add(Dense(20, activation='sigmoid',kernel_initializer='glorot_normal'))
model.add(Dense(num_classes, activation='sigmoid'))

model.compile(loss='binary_crossentropy',
          optimizer=Adam(decay=0),
          metrics=['accuracy'])

history = model.fit(X_train_d.to_records(), Y_train,
                batch_size=batch_size,
                epochs=epochs,
                verbose=1,
                class_weight = {0:1,1:6.5},
                shuffle=False)

最佳答案

您应该使用 Sequential model 中的 fit_generator()带发电机或带 Sequence实例。两者都提供了一种只加载一部分数据的正确方法。

Keras 文档提供了一个很好的例子:

def generate_arrays_from_file(path):
    while 1:
        f = open(path)
        for line in f:
            # create Numpy arrays of input data
            # and labels, from each line in the file
            x, y = process_line(line)
            yield (x, y)
        f.close()

model.fit_generator(generate_arrays_from_file('/my_file.txt'),
                    steps_per_epoch=1000, epochs=10)

关于python - Keras 模型永远用 dask 数据框进行训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47300713/

相关文章:

python - curve_fit 和 scipy.odr 的比较 - 绝对西格玛

keras - keras中的卷积层

python - 使用 Tensorflow 2 的 Keras 功能 API 时传递 `training=true`

python - pandas describe函数的统计意义是什么,如何使用?

r - 有没有办法将多列中的值转为列名?

python - LSTM 的 model.reset_states 会影响模型中的任何其他非 LSTM 层吗?

python - 从列表中删除单个字符

python - 如何使用 StreamHandler 捕获记录器 stderr 上的输出?

python - 将稀疏csv文件读入pandas

R:在表格中将一些行按一列移动