python - LSTM 上的 Keras 注意力层

标签 python keras lstm

我正在使用 keras 1.0.1 我正在尝试在 LSTM 之上添加一个注意力层。这是我目前所拥有的,但它不起作用。

input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = TimeDistributed(Dense(1)(lstm))
att = Reshape((-1, input_length))(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
merge = Merge([att, lstm], "mul")
hid = Merge("sum")(merge)

last = Dense(self.HID_DIM, activation="relu")(hid)

网络应在输入序列上应用 LSTM。然后应该将 LSTM 的每个隐藏状态输入到一个全连接层,在该层上应用 Softmax。为每个隐藏维度复制 softmax,并逐元素乘以 LSTM 隐藏状态。然后应该对生成的向量进行平均。

编辑:这可以编译,但我不确定它是否按照我认为它应该做的去做。

input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = TimeDistributed(Dense(1))(lstm)
att = Flatten()(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
att = Permute((2,1))(att)
mer = merge([att, lstm], "mul")
hid = AveragePooling1D(pool_length=input_length)(mer)
hid = Flatten()(hid)

最佳答案

您分享的第一段代码不正确。除了一件事,第二段代码看起来是正确的。不要使用 TimeDistributed,因为权重将相同。使用具有非线性激活的常规 Dense 层。


    input_ = Input(shape=(input_length, input_dim))
    lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
    att = Dense(1, activation='tanh')(lstm_out )
    att = Flatten()(att)
    att = Activation(activation="softmax")(att)
    att = RepeatVector(self.HID_DIM)(att)
    att = Permute((2,1))(att)
    mer = merge([att, lstm], "mul")

现在您有了权重调整状态。你如何使用它取决于你。我见过的大多数 Attention 版本,只需在时间轴上将它们相加,然后将输出用作上下文。

关于python - LSTM 上的 Keras 注意力层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36812351/

相关文章:

python - key 错误 : "Couldn' t delete link (Can't delete self)"

python - TkInter, slider : how to trigger the event only when the iteraction is complete?

python - 对列进行排序并选择每组 pandas 数据框中的前 n 行

python - 使用linux终端修改python var

python - LSTM 网络在几次迭代后开始生成垃圾

python - 使用 Keras 绘制学习率的准确度时,Matplotlib 返回空图

python - Django-Admin 站点的内置 TreeView ?

tensorflow - Tensorflow text_generation 教程中有状态 GRU 的误导性训练数据混洗

tensorflow - 多维 lstm tensorflow

keras - LSTM 单元的数量与要训练的序列长度之间是否存在关系?