python - 应用 Dropout 恢复 TensorFlow 模型

标签 python tensorflow

我在恢复使用 dropout 训练的 TF 模型时遇到问题。如何将 keep_prob 设置为 1.0

我在下面尝试的代码不起作用,我认为这是因为我在恢复模型时创建了一个新的tf.placeholder。但如何恢复 keep_prob 占位符?

这是我的恢复代码

import tensorflow as tf
import numpy as np

logs_path = ...


def readImage(filenames):
    filenameQ = tf.train.string_input_producer(filenames, shuffle=False)

    reader = tf.WholeFileReader() # Magic function
    key, value = reader.read(filenameQ)

    image = tf.image.decode_png(value)
    image.set_shape([101, 201, 1])
    return image

image = readImage([("../image-to-tfrecords/train/chef/chef%d.png" % i) for i in range(5000)])

merged_summary_op = tf.summary.merge_all()

class CNN:
    """
    Class to load saved CNN
    """
    def __init__(self, model_file, imgsize=None, visualize=True, saver=None, batch_size=100):
        self.model_file = model_file
        self.saver = saver
        self.batch_size = batch_size
        if imgsize:
            self.img_h = imgsize[0]
            self.img_w = imgsize[1]

    def predict(self, X):
        # OUTCOMMENTED THIS LINE:
        #keep_prob = tf.placeholder(tf.float32)  # dropout (keep probability)

        """
        Prediction Routine
        """
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            train_writer = tf.summary.FileWriter(logs_path + '/train', sess.graph)

            graph = tf.get_default_graph()

            # restore the model
            self.saver = tf.train.import_meta_graph(self.model_file) #, input_map={"keep_prob_training:0": keep_prob}, import_scope='imported'
            self.saver.restore(sess, tf.train.latest_checkpoint('./tfmodels/cnn/'))

            x, y = tf.get_collection('inputs')

            # ADDED THE FOLLOWING LINE:
            keep_prob = tf.get_collection('dropout_train')[0]

            logits, predict_op = tf.get_collection('outputs')
            probs = tf.nn.softmax(logits)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            predictions = []

            train_writer.add_graph(sess.graph)

            for i in range(5000):
                batch_xs = sess.run(X)

                # Reshape batch_xs if only a single image is given
                #   (numpy is 4D: batch_size * heigth * width * channels)
                batch_xs = np.reshape(batch_xs, (-1, self.img_w * self.img_h))
                prediction = sess.run([predict_op], feed_dict={x: batch_xs, keep_prob: 1.0})

                predictions.append(prediction[0][0])

            train_writer.close()

            # finalize
            coord.request_stop()
            coord.join(threads)

        return predictions

    @staticmethod
    def load(model_file, imgsize=[201, 101]):
        """ Load TF metagraph """
        print "Loading Model from: " + model_file
        return FNN(model_file, imgsize)


def main():
    """ Main """
    # Load and predict
    model = CNN.load("tfmodels/cnn/tf.model.meta")
    model.predict(image)


if __name__ == '__main__':
    main()

更新

下面是来自张量板的图表。我比较了保存程序和恢复程序的图表,图表是相同的:)

logits

fc1

dropout

最佳答案

您正确地观察到了问题。 keep_prob = tf.placeholder(tf.float32)张量未连接到您使用 tf.train.import_meta_graph() 导入的图形,因此输入该张量对推理没有影响。

解决方案将取决于您构建初始模型的方式。您首先需要识别name用作 keep_prob 的张量在你的原始图表中。例如,如果您使用以下语句在原始图表的顶层创建它:

keep_prob = tf.placeholder(tf.float32, name="keep_prob_training")

...名称为 "keep_prob_training:0" 。但是,如果您没有传递明确的 name参数,那么名称将类似于 "Placeholder:0" , "Placeholder_1:0"等等。最可靠的判断方法是 print(keep_prob.name)在原来的程序中。

一旦你有了这个名字(为了具体起见,我假设是 "keep_prob_training:0"),你需要对 tf.train.import_meta_graph() 进行简单的修改。调用,以设置 input_map并连接您的新keep_prob导入图的张量。以下内容应该有效:

self.saver = tf.train.import_meta_graph(
    self.model_file, input_map={"keep_prob_training:0": keep_prob})

完成此操作后,喂 keep_prob张量将允许您控制在推理时应用的 dropout。

关于python - 应用 Dropout 恢复 TensorFlow 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42912234/

相关文章:

python - 使用 Sympy 实现用户输入的正确方法?

python - 为什么 Python 列表的内存使用量小于预期?

python - 如何在没有 pandas 中的 to_datetime 函数的情况下格式化列中的日期时间值?

Python: sqlite3 - 如何加速数据库的更新

machine-learning - 使用 Tensorflow 和 inception V3 预训练模型训练高清图像

python-3.x - Keras:初始化权重时模型未学习

python - 数据框到 frozenset

python-2.7 - 如何为 python 2.7 安装 tensorflow?

python - Tensorflow 组合图像候选区域的非极大值抑制

python - 在 tensorflow 中展开张量