python - Keras 中具有掩蔽支持的均值或最大池化

标签 python masking keras lstm pooling

...
print('Build model...')
model = Sequential()
model.add(Embedding(max_features, 128))
model.add(LSTM(size, return_sequences=True, dropout_W=0.2 dropout_U=0.2)) 
model.add(GlobalAveragePooling1D())
model.add(Dense(1))
model.add(Activation('sigmoid'))
....

在将此均值或最大值向量提供给 Keras 中的密集层之前,我需要能够在 LSTM 层之后的样本中的所有时间步长中获取向量的平均值或最大值。

我认为 timedistributedmerge 能够做到这一点,但它已被弃用。使用 return_sequences=True 我可以获得 LSTM 层之后样本中所有时间步长的向量。但是,GlobalAveragePooling1D() 与掩蔽不兼容,它考虑所有时间步长,而我只需要非掩蔽时间步长。

我看到推荐 Lambda 层的帖子,但这些也没有考虑屏蔽。任何帮助将不胜感激。

最佳答案

Jacoxu 的回答是正确的。但是,如果您使用的是 keras 的 tensorflow 后端,则 Tensor 类型不支持 dimshuffle 函数,请改用此方法。

def call(self, x, mask=None):
    if mask is not None:
        # mask (batch, time)
        mask = K.cast(mask, K.floatx())
        # mask (batch, x_dim, time)
        mask = K.repeat(mask, x.shape[-1])
        # mask (batch, time, x_dim)
        mask = tf.transpose(mask, [0,2,1])
        x = x * mask
    return K.sum(x, axis=1) / K.sum(mask, axis=1)

关于python - Keras 中具有掩蔽支持的均值或最大池化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39510809/

相关文章:

python - Intellij/Pycharm 无法调试 Python 模块

python - 如何使用 Selenium 从搜索结果中提取 Google 链接的 href?

iphone - 如何屏蔽 UIView

c# - 相当于C#中VB6.0的MaskColor属性

c# - c# System.Drawing 中的 Alpha 掩蔽?

python - 从对象数组中获取所有对象的特定属性的数组

python - 无法连接 Keras Lambda 层

python - 如何使用经过训练的 Keras GRU 模型预测新的数据系列?

python - 回归问题的 Hyperas 损失函数

python - Python 的 itertools.repeat 的目的是什么?