python - 对 Tensorflow 中保存/恢复经过训练的权重和偏差感到困惑

标签 python tensorflow

我正在 Tensorflow 中训练卷积神经网络。我的代码运行完成,没有错误。也就是说,我无法准确理解如何保存神经网络学习的权重和偏差(这很重要,因为我正在服务器上进行训练,并且希望在本地进行更简单的可视化工作)。

我这样初始化我的权重和偏差:

weights = {
'wConv1':  tf.Variable(tf.random_normal([5, 5, 1,   3],0,0.25),    name='wC1'),
'wConv2':  tf.Variable(tf.random_normal([5, 5, 3,  32],0,0.25),  name='wC2'),
'wConv3':  tf.Variable(tf.random_normal([5, 5, 32, 64],0,0.25),  name='wC3'),
'wConv4':  tf.Variable(tf.random_normal([5, 5, 64, 128],0,0.25), name='wC4'),
'wConv5':  tf.Variable(tf.random_normal([5, 5, 128, 64],0,0.25), name='wC5'),
'wConv6':  tf.Variable(tf.random_normal([5, 5, 64, 32],0,0.25),  name='wC6'),
'wConv7':  tf.Variable(tf.random_normal([5, 5, 32, 16],0,0.25),  name='wC7'),
'wOUT'  :  tf.Variable(tf.random_normal([5, 5, 16, 1],0,0.25),          name='wCOUT')
}

biases = {
'bConv1': tf.Variable(tf.random_normal([3]),   name='bC1'),
'bConv2': tf.Variable(tf.random_normal([32]),  name='bC2'),
'bConv3': tf.Variable(tf.random_normal([64]),  name='bC3'),
'bConv4': tf.Variable(tf.random_normal([128]), name='bC4'),
'bConv5': tf.Variable(tf.random_normal([64]),  name='bC5'),
'bConv6': tf.Variable(tf.random_normal([32]),  name='bC6'),
'bConv7': tf.Variable(tf.random_normal([16]),  name='bC7'),
'bOUT': tf.Variable(tf.random_normal([1]),     name='bCOUT')
 }

然后,一旦我运行的许多时期完成,我就会使用以下命令保存所有内容:

 saver = tf.train.Saver({"weights": weights, "biases": biases})
 save_path = saver.save(sess, "./output/trained.ckpt")     

现在,在我自己的机器上,我有一个评估脚本,我尝试在其中加载权重:

with sess.as_default():
          saver = tf.train.import_meta_graph('output.ckpt.meta')
          saver.restore(sess,tf.train.latest_checkpoint('./'))
          a= tf.all_variables()
          sess.run(tf.global_variables_initializer())
          b=sess.run(pred,feed_dict={x: input[:,:,:,30,:]})

现在的问题是,当我加载“a”时,我遇到了困惑,其中似乎有许多我的偏差和权重变量的副本:

<tf.Variable 'wC1:0' shape=(5, 5, 1, 3) dtype=float32_ref>,
<tf.Variable 'wC2:0' shape=(5, 5, 3, 32) dtype=float32_ref>,
<tf.Variable 'wC3:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
<tf.Variable 'wC4:0' shape=(5, 5, 64, 128) dtype=float32_ref>,
<tf.Variable 'wC5:0' shape=(5, 5, 128, 64) dtype=float32_ref>,
<tf.Variable 'wC6:0' shape=(5, 5, 64, 32) dtype=float32_ref>,
<tf.Variable 'wC7:0' shape=(5, 5, 32, 16) dtype=float32_ref>,
<tf.Variable 'wCOUT:0' shape=(5, 5, 16, 1) dtype=float32_ref>,
<tf.Variable 'bC1:0' shape=(3,) dtype=float32_ref>,
<tf.Variable 'bC2:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC3:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC4:0' shape=(128,) dtype=float32_ref>,
<tf.Variable 'bC5:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC6:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC7:0' shape=(16,) dtype=float32_ref>,
<tf.Variable 'bCOUT:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'beta1_power:0' shape=() dtype=float32_ref>,
<tf.Variable 'beta2_power:0' shape=() dtype=float32_ref>,
<tf.Variable 'wC1/Adam:0' shape=(5, 5, 1, 3) dtype=float32_ref>,
<tf.Variable 'wC1/Adam_1:0' shape=(5, 5, 1, 3) dtype=float32_ref>,
<tf.Variable 'wC2/Adam:0' shape=(5, 5, 3, 32) dtype=float32_ref>,
<tf.Variable 'wC2/Adam_1:0' shape=(5, 5, 3, 32) dtype=float32_ref>,
<tf.Variable 'wC3/Adam:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
<tf.Variable 'wC3/Adam_1:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
<tf.Variable 'wC4/Adam:0' shape=(5, 5, 64, 128) dtype=float32_ref>,
<tf.Variable 'wC4/Adam_1:0' shape=(5, 5, 64, 128) dtype=float32_ref>,
<tf.Variable 'wC5/Adam:0' shape=(5, 5, 128, 64) dtype=float32_ref>,
<tf.Variable 'wC5/Adam_1:0' shape=(5, 5, 128, 64) dtype=float32_ref>,
<tf.Variable 'wC6/Adam:0' shape=(5, 5, 64, 32) dtype=float32_ref>,
<tf.Variable 'wC6/Adam_1:0' shape=(5, 5, 64, 32) dtype=float32_ref>,
<tf.Variable 'wC7/Adam:0' shape=(5, 5, 32, 16) dtype=float32_ref>,
<tf.Variable 'wC7/Adam_1:0' shape=(5, 5, 32, 16) dtype=float32_ref>,
<tf.Variable 'wCOUT/Adam:0' shape=(5, 5, 16, 1) dtype=float32_ref>,
<tf.Variable 'wCOUT/Adam_1:0' shape=(5, 5, 16, 1) dtype=float32_ref>,
<tf.Variable 'bC1/Adam:0' shape=(3,) dtype=float32_ref>,
<tf.Variable 'bC1/Adam_1:0' shape=(3,) dtype=float32_ref>,
<tf.Variable 'bC2/Adam:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC2/Adam_1:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC3/Adam:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC3/Adam_1:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC4/Adam:0' shape=(128,) dtype=float32_ref>,
<tf.Variable 'bC4/Adam_1:0' shape=(128,) dtype=float32_ref>,
<tf.Variable 'bC5/Adam:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC5/Adam_1:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC6/Adam:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC6/Adam_1:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC7/Adam:0' shape=(16,) dtype=float32_ref>,
<tf.Variable 'bC7/Adam_1:0' shape=(16,) dtype=float32_ref>,
<tf.Variable 'bCOUT/Adam:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'bCOUT/Adam_1:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'wC1:0' shape=(5, 5, 1, 3) dtype=float32_ref>,
<tf.Variable 'wC2:0' shape=(5, 5, 3, 32) dtype=float32_ref>,
<tf.Variable 'wC3:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
<tf.Variable 'wC4:0' shape=(5, 5, 64, 128) dtype=float32_ref>,
<tf.Variable 'wC5:0' shape=(5, 5, 128, 64) dtype=float32_ref>,
<tf.Variable 'wC6:0' shape=(5, 5, 64, 32) dtype=float32_ref>,
<tf.Variable 'wC7:0' shape=(5, 5, 32, 16) dtype=float32_ref>,
<tf.Variable 'wCOUT:0' shape=(5, 5, 16, 1) dtype=float32_ref>,
<tf.Variable 'bC1:0' shape=(3,) dtype=float32_ref>,
<tf.Variable 'bC2:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC3:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC4:0' shape=(128,) dtype=float32_ref>,
<tf.Variable 'bC5:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC6:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC7:0' shape=(16,) dtype=float32_ref>,
<tf.Variable 'bCOUT:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'beta1_power:0' shape=() dtype=float32_ref>,
<tf.Variable 'beta2_power:0' shape=() dtype=float32_ref>,
<tf.Variable 'wC1/Adam:0' shape=(5, 5, 1, 3) dtype=float32_ref>,
<tf.Variable 'wC1/Adam_1:0' shape=(5, 5, 1, 3) dtype=float32_ref>,
<tf.Variable 'wC2/Adam:0' shape=(5, 5, 3, 32) dtype=float32_ref>,
<tf.Variable 'wC2/Adam_1:0' shape=(5, 5, 3, 32) dtype=float32_ref>,
<tf.Variable 'wC3/Adam:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
<tf.Variable 'wC3/Adam_1:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
<tf.Variable 'wC4/Adam:0' shape=(5, 5, 64, 128) dtype=float32_ref>,
<tf.Variable 'wC4/Adam_1:0' shape=(5, 5, 64, 128) dtype=float32_ref>,
<tf.Variable 'wC5/Adam:0' shape=(5, 5, 128, 64) dtype=float32_ref>,
<tf.Variable 'wC5/Adam_1:0' shape=(5, 5, 128, 64) dtype=float32_ref>,
<tf.Variable 'wC6/Adam:0' shape=(5, 5, 64, 32) dtype=float32_ref>,
<tf.Variable 'wC6/Adam_1:0' shape=(5, 5, 64, 32) dtype=float32_ref>,
<tf.Variable 'wC7/Adam:0' shape=(5, 5, 32, 16) dtype=float32_ref>,
<tf.Variable 'wC7/Adam_1:0' shape=(5, 5, 32, 16) dtype=float32_ref>,
<tf.Variable 'wCOUT/Adam:0' shape=(5, 5, 16, 1) dtype=float32_ref>,
<tf.Variable 'wCOUT/Adam_1:0' shape=(5, 5, 16, 1) dtype=float32_ref>,
<tf.Variable 'bC1/Adam:0' shape=(3,) dtype=float32_ref>,
<tf.Variable 'bC1/Adam_1:0' shape=(3,) dtype=float32_ref>,
<tf.Variable 'bC2/Adam:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC2/Adam_1:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC3/Adam:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC3/Adam_1:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC4/Adam:0' shape=(128,) dtype=float32_ref>,
<tf.Variable 'bC4/Adam_1:0' shape=(128,) dtype=float32_ref>,
<tf.Variable 'bC5/Adam:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC5/Adam_1:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC6/Adam:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC6/Adam_1:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC7/Adam:0' shape=(16,) dtype=float32_ref>,
<tf.Variable 'bC7/Adam_1:0' shape=(16,) dtype=float32_ref>,
<tf.Variable 'bCOUT/Adam:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'bCOUT/Adam_1:0' shape=(1,) dtype=float32_ref>,
<tf.Variable 'wC1:0' shape=(5, 5, 1, 3) dtype=float32_ref>,
<tf.Variable 'wC2:0' shape=(5, 5, 3, 32) dtype=float32_ref>,
<tf.Variable 'wC3:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
<tf.Variable 'wC4:0' shape=(5, 5, 64, 128) dtype=float32_ref>,
<tf.Variable 'wC5:0' shape=(5, 5, 128, 64) dtype=float32_ref>,
<tf.Variable 'wC6:0' shape=(5, 5, 64, 32) dtype=float32_ref>,
<tf.Variable 'wC7:0' shape=(5, 5, 32, 16) dtype=float32_ref>,
<tf.Variable 'wCOUT:0' shape=(5, 5, 16, 1) dtype=float32_ref>,
<tf.Variable 'bC1:0' shape=(3,) dtype=float32_ref>,
<tf.Variable 'bC2:0' shape=(32,) dtype=float32_ref>,
<tf.Variable 'bC3:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC4:0' shape=(128,) dtype=float32_ref>,
<tf.Variable 'bC5:0' shape=(64,) dtype=float32_ref>,
<tf.Variable 'bC6:0' shape=(32,) dtype=float32_ref>,


<tf.Variable 'bC7:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'bCOUT:0' shape=(1,) dtype=float32_ref>,
 <tf.Variable 'beta1_power:0' shape=() dtype=float32_ref>,
 <tf.Variable 'beta2_power:0' shape=() dtype=float32_ref>,
 <tf.Variable 'wC1/Adam:0' shape=(5, 5, 1, 3) dtype=float32_ref>,
 <tf.Variable 'wC1/Adam_1:0' shape=(5, 5, 1, 3) dtype=float32_ref>,
 <tf.Variable 'wC2/Adam:0' shape=(5, 5, 3, 32) dtype=float32_ref>,
 <tf.Variable 'wC2/Adam_1:0' shape=(5, 5, 3, 32) dtype=float32_ref>,
 <tf.Variable 'wC3/Adam:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
 <tf.Variable 'wC3/Adam_1:0' shape=(5, 5, 32, 64) dtype=float32_ref>,
 <tf.Variable 'wC4/Adam:0' shape=(5, 5, 64, 128) dtype=float32_ref>,
 <tf.Variable 'wC4/Adam_1:0' shape=(5, 5, 64, 128) dtype=float32_ref>,
 <tf.Variable 'wC5/Adam:0' shape=(5, 5, 128, 64) dtype=float32_ref>,
 <tf.Variable 'wC5/Adam_1:0' shape=(5, 5, 128, 64) dtype=float32_ref>,
 <tf.Variable 'wC6/Adam:0' shape=(5, 5, 64, 32) dtype=float32_ref>,
 <tf.Variable 'wC6/Adam_1:0' shape=(5, 5, 64, 32) dtype=float32_ref>,
 <tf.Variable 'wC7/Adam:0' shape=(5, 5, 32, 16) dtype=float32_ref>,
 <tf.Variable 'wC7/Adam_1:0' shape=(5, 5, 32, 16) dtype=float32_ref>,
 <tf.Variable 'wCOUT/Adam:0' shape=(5, 5, 16, 1) dtype=float32_ref>,
 <tf.Variable 'wCOUT/Adam_1:0' shape=(5, 5, 16, 1) dtype=float32_ref>,
 <tf.Variable 'bC1/Adam:0' shape=(3,) dtype=float32_ref>,
 <tf.Variable 'bC1/Adam_1:0' shape=(3,) dtype=float32_ref>,
 <tf.Variable 'bC2/Adam:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'bC2/Adam_1:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'bC3/Adam:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'bC3/Adam_1:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'bC4/Adam:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'bC4/Adam_1:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'bC5/Adam:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'bC5/Adam_1:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'bC6/Adam:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'bC6/Adam_1:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'bC7/Adam:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'bC7/Adam_1:0' shape=(16,) dtype=float32_ref>,
 <tf.Variable 'bCOUT/Adam:0' shape=(1,) dtype=float32_ref>,
 <tf.Variable 'bCOUT/Adam_1:0' shape=(1,) dtype=float32_ref>]

我的问题是,如何在 Tensorflow 中仅保存经过训练的权重和偏差,然后稍后加载它们以进行测试?

最佳答案

在回答确切的问题之前,让我先解决您的疑虑:

the issue is, when I load in "a" I get a mess, with what appears to be many copies of my bias and weight variables

在评估脚本中加载训练元图:

saver = tf.train.import_meta_graph('output.ckpt.meta')

在该图中,在训练期间,除了您定义的显式权重和偏差变量之外,还有与优化过程相关的变量(即带有后缀 adam 或 beta1_power 的变量)。执行上面指定的行,它们会在您的评估脚本中再次定义,尽管推理不一定需要它们。

另一种方法是定义您想要进行推理的确切图表,这可能与训练有点不同。在你的情况下 - 只是没有定义优化器。

现在回答您的问题:

My question is, how can I save ONLY the trained weights and biases in Tensorflow and then load them later on for testing purposes?

从您的代码来看,您似乎基本上是这样做的。您看到的其他变量源于上述内容。

需要指出的一件事 - 确保在恢复变量后不要对其进行初始化。如果保留当前代码,请先初始化,然后恢复。如果您计划更改推理图而不包含优化器,则无需初始化任何变量。

关于python - 对 Tensorflow 中保存/恢复经过训练的权重和偏差感到困惑,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46716070/

相关文章:

python - 拟合 Keras L1 模型

python - 使用 Xpath 提取值时 Scrapy 中的空列表

python - 调用从 Memcached 检索到的对象的方法

Python 3 中的 Python 挑战级别 17

python - 如何在spyder中添加python控制台

python - 对实时图像中的新对象进行分类

tensorflow - 我如何操作 <class 'google.protobuf.internal.containers.RepeatedScalarFieldContainer' >?

python ndarray 什么是返回值

python - 如何在 Tensor.tensorflow 中用 True 替换唯一值,用 False 替换其他值?

python - TensorFlow 中的 tf.nn.embedding_lookup_sparse 是什么意思?