python - 我怎样才能加快这个 Keras 注意力计算?

标签 python tensorflow keras vectorization

我已经为 AttentiveLSTMCellAttentiveLSTM(RNN) 编写了自定义 keras 层,以符合 keras 的 RNN 方法。这种注意机制由 Bahdanau 描述。其中,在编码器/解码器模型中,“上下文”向量是根据编码器的所有输出和解码器的当前隐藏状态创建的。然后,我在每个时间步将上下文向量附加到输入。

该模型用于制作对话代理,但在架构(类似任务)上与 NMT 模型非常相似。

但是,在添加这种注意力机制后,我的网络训练速度减慢了 5 倍,我真的很想知道如何以更高效的方式编写让速度减慢这么多的代码部分方法。

主要的计算在这里完成:

h_tm1 = states[0]  # previous memory state
c_tm1 = states[1]  # previous carry state

# attention mechanism

# repeat the hidden state to the length of the sequence
_stm = K.repeat(h_tm1, self.annotation_timesteps)

# multiplty the weight matrix with the repeated (current) hidden state
_Wxstm = K.dot(_stm, self.kernel_w)

# calculate the attention probabilities
# self._uh is of shape (batch, timestep, self.units)
et = K.dot(activations.tanh(_Wxstm + self._uh), K.expand_dims(self.kernel_v))

at = K.exp(et)
at_sum = K.sum(at, axis=1)
at_sum_repeated = K.repeat(at_sum, self.annotation_timesteps)
at /= at_sum_repeated  # vector of size (batchsize, timesteps, 1)

# calculate the context vector
context = K.squeeze(K.batch_dot(at, self.annotations, axes=1), axis=1)

# append the context vector to the inputs
inputs = K.concatenate([inputs, context])

AttentiveLSTMCellcall 方法中(一个时间步)。

可以找到完整的代码here .如果有必要我提供一些数据和与模型交互的方式,那么我可以这样做。

有什么想法吗?当然,如果这里有什么聪明的话,我会在 GPU 上进行训练。

最佳答案

我建议使用 relu 而不是 tanh 来训练您的模型,因为此操作的计算速度要快得多。这将节省您的计算时间,顺序为训练示例 * 每个示例的平均序列长度 * 轮数。

另外,我会评估附加上下文向量的性能改进,请记住,这会减慢您在其他参数上的迭代周期。如果它没有给您带来很大改进,则可能值得尝试其他方法。

关于python - 我怎样才能加快这个 Keras 注意力计算?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49175581/

相关文章:

apache-spark - 将 Tensorflow 模型的预测输出保存到 hdfs 文件中

tensorflow - 如果我使用 Tensorflow 对象检测 API 调整图像大小,bbox 是否也会自动调整大小?

r - 了解 R 中 rnn 模型的 Keras 预测输出

machine-learning - keras 结合预训练模型

python - 文件/目录删除时提示

python - 如何在 google colab 中导入自定义模块?

python - () 与 [] 与 {} 之间有什么区别?

python - 检查一个字符串的字符是否按字母顺序升序并且它的升序是均匀间隔的python

python - 使用 load_model 加载经过 tensorflow.keras 训练的模型返回 JSON 解码错误,而未经训练的模型正常加载

python - 深度学习: Validation Loss Fluctuates Wildly Yet Training Loss is Stable