python - Tensorflow GetNext() 失败,因为迭代器尚未初始化

标签 python tensorflow

tensorflow 推荐使用 tf.data.Dataset 导入数据。如果图像的验证大小与训练图像不同,是否可以将其用于验证和训练?

import tensorflow as tf
import generator
import glob
import cv2

BATCH_SIZE = 4
filenames_train = glob.glob("/home/user/Datasets/MsCoco/train2017/*.jpg")
filenames_valid = glob.glob("/home/user/Datasets/Set5_14/*.png")

# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename):
  image_decoded = cv2.imread(filename, cv2.IMREAD_COLOR)
  image_blurred_decoded = cv2.GaussianBlur(image_decoded, (1, 1), 0)
  return image_decoded, image_blurred_decoded

# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, image_blurred_decoded):
  image_decoded.set_shape([None, None, None])
  image_blurred_decoded.set_shape([None, None, None])
  image_resized = tf.cast(tf.image.resize_images(image_decoded, [288, 288]),tf.uint8)
  image_blurred = tf.cast(tf.image.resize_images(image_blurred_decoded, [72, 72]),tf.uint8)
  return image_resized, image_blurred

def _cast_function(image_decoded, image_blurred_decoded):
  image_resized = tf.cast(image_decoded,tf.uint8)
  image_blurred = tf.cast(image_blurred_decoded,tf.uint8)
  return image_resized, image_blurred

dataset_train = tf.data.Dataset.from_tensor_slices(filenames_train)
dataset_train = dataset_train.map(
    lambda filename: tuple(tf.py_func(
        _read_py_function, [filename], [tf.uint8, tf.uint8])))
dataset_train = dataset_train.map(_resize_function)
#dataset_train = dataset_train.shuffle(buffer_size=10000)
dataset_train = dataset_train.repeat()
dataset_train = dataset_train.batch(BATCH_SIZE)

# validation dataset
dataset_valid = tf.data.Dataset.from_tensor_slices(filenames_valid)
dataset_valid = dataset_valid.map(
    lambda filename: tuple(tf.py_func(
        _read_py_function, [filename], [tf.uint8, tf.uint8])))
dataset_train = dataset_train.map(_cast_function)
dataset_valid = dataset_valid.batch(BATCH_SIZE)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types)
next_element = iterator.get_next()

training_iterator = dataset_train.make_one_shot_iterator()
validation_iterator = dataset_valid.make_initializable_iterator()


my_transformator = generator.johnson(tf.cast(next_element[1],tf.float32))
images_transformed = my_transformator.new_images
images_transformed_uint = tf.cast(images_transformed,tf.uint8)

loss_square = tf.square(tf.cast(next_element[0],tf.float32)-images_transformed)
loss_sum = tf.reduce_sum(loss_square)
loss_norm = tf.cast(tf.shape(next_element[0])[0]*tf.shape(next_element[0])[1]*tf.shape(next_element[0])[2]*tf.shape(next_element[0])[3],tf.float32)
loss = tf.reduce_sum(loss_square)/loss_norm

solver = tf.train.AdamOptimizer(learning_rate=0.001,beta1=0.5).minimize(loss)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer())
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())
    for i in range(200000):
        curr_norm,curr_loss_sum, _, curr_loss, curr_labels, curr_transformed, curr_loss_square  = sess.run([loss_norm,loss_sum, solver,loss,next_element,images_transformed_uint, loss_square], feed_dict={handle: training_handle})
        if i%1000 == 0:
            curr_labels, curr_transformed = sess.run([next_element, images_transformed_uint], feed_dict={handle: validation_handle})

如果我尝试该代码,我会收到以下错误:

FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element. [[Node: IteratorGetNext = IteratorGetNextoutput_shapes=[, ], output_types=[DT_UINT8, DT_UINT8], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

在代码中您可以看到,我没有调整验证数据集中图像的大小。这些验证图像具有不同的图像大小。

最佳答案

您刚刚忘记初始化 validation_iterator

只需在运行 for 循环之前添加 sess.run(validation_iterator.initializer)

关于python - Tensorflow GetNext() 失败,因为迭代器尚未初始化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48443203/

相关文章:

python - 导入 Keras 时出错

python - 我怎么知道 Keras 模型中是否加载了权重?

python - 用于识别对象之间关系的 OpenCv 技术

python - 如何使用 google doc api 创建具有水平页面方向的文档

Python File IO - 构建字典并查找最大值

python - 池化后预期 Keras 形状不匹配

python - tensorflow 使用所有 GPU 内存

python - 如何在 Keras 中使用类权重进行图像分割

python - 异步 http 调用的时间是应有的两倍

python - ctypes 从 c 函数返回一个字符串