python - TF 中的索引操作

标签 python python-2.7 tensorflow

有没有办法在 tensorflow 中索引操作?特别是,我对通过 tf.while_loop 的迭代器变量进行索引感兴趣。

更具体地说,假设我有 my_ops = [op1, op2]。我想要:

my_ops = [...]
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: my_ops[i](...)
r = tf.while_loop(c, b, [i])

不幸的是,这不起作用,因为 python 数组仅支持整数索引。

最佳答案

我相信这是不可能的。但是,您可以改为使用 tf.stack堆叠操作的输出张量,然后使用 tf.gather以获得所需的输出。

这里有一个例子:

import tensorflow as tf


def condition(i, x):
    return tf.less(i, 10)


def body_1(my_ops):
    def b(i, x):
        stacked_results = tf.stack([op(x) for op in my_ops])
        gather_idx = tf.mod(i, 2)
        return [i + 1, tf.gather(stacked_results, gather_idx)]

    return b


def body_2(my_ops):
    def b(i, x):
        nb_ops = len(my_ops)
        pred_fn_pairs = [(tf.equal(tf.mod(i, nb_ops), 0), lambda: my_ops[0](x)),
                         (tf.equal(tf.mod(i, nb_ops), 1), lambda: my_ops[1](x))]
        result = tf.case(pred_fn_pairs)
        return [i + 1, result]

    return b


my_ops = [lambda x: tf.Print(x + 1, [x, 1]),
          lambda x: tf.Print(x + 2, [x, 2])]
i = tf.constant(0)
x = tf.constant(0)
r = tf.while_loop(condition, body_2(my_ops), [i, x])  # See the difference with body_1

with tf.Session() as sess:
    i, x = sess.run(r)
    print(x)  # Prints 15 = 5*2 + 5*1

关于python - TF 中的索引操作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53033780/

相关文章:

python - 无法从模板中的 for 循环中获得我想要的内容?

python - 定义函数的全局行为?好的做法

python - pyomo环境下添加约束

python - 如何导入模块 PyImgur Python

python-3.x - 在 Keras 功能 API 中使用嵌入层

python - Tensorflow dynamic_rnn 弃用

Python 从列表及其变体创建正则表达式

python - Python 的 "from <module> import <symbol>"的动态等价物

python - 递归访问嵌套字典的路径和值

python - Keras 访问自定义损失函数中的各个值