python - TensorFlow:按名称访问变量的内容

标签 python tensorflow

我有以下情况:

我已经构建、训练并保存了我的网络。现在,我正在尝试恢复网络并可视化权重矩阵。

我知道变量的所有名称,但我没有为变量分配 python 标记以传递给 session 进行评估。如何检索变量中的数据?

这是我的代码情况:

dataset_params = nn_params.mnist_dataset_params
design = nn_designs.mnist_net_A_design
## Build Housing Object
mnist_nn = nn_class.CNN(**dataset_params)
mnist_nn.build_net(design['design'])
mnist_nn.__setattr__('saved_path',saved_model)
mnist_nn_epoch_file = saved_model+'_epochs_completed.txt'
mnist_nn.__setattr__('epoch_file',mnist_nn_epoch_file)


# evaluate weight variables
session = tf.Session()
saver = tf.train.Saver()
session.run(tf.initialize_all_variables())
saver.restore(session,saved_model)




session.close()

为了提取权重,我应该将什么传递给 session ? (一个示例权重名称是:'conv_w_1')?

最佳答案

您可以使用 tf.get_collection() 来做到这一点获取所需变量的查找方法:

weight_var = tf.get_collection(tf.GraphKeys.VARIABLES, "conv_w_1")[0]

weight_var_value = session.run(weight_var)

关于python - TensorFlow:按名称访问变量的内容,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36918831/

相关文章:

python - 如何使用 Keras 中保存的模型来预测和分类图像?

python - 检查 Python 字典中是否存在键/值对

python - 如何在 plotly 和 python 中使用 make_subplot 绘制多个图表

python - 如何使用 AutoGrad 包?

适用于 Windows 的 python gstreamer

python - 加载 TensorFlow 嵌入模型

python - 使用 Python 的 Subprocess 库避免 SSH 密码提示

c++ - 未能 Bazel 以 tensorflow 作为依赖项构建 C++ 项目

python - 如何在TensorFlow中使用朴素贝叶斯?

python - tensorflow :检查标量 bool 张量是否为真