python - 在单个 GPU 上并行拟合多个 Keras 模型

标签 python tensorflow multiprocessing keras

我正在尝试在单个 GPU 上并行安装多个小型 Keras 模型。由于某些原因,我需要将它们从列表中移除并一次一步地训练它们。由于我对标准多处理模块并不幸运,所以我使用了 pathos。

我尝试做的是这样的:

from pathos.multiprocessing import ProcessPool as Pool
import tensorflow as tf
import keras.backend as K

def multiprocess_step(self, model):
    K.set_session(sess)
    with sess.graph.as_default():
        model = step(model, sess)
        return model

def step(model, sess):
    K.set_session(sess)
    with sess.graph.as_default():
        model.fit(x=data['X_train'], y=data['y_train'],
               batch_size=batch_size
               validation_data=(data['X_test'], data['y_test']), 
               verbose=verbose,
               shuffle=True,
               initial_epoch=self.step_num - 1)
        return model

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = "0"
sess = tf.Session(config=config)

K.set_session(sess)
with sess.graph.as_default():
    pool = Pool(8).map
    model_list = pool(multiprocess_step, model_list)

但无论我尝试什么,我总是收到一个错误,声称模型似乎不在同一个图表上......

ValueError: Tensor("training/RMSprop/Variable:0", shape=(25, 352), dtype=float32_ref) 必须来自与 Tensor("RMSprop/rho/read:0", shape=(), dtype=float32).

异常源自 model.fit() 行,所以我一定是在 session 图的分配上做错了什么,尽管我试图在每个可能的位置设置它?

有没有人有类似的经历?

最佳答案

Keras issue tracker 上提出了以下建议.与使用多处理相比,我不确定该方法的相对优点。

in_1 = Input()
lstm_1 = LSTM(...)(in_1)
out_1 = Dense(...)(lstm_1)

in_2 = Input()
lstm_2 = LSTM(...)(in_2)
out_2 = Dense(...)(lstm_2)

model_1 = Model(input=in_1, output=out_1)
model_2 = Model(input=in_2, output=out_2)

model = Model(input = [in_1, in_2], output = [out_1, out_2])
model.compile(...)
model.fit(...)

model_1.predict(...)
model_2.predict(...)

关于python - 在单个 GPU 上并行拟合多个 Keras 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47885357/

相关文章:

python - 在 django admin 上使用 list_editable 时不显示文本字段

python - 如何从Python数据框中的DateTimeIndex中删除微秒?

python - 何时以及如何使用 Tornado?什么时候没用?

在主进程中从标准输入进行阻塞读取时,Python 子进程阻塞

python - 尝试覆盖 QTreeView.edit() 时出现最大递归错误

tensorflow - 如何在 Tensorboard 中隐藏 "Cond"张量? (附截图)

tensorflow - 与 keras 进行注意卷积

python - 对象检测 API 的 Tensorflow ConcatOp 错误

Python AIOHTTP.web 服务器多处理负载均衡器?

c - 在核心之间共享数据的最有效方式