python - 在 TensorFlow 中提取子张量

标签 python matrix tensorflow tensor

我有一个 2 x 4 tensorA = [[0,1,0,1],[1,0,1,0]]。 我想从维度 d 中提取索引 i。 在 Torch 中我可以这样做:tensorA:select(d,i)

例如,tensorA:select(0,0) 将返回 [0,1,0,1] 并且 tensorA:select(1,1) 将返回 [1,0]

如何在 TensorFlow 中执行此操作? 我能找到的最简单的方法是:tf.gather(tensorA,indices=[i],axis=d)

但是为此使用聚集似乎有点矫枉过正。有谁知道更好的方法吗?

最佳答案

您可以使用以下配方:

将除 d 之外的所有轴替换为分号,并将值 i 放在 d 轴上,例如:

tensorA[0, :]  # same as tensorA:select(0,0)
tensorA[:, 1]  # same as tensorA:select(1,1)
tensorA[:, 0]  # same as tensorA:select(1,0)

但是,当我尝试这个时,我遇到了语法错误:

i = 1
selection = [:,i]  # this raises SyntaxError
tensorA[selection]

所以我用了一个切片,如下所示

i = 1
selection = [slice(0,2,1), i]
tensorA[selection]  # same as tensorA:select(1,i)

这个函数可以解决问题:

def select(t, axis, index):
    shape = K.int_shape(t)
    selection = [slice(shape[a]) if a != axis else index for a in 
                 range(len(shape))]
    return t[selection]

例如:

import numpy as np
t = K.constant(np.arange(60).reshape(2,5,6))
sub_tensor = select(t, 1, 1)
print(K.eval(sub_tensor)  

打印

[[6., 7., 8., 9., 10., 11.],

[36., 37., 38., 39., 40., 41.]]

关于python - 在 TensorFlow 中提取子张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49158935/

相关文章:

python - 如何从 bazel 构建中排除包?

python - Tensorflow:如何赋予变量范围

python - 显示已删除的子框架 (wxPython)

android - 如何从图像中获取 RGB 矩阵?

c++ - R 矩阵到 Armadillo 的转换真的很慢

python-3.x - Tensorflow恢复模型: attempting to use uninitialized value

python - 在一行上打印 `numpy.ndarray`

python - Pandas 或 Numpy : how to get matching data entries to do data manipulation

python - 进程间共享锁

c - 如何使用修复中的参数,修复是指针?