python - 在 TensorFlow 中保存和恢复函数

标签 python tensorflow machine-learning tensorflow-probability

我正在 TensorFlow 中开发一个 VAE 项目,其中编码器/解码器网络是内置函数。这个想法是能够保存,然后加载经过训练的模型并使用编码器功能进行采样。

恢复模型后,我无法运行解码器函数并返回恢复后的训练变量,出现“未初始化值”错误。我认为这是因为该函数要么创建一个新函数,要么覆盖现有函数,要么以其他方式。但我不知道如何解决这个问题。这是一些代码:

class VAE(object):    
    def __init__(self, restore=True):
        self.session = tf.Session()
        if restore:
            self.restore_model()
            self.build_decoder = tf.make_template('decoder', self._build_decoder)

@staticmethod
def _build_decoder(z, output_size=768, hidden_size=200,
                  hidden_activation=tf.nn.elu, output_activation=tf.nn.sigmoid):
    x = tf.layers.dense(z, hidden_size, activation=hidden_activation)
    x = tf.layers.dense(x, hidden_size, activation=hidden_activation)
    logits = tf.layers.dense(x, output_size, activation=output_activation)
    return distributions.Independent(distributions.Bernoulli(logits), 2)

def sample_decoder(self, n_samples):
    prior = self.build_prior(self.latent_dim)
    samples = self.build_decoder(prior.sample(n_samples), self.input_size).mean()
    return self.session.run([samples])

def restore_model(self):
    print("Restoring")
    self.saver = tf.train.import_meta_graph(os.path.join(self.save_dir, "turbolearn.meta"))
    self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir))
    self._restored = True

想要运行 samples = vae.sample_decoder(5)

在我的日常训练中,我跑:

        if self.checkpoint:
            self.saver.save(self.session, os.path.join(self.save_dir, "myvae"), write_meta_graph=True)

更新

根据下面的建议答案,我更改了还原方法

self.saver = tf.train.Saver()
self.saver.restore(self.session, tf.train.latest_checkpoint(self.save_dir))

但是现在在创建 Saver() 对象时出现值错误:

ValueError: No variables to save

最佳答案

tf.train.import_meta_graph 恢复图形,意味着重建存储到文件的网络架构。另一方面,对 tf.train.Saver.restore 的调用只会将文件中的变量值恢复到 session 中的当前图形(如果文件中的某些值属于当前事件图中不存在的变量)。

因此,如果您已经在代码中构建了网络层,则无需调用 tf.train.import_meta_graph。否则这可能会给您带来麻烦。

不确定其余代码的样子,但这里有一些建议。首先构建图形,然后创建 session ,如果适用,最后恢复。你的 init 可能看起来像这样

def __init__(self, restore=True):
    self.build_decoder = tf.make_template('decoder', self._build_decoder)
    self.session = tf.Session()
    if restore:
        self.restore_model()

但是,如果您只是恢复编码器,并重新构建解码器,则可以最后构建解码器。但是不要忘记在使用前初始化它的变量。

关于python - 在 TensorFlow 中保存和恢复函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53246331/

相关文章:

python - 如何使用正则表达式从这样的字符串中提取两个值?

python - 如何从 Qt 对话框获取自定义信号

image - 是否可以将 coreml 模型中输入张量的类型从多数组更改为图像?

python - keras 中的 GPU 波动性非常低

machine-learning - 在SVM中,支持向量不能作为训练样本吗?

python - python/django 中的每小时日志轮转

python - Pandas scatter_matrix : Labels vertical (x) and horizontal (y) without being cut-off

python - 如何在TF2.0中创建自定义渐变的keras图层?

matlab - 来自图像或视频库的人体检测

machine-learning - 如何在 Keras 中使用其他 GPU 和 TensorFlow 后端?