tensorflow - 在急切模式下对字符串张量调用 map 时,需要一个类似字节的对象,而不是 'Tensor'

标签 tensorflow tensorflow-datasets

我正在尝试使用 TF.dataset.map 移植此旧代码,因为我收到了弃用警告。

从 TFRecord 文件中读取一组自定义原型(prototype)的旧代码:

record_iterator = tf.python_io.tf_record_iterator(path=filename)
for record in record_iterator:
    example = MyProto()
    example.ParseFromString(record)

我正在尝试使用渴望模式和 map ,但出现此错误。

def parse_proto(string):
      proto_object = MyProto()
      proto_object.ParseFromString(string)

dataset = tf.data.TFRecordDataset(dataset_paths)
parsed_protos = raw_tf_dataset.map(parse_proto)


此代码有效:

for raw_record in raw_tf_dataset:                                                                                                                                         
    proto_object = MyProto()                                                                                                                                              
    proto_object.ParseFromString(raw_record.numpy())                                                                                                                                 


但是 map 给了我一个错误:
TypeError: a bytes-like object is required, not 'Tensor'


什么是使用参数映射的函数结果并将它们视为字符串的正确方法?

最佳答案

您需要从张量中提取字符串并在 map 中使用功能。以下是要在代码中实现的步骤。

  • 您必须使用 tf.py_function(get_path, [x], [tf.float32]) 装饰 map 功能.你可以找到更多关于 tf.py_function here .在 tf.py_function , 第一个参数是 map 的名称函数,第二个参数是要传递给 map 的元素函数和最终参数是返回类型。
  • 您可以使用 bytes.decode(file_path.numpy()) 获取字符串部分在 map 功能中。

  • 所以修改你的程序如下,
    parsed_protos = raw_tf_dataset.map(parse_proto)
    


    parsed_protos = raw_tf_dataset.map(lambda x: tf.py_function(parse_proto, [x], [function return type]))
    

    同时修改parse_proto如下,
    def parse_proto(string):
          proto_object = MyProto()
          proto_object.ParseFromString(string)
    


    def parse_proto(string):
          proto_object = MyProto()
          proto_object.ParseFromString(bytes.decode(string.numpy()))
    

    在下面的简单程序中,我们使用 tf.data.Dataset.list_files读取图像的路径。下一个 map我们正在使用 load_img 读取图像的函数然后做 tf.image.central_crop裁剪图像的中心部分的功能。

    代码 -
    %tensorflow_version 2.x
    import tensorflow as tf
    from keras.preprocessing.image import load_img
    from keras.preprocessing.image import img_to_array, array_to_img
    from matplotlib import pyplot as plt
    import numpy as np
    
    def load_file_and_process(path):
        image = load_img(bytes.decode(path.numpy()), target_size=(224, 224))
        image = img_to_array(image)
        image = tf.image.central_crop(image, np.random.uniform(0.50, 1.00))
        return image
    
    train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
    train_dataset = train_dataset.map(lambda x: tf.py_function(load_file_and_process, [x], [tf.float32]))
    
    for f in train_dataset:
      for l in f:
        image = np.array(array_to_img(l))
        plt.imshow(image)
    

    输出 -

    enter image description here

    希望这能回答你的问题。快乐学习。

    关于tensorflow - 在急切模式下对字符串张量调用 map 时,需要一个类似字节的对象,而不是 'Tensor',我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57103273/

    相关文章:

    python - 在 Anaconda 上安装特定版本的 TensorFlow

    python - 张量板重复标志错误

    validation - 如何使用 tf.session.run() 进行测试(不更新网络参数)?

    tensorflow - Tensorflow 的 DirectoryIterator 是如何工作的?

    python-3.x - tensorflow 中 tf.data.Dataset 的填充

    tensorflow - 对于可变长度特征,使用 tf.train.SequenceExample 相对于 tf.train.Example 有何优点?

    python - 如何解决这些 tensorflow 警告?

    python - tensorflow.load 与下载 URL

    tensorflow - 并行性并没有减少数据集映射中的时间

    python - tf.data API无法打印所有批处理