python - GRU 和 RNN 实现之间的不一致

标签 python tensorflow recurrent-neural-network

我正在尝试使用 Tensorflow 实现一些自定义 GRU 单元。我需要堆叠这些单元格,并且我想继承 tensorflow.keras.layers.GRU 。但是,在查看源代码时,我注意到只能将 units 参数传递给 GRU__init__,而 RNN有一个参数,它是一个 RNNcell 列表,并利用它来堆叠那些调用 StackedRNNCells 的单元。同时,GRU仅创建一个GRUCell

对于我试图实现的论文,我实际上需要堆栈GRUCell。为什么RNNGRU的实现不同?

最佳答案

在搜索这些类的文档以添加链接时,我注意到一些可能会让您困惑的事情:(目前,就在官方 TF 2.0 发布之前)两个 GRUCell在 TensorFlow 中的实现!有一个 tf.nn.rnn_cell.GRUCell 和一个 tf.keras.layers.GRUCell 。看起来像 tf.nn.rnn_cell 中的那个已弃用,您应该使用 Keras。

据我所知,GRUCell有相同的 __call__()方法签名为 tf.keras.layers.LSTMCell tf.keras.layers.SimpleRNNCell ,并且它们都继承自 Layer RNN 文档给出了有关 __call__() 的一些要求您传递给其的对象的方法 cell论证必须这样做,但我的猜测是,所有这三个都应该满足这些要求。您应该能够使用相同的 RNN框架并向其传递 GRUCell 的列表对象而不是 LSTMCellSimpleRNNCell .

我现在无法测试这个,所以我不确定你是否传递了 GRUCell 的列表物体或只是 GRU 对象进入RNN ,但我认为其中之一应该有效。

关于python - GRU 和 RNN 实现之间的不一致,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55831278/

相关文章:

python - 除周末 Python 之外的日期时间序列

python - 在 TensorFlow 中使用循环填充矩阵

python - 当您的模型不能过度拟合一小部分数据时,这意味着什么?

neural-network - 使用 Keras 的 LSTM 网络中的验证损失和准确性

python - seq-to-seq LSTM 在低频简单正弦波上的性能不佳

Python - PayPal 定期付款 - 未显示协议(protocol)详细信息

python - 组合 python 列表元素,其中值为 1 加上偏移量(屏蔽)

python - 列表元素大小写转换

python-2.7 - tf.transpose 是否也会更改内存(如 np.ascontiguousarray)?

tensorflow - 如何将图像传递给模型以在 Tensorflow 中进行分类