Tensorflow - 数据集 API 中的字符串处理

标签 tensorflow tensorflow-datasets

我在 .txt 格式的目录中有 <text>\t<label> 文件。我正在使用 TextLineDataset API 来使用这些文本记录:

filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]

dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames)

dataset = dataset.flat_map(
    lambda filename: (
        tf.contrib.data.TextLineDataset(filename)
        .map(_parse_data)))

def _parse_data(line):   
    line_split = tf.string_split([line], '\t')
    features = {"raw_text": tf.string(line_split.values[0].strip().lower()),
                "label": tf.string_to_number(line_split.values[1], 
                    out_type=tf.int32)}
    parsed_features = tf.parse_single_example(line, features)
    return parsed_features["raw_text"], raw_features["label"]

我想对 raw_text 功能进行一些字符串清理/处理。当我尝试运行 line_split.values[0].strip().lower() 时,出现以下错误:

AttributeError: 'Tensor' object has no attribute 'strip'

最佳答案

对象 lines_split.values[0] 是一个 tf.Tensor 对象,表示从 line 的第 0 次拆分。它不是 Python 字符串,因此它没有 .strip().lower() 方法。相反,您必须将 TensorFlow 操作应用于张量以执行转换。

TensorFlow 目前没有很多 string operations ,但您可以使用 tf.py_func() 操作在 tf.Tensor 上运行一些 Python 代码:

def _parse_data(line):
    line_split = tf.string_split([line], '\t')

    raw_text = tf.py_func(
        lambda x: x.strip().lower(), line_split.values[0], tf.string)

    label = tf.string_to_number(line_split.values[1], out_type=tf.int32)

    return {"raw_text": raw_text, "label": label}

请注意,问题中的代码还有一些其他问题:
  • 不要使用 tf.parse_single_example() 。该操作仅用于解析 tf.train.Example Protocol Buffer 字符串;解析文本时不需要使用它,可以直接从 _parse_data() 返回提取的特征。
  • 使用 dataset.map() 而不是 dataset.flat_map() 。仅当映射函数的结果是 flat_map() 对象时才需要使用 Dataset(因此返回值需要扁平化为单个数据集)。当结果是一个或多个 map() 对象时,您必须使用 tf.Tensor
  • 关于Tensorflow - 数据集 API 中的字符串处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47022987/

    相关文章:

    python - TF Hub微调错误: ValueError: Failed to find data adapter that can handle input

    tensorflow - 将 Keras model.fit 的 `steps_per_epoch` 与 TensorFlow 的 Dataset API 的 `batch()` 相结合

    python - 在训练周期的一部分之后运行评估

    python - 如何使用新的数据集 API 在 tensorflow 中仅循环使用占位符提供的数据一次

    python - 如何打印来自 tf.data 的数据集示例?

    python-3.x - Tensorflow 错误 "UnimplementedError: Cast string to float is not supported"- 使用估计器的线性分类器模型

    python - 按常数因子缩放张量中的行集

    python - Keras 中的逐像素加权损失函数 - TensorFlow 2.0

    python - 在 tensorflow 中展开张量

    python - 如何跟踪使用 CPU 与 GPU 进行深度学习的时间?