python - 恢复训练好的 tensorflow 模型,编辑与节点关联的值,然后保存

标签 python machine-learning tensorflow deep-learning batch-normalization

我已经用 tensorflow 训练了一个模型,并在训练过程中使用了批量归一化。批量归一化要求用户传递一个名为 is_training 的 bool 值,以设置模型是处于训练阶段还是测试阶段。

训练模型时,is_training 被设置为常量,如下所示

is_training = tf.constant(True, dtype=tf.bool, name='is_training')

我已经保存了训练好的模型,文件包括检查点、.meta 文件、.index 文件和 .data。我想恢复模型并使用它进行推理。 该模型无法重新训练。因此,我想恢复现有模型,将 is_training 的值设置为 False,然后保存模型。 如何编辑与该节点关联的 bool 值,并再次保存模型?

最佳答案

您可以使用 tf.train.import_meta_graphinput_map 参数将图张量重新映射到更新值。

config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
    # define the new is_training tensor
    is_training = tf.constant(False, dtype=tf.bool, name='is_training')

    # now import the graph using the .meta file of the checkpoint
    saver = tf.train.import_meta_graph(
    '/path/to/model.meta', input_map={'is_training:0':is_training})

    # restore all weights using the model checkpoint 
    saver.restore(sess, '/path/to/model')

    # save updated graph and variables values
    saver.save(sess, '/path/to/new-model-name')

关于python - 恢复训练好的 tensorflow 模型,编辑与节点关联的值,然后保存,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45734182/

相关文章:

python - tensorflow 卷积神经网络的人脸识别准确率仅为0.05

tensorflow - 了解更高维度的密集层的输出

Ubuntu 上的 Python file.read()

python - 如果数据框中的十进制值适合一个范围,则用不同的字符串替换它们

machine-learning - 朴素贝叶斯是否应该对词汇表中的所有单词进行乘法运算

python - scikit-learn 中的sample_weight 与class_weight 相比如何?

python - 解码 Unicode 字符串;这是什么意思,我该如何避免呢?

python - 根据另一个数据帧中的 boolean 值设置一个数据帧中的值

machine-learning - Cleartk:初始化 [class org.cleartk.classifier.jar.DefaultSequenceDataWriterFactory] ​​时出错,需要字段 'dataWriterClassName'

python - 为什么自动编码器与编码器 + 解码器的预测不同?