python - 并行 LSTM 分别处理输入的不同部分

标签 python tensorflow keras neural-network

需要对这段代码进行哪些更改才能获得如图所示的模型?

        model = Sequential()
        model.add(LSTM(256, input_shape=(self.timestamps, len(columns)), activation=keras.activations.tanh,
                       recurrent_activation=keras.activations.tanh))
        model.add(Dense(6, activation=keras.activations.tanh))
        model.summary()
        model.compile(loss=keras.losses.mean_squared_error, optimizer=keras.optimizers.RMSprop())

stdScore

例如,如果该模型的批量输入形状为(10,30,6),那么我希望input[:,:15,:]流入左侧 LSTM,input[:,15:,:] 流入右侧 LSTM。它是如何完成的?

最佳答案

首先,我建议使用 Keras Functional API 。它倾向于简化模型定义。

如果您想定义模型的两个输入,您可以这样定义模型:

from keras.layers import Input, LSTM, concatenate, Dense
from keras.models import Model

input_1 = Input(shape=(15, 6), name='input_1')
input_2 = Input(shape=(15, 6), name='input_2')
lstm1 = LSTM(256, name='lstm1')(input_1)
lstm2 = LSTM(256, name='lstm2')(input_2)
concat = concatenate([lstm1, lstm2]) 
output = Dense(6, activation='tanh', name='dense')(concat)
model = Model(inputs=[input_1, input_2], outputs=output)

如果您不想指定多个输入,则可以仅使用 Lambda 层来拆分输入:

from keras.layers import Input, LSTM, concatenate, Dense, Lambda
from keras.models import Model

input_ = Input(shape=(30, 6), name='input')
input_1 = Lambda(lambda x: x[:, :15, :])(input)
input_2 = Lambda(lambda x: x[:, 15:, :])(input)
lstm1 = LSTM(256, name='lstm1')(input_1)
lstm2 = LSTM(256, name='lstm2')(input_2)
concat = concatenate([lstm1, lstm2]) 
output = Dense(6, activation='tanh', name='dense')(concat)
model = Model(inputs=input_, outputs=output)

您将分别为每个示例调用 fit 函数,如下所示:

<小时/>

多个输入:

model.fit(x=[input_1, input_2], y=y)

或者

model.fit(x={'input_1': input_1, 'input_2': input_2}, y=y)

<小时/>

单输入:

model.fit(x=input_, y=y)

关于python - 并行 LSTM 分别处理输入的不同部分,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54875137/

相关文章:

python - 为什么这个错误 "list index out of bound"?

java - TensorFlow:在 Android 中的推理之间初始化 RNN 状态

python - 在同一函数中的子进程调用之前运行 python 命令?

python - 如何将数据从numpy数组复制到另一个

python - OpenCV 使用不对应点的 solvePnPRansac 找到对象的位置

python - 等效于 csv 文件的 Keras ImageDataGenerator

python - Keras 输入形状,输入列表的简单数组

python - 用于 conv2d 和手动加载图像的 Keras input_shape

python - 如何将 python 生成器更改为 Keras Sequence 对象?

python - 将 Keras 和 Tensorflow 与 AMD GPU 结合使用