python - 拟合模型时如何在keras/tensorflow中使用多线程?

标签 python python-3.x tensorflow keras

我有一个 20 核的 CPU,我正在尝试使用所有核来拟合模型。我设置了 tfintra_op_parallelism_threads=20 的 session 并调用model.fit在同一tf session 。

python 进程利用 2000% CPU(如 top 所述)。但是,当将以下代码与单核配置 (intra_op_parallelism_threads=1) 进行比较时,我得到了相同的学习率。

from keras.layers import Dense, Activation, Dropout
from keras.layers import Input, Conv1D
import numpy as np
from keras.layers.merge import concatenate
from keras.models import Model

import tensorflow as tf
from keras.backend import tensorflow_backend as K

with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=20)) as sess:
        K.set_session(sess)

        size=20
        batch_size=16
        def xor_data_generator():
                while True:
                        data1 = np.random.choice([0, 1], size=(batch_size, size,size))
                        data2 = np.random.choice([0, 1], size=(batch_size, size,size))
                        labels  = np.bitwise_xor(data1, data2)
                        yield ([data1, data2], np.array(labels))

        a = Input(shape=(size,size))
        b = Input(shape=(size,size))
        merged = concatenate([a, b])
        hidden = Dense(2*size)(merged)
        conv1 = Conv1D(filters=size*16, kernel_size=1, activation='relu')(hidden)
        hidden = Dropout(0.1)(conv1)
        outputs = Dense(size, activation='sigmoid')(hidden)

        model = Model(inputs=[a, b], outputs=outputs)
        model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
        model.fit_generator(xor_data_generator(), steps_per_epoch=batch_size, epochs=10000) 

请注意,我不能使用 multi_gpu_model ,因为我的系统只有 20 个 CPU 内核。

如何分发 model.fit_generator(xor_data_generator(), steps_per_epoch=batch_size, epochs=10000)同时在不同的内核上?

最佳答案

看看 Keras 的 Sequence object编写您的自定义生成器。它是生成图像数据的 ImageDataGenerator 的基础对象。文档包含您可以修改的样板代码。如果你使用它,你可以将 fit.generator()use_multiprocessing 参数设置为 True。另见 this回答。

关于python - 拟合模型时如何在keras/tensorflow中使用多线程?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52180645/

相关文章:

python - FastAPI 测试客户端在 POST 或 PUT 时重定向请求

python - “NoneType”对象在python小程序中没有属性 'get'

python - 如何使用shutil.copy修复Python 3中的 "FileNotFoundError: [Errno 2]"

tensorflow - 超出使用限制后,如何在 Google Colab 上再次使用 GPU?

python - TensorFlow matmul : Blas xGEMMBatched launch failed

python - 为什么 keras 中的自定义图像生成器会出现错误 "object cannot be interpreted as an integer"?

python - 使用pyxl设置列宽时出现关键错误

python - 扩展方法时将局部变量保留在作用域内

python-3.x - 打印在 map 中时不打印,Python

Python - 如果数字大于 0,则运行平均值