python - Tensorflow:从 TFRecords 文件中提取图像和标签

标签 python tensorflow tensorflow-datasets

我有一个 TFRecords 文件,其中包含带有标签、名称、大小等的图像。我的目标是将标签和图像提取为 numpy 数组。

我执行以下操作来加载文件:

def extract_fn(data_record):
    features = {
        # Extract features using the keys set during creation
        "image/class/label":    tf.FixedLenFeature([], tf.int64),
        "image/encoded":        tf.VarLenFeature(tf.string),
    }
    sample = tf.parse_single_example(data_record, features)
    #sample = tf.cast(sample["image/encoded"], tf.float32)
    return sample

filename = "path\train-00-of-10"
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(extract_fn)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    while True:
        data_record = sess.run(next_element)
        print(data_record)

图像被保存为一个字符串。如何将图像转换为 float32 ?我试过 sample = tf.cast(sample["image/encoded"], tf.float32)这是行不通的。我要data_record是一个包含图像作为 numpy 数组和标签作为 np.int32 的列表数字。我该怎么做?

现在data_record看起来像这样:

{'image/encoded': SparseTensorValue(indices=array([[0]]), values=array([b'\xff\xd8\ ... 8G\xff\xd9'], dtype=object), dense_shape=array([1])), 'image/class/label': 394}

我不知道如何处理它。我将不胜感激

编辑

如果我打印 samplesample['image/encoded']extract_fn()我得到以下信息:

print(sample) = {'image/encoded': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7fe41ec15978>, 'image/class/label': <tf.Tensor 'ParseSingleExample/ParseSingleExample:3' shape=() dtype=int64>}

print(sample['image/encoded'] = SparseTensor(indices=Tensor("ParseSingleExample/ParseSingleExample:0", shape=(?, 1), dtype=int64), values=Tensor("ParseSingleExample/ParseSingleExample:1", shape=(?,), dtype=string), dense_shape=Tensor("ParseSingleExample/ParseSingleExample:2", shape=(1,), dtype=int64))

看起来图像是稀疏张量和tf.image.decode_image抛出错误。将图像提取为 tf.float32 的正确方法是什么?张量?

最佳答案

我相信您存储编码为 JPEG 或 PNG 或其他格式的图像。所以,在阅读时,你必须解码它们:

def extract_fn(data_record):
    features = {
        # Extract features using the keys set during creation
        "image/class/label":    tf.FixedLenFeature([], tf.int64),
        "image/encoded":        tf.VarLenFeature(tf.string),
    }
    sample = tf.parse_single_example(data_record, features)
    image = tf.image.decode_image(sample['image/encoded'], dtype=tf.float32) 
    label = sample['image/class/label']
    return image, label

...

with tf.Session() as sess:
    while True:
        image, label = sess.run(next_element)
        image = image.reshape(IMAGE_SHAPE)

更新: 似乎您将数据作为稀疏张量中的单个单元格值获取。尝试将其转换回密集并在解码前后进行检查:

def extract_fn(data_record):
    features = {
        # Extract features using the keys set during creation
        "image/class/label":    tf.FixedLenFeature([], tf.int64),
        "image/encoded":        tf.VarLenFeature(tf.string),
    }
    sample = tf.parse_single_example(data_record, features)
    label = sample['image/class/label']
    dense = tf.sparse_tensor_to_dense(sample['image/encoded'])

    # Comment it if you got an error and inspect just dense:
    image = tf.image.decode_image(dense, dtype=tf.float32) 

    return dense, image, label

关于python - Tensorflow:从 TFRecords 文件中提取图像和标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54723912/

相关文章:

python - TensorFlow:一个网络,两个 GPU?

python - 使用 Windows 在 python 中获取友好的设备名称

python - 层 lstm_5 的输入 0 与层 : expected ndim=3, 不兼容,发现 ndim=2

python - future 警告 : Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated

python - Tensorflow 数据集数据预处理是对整个数据集进行一次还是对每次调用 iterator.next()?

python - "Error while extracting"来自 tensorflow 数据集

python - 有没有一种方法可以在Python中的向后兼容代码上使用类型提示?

python - 从字典列表中删除具有近乎重复值的字典 - Python

tensorflow - 我可以在tensorflow中导出单词的嵌入矩阵吗?

python-3.6 - 为什么从 Tensorflow Record 文件中读取大张量的速度如此之慢?