python - 访问张量中各个元素的更好方法

标签 python python-3.x tensorflow

我正在尝试访问张量a的元素,索引在张量b中定义。

a=tf.constant([[1,2,3,4],[5,6,7,8]])
b=tf.constant([0,1,1,0])

我想要的输出是

out = [1 6 7 4]

我尝试过的:

out=[]
for i in range(a.shape[1]):
    out.append(a[b[i],i])

out=tf.stack(out) #[1 6 7 4]

这给出了正确的输出,但我正在寻找一种更好、更紧凑的方法来实现它。

a 的形状类似于 (2,None) 时,我的逻辑也不起作用,因为我无法使用 range(a.shape[1 ]),如果答案也包括这个案例,会对我有帮助

谢谢

最佳答案

您可以使用tf.one_hot()tf.boolean_mask()

import tensorflow as tf
import numpy as np

a_tf = tf.placeholder(shape=(2,None),dtype=tf.int32)
b_tf = tf.placeholder(shape=(None,),dtype=tf.int32)

index = tf.one_hot(b_tf,a_tf.shape[0])
out = tf.boolean_mask(tf.transpose(a_tf),index)

a=np.array([[1,2,3,4],[5,6,7,8]])
b=np.array([0,1,1,0])
with tf.Session() as sess:
    print(sess.run(out,feed_dict={a_tf:a,b_tf:b}))

# print
[1 6 7 4]

关于python - 访问张量中各个元素的更好方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54455169/

相关文章:

python - 在 Emacs 中更改 python 函数参数格式

python mysqldb 一个连接的多个游标

tensorflow - 无法卡住 Tensorflow 工作流程中的 Keras 层

python-3.x - pd.to_numeric 不工作

python - Numpy 索引赋值的 Tensorflow 等价物

python - 为什么结果打印 b'hello,Python!' ,当我使用tensorflow?

python - 如何使用 TDD 创建现有对象的数据库表示?

python - pyopenCL,openCL,无法在GPU上构建程序

python - 如何为 Python 2 和 3 上传通用 Python Wheel?

python - O(n) 复杂度算法,无需 remove() 方法即可从未排序的列表中删除值的实例