python - 如何按特定值过滤 tf.data.Dataset?

标签 python tensorflow tensorflow-datasets

我通过读取 TFRecords 创建了一个数据集,我映射了值,我想过滤数据集的特定值,但由于结果是一个带有张量的字典,我无法获得张量的实际值或用 tf.cond()/tf.equal 检查它。我该怎么做?

def mapping_func(serialized_example):
    feature = { 'label': tf.FixedLenFeature([1], tf.string) }
    features = tf.parse_single_example(serialized_example, features=feature)
    return features

def filter_func(features):
    # this doesn't work
    #result = features['label'] == 'some_label_value'
    # neither this
    result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
    return result

def main():
    file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
    dataset = tf.contrib.data.TFRecordDataset(file_names)
    dataset = dataset.map(mapping_func)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.filter(filter_func)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    sample = iterator.get_next()

最佳答案

我正在回答我自己的问题。我发现了问题!

我需要做的是像这样tf.unstack()标签:

label = tf.unstack(features['label'])
label = label[0]

在我将它交给 tf.equal() 之前:

result = tf.reshape(tf.equal(label, 'some_label_value'), [])

我想问题是标签被定义为一个数组,其中一个元素的类型为字符串 tf.FixedLenFeature([1], tf.string),所以为了获得第一个和单个元素我必须解压它(这会创建一个列表)然后获取索引为 0 的元素,如果我错了请纠正我。

关于python - 如何按特定值过滤 tf.data.Dataset?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48825785/

相关文章:

python - 如何在 Win 7 64 位上安装 python-Levenshtein 或其他 C/C++ 包

python - 为什么我在 Keras 中使用 multi_gpu_model 的训练速度比单 gpu 差?

tensorflow - 渴望 tf.GradientTape() 只返回无

python - Scikit-learn 管道中的 Keras 模型与早期停止

python - Tensorflow:从大于 2 GB 的 numpy 数组创建小批量

python - 从 Python GUI 运行 shell 脚本

python - python中如何继承ElementTree.Element类?

Python:定位文本框固定在角落并正确对齐

dataframe - 将 "columns"添加到 tf.data.dataset 的最佳方法

tensorflow - 将 Keras model.fit 的 `steps_per_epoch` 与 TensorFlow 的 Dataset API 的 `batch()` 相结合