python - 运行时保存 TensorFlow 状态(权重)

标签 python tensorflow

我有一个运行拟合模型的 python 程序。 我没有实现一个保护程序,所以我想知道是否有一种方法可以在拟合运行时直接从内存(也许从临时文件?)恢复权重

最佳答案

我认为这应该可以使用命令 tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

在不保存模型的情况下提取/恢复模型权重的代码如下所示。

import tensorflow as tf
import numpy as np

X_ = tf.placeholder(tf.float64, [None, 5], name="Input")
Y_ = tf.placeholder(tf.float64, [None, 1], name="Output")

X = np.random.randint(1,10,[10,5])
Y = np.random.randint(0,2,[10,1])

with tf.variable_scope("LogReg"):
    pred = tf.layers.dense(X_, 1, activation=tf.nn.sigmoid, name = 'fc1')
    loss = tf.losses.mean_squared_error(labels=Y_, predictions=pred)
    training_ops = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

with tf.Session() as sess:

    all_vars= tf.global_variables()
    def get_var(name):
        for i in range(len(all_vars)):
            if all_vars[i].name.startswith(name):
                return all_vars[i]
        return None

    sess.run(tf.global_variables_initializer())    
    for i in range(200):
        sess.run([training_ops], feed_dict={X_: X,Y_: Y})
        Weight_Vars = sess.run([tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)], feed_dict={X_: X,Y_: Y})
        print(Weight_Vars)

关于python - 运行时保存 TensorFlow 状态(权重),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56687437/

相关文章:

python - 在keras中定义自定义损失函数

python - LSTM输入形状错误: Input 0 is incompatible with layer sequential_1

python-3.x - TPU分类器InvalidArgumentError : No OpKernel was registered to support Op 'CrossReplicaSum' with these attrs

python - 查找网格上两点之间的距离

python - 使用条件列表过滤 Pandas 中的 DataFrame

python - 如何在html页面上显示模型字段 - Django

python - Python和Haskell有C/C++的 float 不确定性问题吗?

python - 我应该使用 pip3 还是 pip?我应该删除旧软件包并在虚拟环境中重新安装它们吗?

machine-learning - 口音检测API?

tensorflow - 为什么 Tensorflow 中的分布策略不支持梯度裁剪?