python - 从 Tensorflow 过滤 "empty"值

标签 python tensorflow

我编写这段代码是为了过滤 Dataset 中 <= 6 的值。

import tensorflow as tf
import tensorflow.contrib.data as ds

def make_graph():
    inits = []
    filter_value = tf.constant([6], dtype=tf.int64)
    source = ds.Dataset.range(10)
    batched = source.batch(3)
    batched_iter = batched.make_initializable_iterator()
    batched_next = batched_iter.get_next()
    inits.append(batched_iter.initializer)
    predicate = tf.less_equal(batched_next, filter_value, name="less_than_filter")
    true_coordinates = tf.where(predicate)
    reshaped = tf.reshape(true_coordinates, [-1])
    # need to turn bools into 1 and 0 elsewhere
    found = tf.gather(params=batched_next, indices=reshaped)

    return found, inits # prepend final tensor

def run_graph(final_tensor, initializers, rounds):
    with tf.Session() as sess:
        init_ops = (tf.local_variables_initializer(), tf.global_variables_initializer())
        sess.run(init_ops)
        summary_writer = tf.summary.FileWriter(graph=sess.graph, logdir=".")
        while rounds > 0:
            for i in initializers:
                sess.run(i)
            try:
                while True:
                    final_result = sess.run(final_tensor)
                    p```pythrint("Got result: {r}".format(r=final_result))
            except tf.errors.OutOfRangeError:
                print("Got out of range error")
            rounds -=1

        summary_writer.flush()

def run():
    final_tensor, initializers = make_graph()
    run_graph(final_tensor=final_tensor,
              initializers=initializers,
              rounds=1)

if __name__ == "__main__":
    run()

然而,结果如下:

Got result: [0 1 2]
Got result: [3 4 5]
Got result: [6]
Got result: []
Got out of range error

有没有办法过滤这个空的 Tensor? 我试着集思广益,也许用 tf.while 循环,但我不确定我是否遗漏了某些东西或这样的操作(即 OpKernel 通过不根据其值生成输出来“丢弃”输入)在 Tensorflow 中是不可能的。

最佳答案

在批处理之前只保留 <= 6 的值:

dataset = ds.Dataset.range(10)
dataset = dataset.filter( lambda v : v <= 6 )
dataset = dataset.batch(3)
batched_iter = dataset.make_initializable_iterator()

这将生成仅包含所需数据的批处理。请注意,通常最好在构建批处理之前过滤掉不需要的数据。这样,迭代器就不会生成空张量。

关于python - 从 Tensorflow 过滤 "empty"值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46092081/

相关文章:

python - 如何在Tensorflow中删除张量中的重复值?

python - Tensorflow中sess.run(c)和c.eval()的区别

python - ValueError : Input 0 of layer "sequential" is incompatible with the layer: expected shape=(None, 196, 196, 3),调整大小后发现 shape=(None, 196, 3)

python - Windows 缺少 Python.h

python - 在 Python 2.5 中的嵌套 For Range 循环中传递参数

python - 执行功能的算法仅占所有情况的 12%

python - 从数据帧获取索引列错误

tensorflow - TF物体检测: return subset of inference payload

使用大型查找表的 Python 类

python - 类型错误 : Unrecognized keyword arguments: {'show_accuracy' : True} #yelp challenge dataset