python - 如何访问 Tensorflow 中循环单元的权重?

标签 python machine-learning tensorflow reinforcement-learning

提高深度 Q 学习任务稳定性的一种方法是为网络维护一组更新缓慢的目标权重,并用于计算 Q 值目标。作为学习过程中不同时间的结果,前向传递中使用了两组不同的权重。对于普通的 DQN,这并不难实现,因为权重是可以在 feed_dict 中设置的 tensorflow 变量,即:

sess = tf.Session()
input = tf.placeholder(tf.float32, shape=[None, 5])
weights = tf.Variable(tf.random_normal(shape=[5,4], stddev=0.1)
bias = tf.Variable(tf.constant(0.1, shape=[4])
output = tf.matmul(input, weights) + bias
target = tf.placeholder(tf.float32, [None, 4])
loss = ...

...

#Here we explicitly set weights to be the slowly updated target weights
sess.run(output, feed_dict={input: states, weights: target_weights, bias: target_bias})

# Targets for the learning procedure are computed using this output.

....

#Now we run the learning procedure, using the most up to date weights,
#as well as the previously computed targets
sess.run(loss, feed_dict={input: states, target: targets})

我想在 DQN 的循环版本中使用这种目标网络技术,但我不知道如何访问和设置循环单元内使用的权重。具体来说,我使用的是 tf.nn.rnn_cell.BasicLSTMCell,但我想知道如何对任何类型的循环单元执行此操作。

最佳答案

BasicLSTMCell 不会将其变量作为其公共(public) API 的一部分公开。我建议您查看这些变量在图表中的名称并提供这些名称(这些名称不太可能更改,因为它们在检查点中,更改这些名称会破坏检查点兼容性)。

或者,您可以制作一份 BasicLSTMCell 的副本,它会公开变量。我认为这是最干净的方法。

关于python - 如何访问 Tensorflow 中循环单元的权重?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40831896/

相关文章:

python - 在路径中创建缺少的目录

python - 如何从 Python 中的字符串中提取开始时间?

machine-learning - Caffe中Tiling层的用途是什么

tensorflow - Keras image_dataset_from_directory 未找到图像

tensorflow - tf.contrib.lookup.index_table_from_tensor 的选项

python - 具有不同序列长度的多对多序列预测

python - Numpy:用同一行中其他元素的最大值替换一行中的每个元素

python - 在 Google App Engine 上的 Django 中获取没有数据库访问权限的登录用户的 key ?

python - 类型错误 : '_IncompatibleKeys' object is not callable

matlab - 使用单输入 GPML 进行大规模回归