我尝试使用 Keras 构建 3 层 RNN。部分代码在这里:
model = Sequential()
model.add(Embedding(input_dim = 91, output_dim = 128, input_length =max_length))
model.add(GRUCell(units = self.neurons, dropout = self.dropval, bias_initializer = bias))
model.add(GRUCell(units = self.neurons, dropout = self.dropval, bias_initializer = bias))
model.add(GRUCell(units = self.neurons, dropout = self.dropval, bias_initializer = bias))
model.add(TimeDistributed(Dense(target.shape[2])))
然后我遇到了这个错误:
call() missing 1 required positional argument: 'states'
错误详情如下:
~/anaconda3/envs/hw3/lib/python3.5/site-packages/keras/models.py in add(self, layer)
487 output_shapes=[self.outputs[0]._keras_shape])
488 else:
--> 489 output_tensor = layer(self.outputs[0])
490 if isinstance(output_tensor, list):
491 raise TypeError('All layers in a Sequential model '
~/anaconda3/envs/hw3/lib/python3.5/site-packages/keras/engine/topology.py in __call__(self, inputs, **kwargs)
601
602 # Actually call the layer, collecting output(s), mask(s), and shape(s).
--> 603 output = self.call(inputs, **kwargs)
604 output_mask = self.compute_mask(inputs, previous_mask)
605
最佳答案
不要直接在 Keras 中使用 Cell 类(即
GRUCell
或LSTMCell
)。它们是由相应层包裹的计算单元。而是使用 Layer 类(即GRU
或LSTM
):model.add(GRU(units = self.neurons, dropout = self.dropval, bias_initializer = bias)) model.add(GRU(units = self.neurons, dropout = self.dropval, bias_initializer = bias)) model.add(GRU(units = self.neurons, dropout = self.dropval, bias_initializer = bias))
LSTM
和GRU
使用它们对应的单元格在所有时间步长上执行计算。阅读此 SO answer了解更多关于它们的区别。当您将多个 RNN 层堆叠在一起时,您需要将它们的
return_sequences
参数设置为True
以生成每个时间步长的输出,它又被下一个 RNN 层使用。请注意,您可能会也可能不会在最后一个 RNN 层上执行此操作(这取决于您的架构和您要解决的问题):model.add(GRU(units = self.neurons, dropout = self.dropval, bias_initializer = bias, return_sequences=True)) model.add(GRU(units = self.neurons, dropout = self.dropval, bias_initializer = bias, return_sequences=True)) model.add(GRU(units = self.neurons, dropout = self.dropval, bias_initializer = bias))
关于python - Keras GRUCell 缺少 1 个必需的位置参数 : 'states' ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51254706/