python - RNN 分类器,如何显示权重以了解 neraul 网络如何做出决策

标签 python tensorflow recurrent-neural-network

我正在使用循环神经网络构建一个简单的分类器来将一段单词分类到不同的目录中。它有一个嵌入层,然后是 RNN,然后是 Dense 层,如下所示。

它可以正确预测,但除了预测之外,我怎么知道为什么 RNN 得到这个预测,例如,段落中每个单词的权重是多少。

是什么词让 RNN 认为它属于特定目录?

model = Sequential()

embedding_size = 300

model.add(Embedding(input_dim=num_words+1, output_dim=embedding_size, input_length=max_tokens, name='layer_embedding', weights=embedding_matrix],trainable=True))

return_sequences=True))

model.add(Bidirectional(GRU(32,return_sequences=True)))
model.add(Bidirectional(GRU(32,return_sequences=True)))
model.add(Bidirectional(GRU(32)))
model.add(Dense(numdense, activation='sigmoid'))

model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

最佳答案

This post on GitHub提出了一种在打印参数时查看参数名称的方法:

for e in zip(model.layers[0].trainable_weights, model.layers[0].get_weights()):
   print('Param %s:\n%s' % (e[0],e[1]))

关于python - RNN 分类器,如何显示权重以了解 neraul 网络如何做出决策,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50792130/

相关文章:

python - TFRecordReader 在 session 关闭后保持文件锁定

python - InvalidArgumentError : Received a label value of 8825 which is outside the valid range of [0, 8825) SEQ2SEQ 模型

Python 多个随机种子

python - 我收到异常值 :No module named views in django

python - 将 cx_freeze 与 feedparser 导入一起使用时出错 : ModuleNotFoundError: No module named 'sgmllib'

python - 禁用 TensorFlow 调试信息

python - 将 zip 文件下载到本地驱动器并使用 python 2.5 将所有文件解压缩到目标文件夹

python - 银行交易分类的 Tensorflow 实现

python - lstm_9 层的输入 0 与发现 ndim=4 的 : expected ndim=3, 层不兼容。完整形状收到 : [None, 2, 4000, 256]

machine-learning - RNN分类不断输出相同的值