我通过读取 TFRecords 创建了一个数据集,我映射了值,我想过滤数据集的特定值,但由于结果是一个带有张量的字典,我无法获得张量的实际值或用 tf.cond()
/tf.equal
检查它。我该怎么做?
def mapping_func(serialized_example):
feature = { 'label': tf.FixedLenFeature([1], tf.string) }
features = tf.parse_single_example(serialized_example, features=feature)
return features
def filter_func(features):
# this doesn't work
#result = features['label'] == 'some_label_value'
# neither this
result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
return result
def main():
file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(file_names)
dataset = dataset.map(mapping_func)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.filter(filter_func)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
sample = iterator.get_next()
最佳答案
我正在回答我自己的问题。我发现了问题!
我需要做的是像这样tf.unstack()
标签:
label = tf.unstack(features['label'])
label = label[0]
在我将它交给 tf.equal()
之前:
result = tf.reshape(tf.equal(label, 'some_label_value'), [])
我想问题是标签被定义为一个数组,其中一个元素的类型为字符串 tf.FixedLenFeature([1], tf.string)
,所以为了获得第一个和单个元素我必须解压它(这会创建一个列表)然后获取索引为 0 的元素,如果我错了请纠正我。
关于python - 如何按特定值过滤 tf.data.Dataset?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48825785/