python - 从自定义回调中获取 Keras 模型输入

标签 python tensorflow keras

我有一个非常简单的问题。我有一个为分类定义的 Keras 模型(TF 后端)。我想在训练期间将训练图像转储到我的模型中以进行调试。我正在尝试创建一个自定义回调来为此编写 Tensorboard 图像摘要。

但是如何在回调中获取真实的训练数据呢?

目前我正在尝试这个:

class TensorboardKeras(Callback):                                                                                                                                                                                                                                     
    def __init__(self, model, log_dir, write_graph=True):                                                                                                                                                                                                             
        self.model = model                                                                                                                                                                                                                                            
        self.log_dir = log_dir                                                                                                                                                                                                                                        
        self.session = K.get_session()                                                                                                                                                                                                                                

        tf.summary.image('input_image', self.model.input)                                                                                                                                                                                                             
        self.merged = tf.summary.merge_all()                                                                                                                                                                                                                          

        if write_graph:                                                                                                                                                                                                                                               
            self.writer = tf.summary.FileWriter(self.log_dir, K.get_session().graph)                                                                                                                                                                                  
        else:                                                                                                                                                                                                                                                         
            self.writer = tf.summary.FileWriter(self.log_dir)

    def on_batch_end(self, batch, logs=None):
        summary = self.session.run(self.merged, feed_dict={})                                                                                                                                                                                                         
        self.writer.add_summary(summary, batch)                                                                                                                                                                                                                       
        self.writer.flush()

但是我得到了错误: InvalidArgumentError(见上文的回溯):您必须为占位符张量“input_1”提供一个 dtype float 和形状 [?,224,224,3]

必须有一种方法可以查看哪些模型作为输入,对吧?

或者也许我应该尝试另一种方式来调试它?

最佳答案

你不需要回调。您需要做的就是实现一个生成图像及其标签作为元组的函数。 flow_from_directory 函数有一个名为 save_to_dir 的参数,它可以满足您的所有需求,如果不能,您可以执行以下操作:

def trainGenerator(batch_size,train_path, image_size)
    #preprocessing see https://keras.io/preprocessing/image/ for details
    image_datagen = ImageDataGenerator(horizontal_flip=True)
    #create image generator see https://keras.io/preprocessing/image/#flow_from_directory for details
    train_generator = image_datagen.flow_from_directory(
        train_path,
        class_mode = "categorical",
        target_size = image_size,
        batch_size = batch_size,
        save_prefix  = "augmented_train",
        seed = seed)

    for (batch_imgs, batch_labels) in train_generator: 
        #do other stuff such as dumping images or further augmenting images
    yield (batch_imgs,batch_labels)


t_generator = trainGenerator(32, "./train_data", (224,224,3))
model.fit_generator(t_generator,steps_per_epoch=10,epochs=1)

关于python - 从自定义回调中获取 Keras 模型输入,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52783781/

相关文章:

python - pip安装flask-user时出现"Failed building wheel for py-bcrypt"

python - Pygame--pygame不能运行乒乓球碰撞游戏

python - Flask - 数据创建

python - Tensorflow - 范围明智回归损失

python - tensorflow 中的自定义卷积函数

python - 使用 tf.estimator 提前停止,如何?

python - 使用 TensorFlow 模型时图像分类的准确性没有提高

python - 内存 SQL 查询

python - 使用干净和退化的照片训练深度学习模型

python - Keras 值错误 : This model has never been called