python - 如何从 dtype 为字符串的 tf.tensor 中获取字符串值

标签 python tensorflow tensorflow-datasets

我想使用 tf.data.Dataset.list_files 函数来提供我的数据集。
但是因为文件不是图片,所以需要手动加载。
问题是 tf.data.Dataset.list_files 将变量作为 tf.tensor 传递,我的 python 代码无法处理张量。

如何从 tf.tensor 获取字符串值。 dtype 是字符串。

train_dataset = tf.data.Dataset.list_files(PATH+'clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: load_audio_file(x))

def load_audio_file(file_path):
  print("file_path: ", file_path)
  # i want do something like string_path = convert_tensor_to_string(file_path)

文件路径是 Tensor("arg0:0", shape=(), dtype=string)

我使用 tensorflow 1.13.1 和 eager 模式。

提前致谢

最佳答案

您可以使用 tf.py_func包裹 load_audio_file() .

import tensorflow as tf

tf.enable_eager_execution()

def load_audio_file(file_path):
    # you should decode bytes type to string type
    print("file_path: ",bytes.decode(file_path),type(bytes.decode(file_path)))
    return file_path

train_dataset = tf.data.Dataset.list_files('clean_4s_val/*.wav')
train_dataset = train_dataset.map(lambda x: tf.py_func(load_audio_file, [x], [tf.string]))

for one_element in train_dataset:
    print(one_element)

file_path:  clean_4s_val/1.wav <class 'str'>
(<tf.Tensor: id=32, shape=(), dtype=string, numpy=b'clean_4s_val/1.wav'>,)
file_path:  clean_4s_val/3.wav <class 'str'>
(<tf.Tensor: id=34, shape=(), dtype=string, numpy=b'clean_4s_val/3.wav'>,)
file_path:  clean_4s_val/2.wav <class 'str'>
(<tf.Tensor: id=36, shape=(), dtype=string, numpy=b'clean_4s_val/2.wav'>,)

更新 TF 2

上述解决方案不适用于 TF 2(使用 2.2.0 测试),即使替换 tf.py_func 也是如此。与 tf.py_function , 给予

InvalidArgumentError: TypeError: descriptor 'decode' requires a 'bytes' object but received a 'tensorflow.python.framework.ops.EagerTensor'

要使其在 TF 2 中运行,请进行以下更改:

  • 删除 tf.enable_eager_execution() (渴望是 TF 2 中的 enabled by default,您可以使用 tf.executing_eagerly() 返回 True 进行验证)
  • 替换tf.py_functf.py_function
  • 替换 file_path 的所有函数内引用与 file_path.numpy()

关于python - 如何从 dtype 为字符串的 tf.tensor 中获取字符串值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56122670/

相关文章:

python - 将用户添加到组访问权限

python - Python 3.6 sum() 是否有 `start=0` 关键字参数?

tensorflow-datasets - 我该如何处理错误:-- unbalanced parenthesis at position 32

python - 根据 OptionMenu 的选择运行命令

python - Django-CMS 插件未显示在可用插件中

python - 多层前馈网络无法在 TensorFlow 中训练

tensorflow - 无法通过 pip3 安装 TensorFlow

python - TensorFlow 失败前提条件错误 : iterator has not been initialized

python - 迭代 tf.data.Dataset 的有效方法

python - 如何保存 Tensorflow 数据集