例如,np.array([[1,2],[3,4]])[np.triu_indices(2)]
的形状为 (3,)
,是上三角条目的扁平列表。但是,如果我有一批 2x2 矩阵:
foo = np.repeat(np.array([[[1,2],[3,4]]]), 30, axis=0)
并且我想获得每个矩阵的上三角索引,尝试的天真的做法是:
foo[:,np.triu_indices(2)]
但是,这个对象实际上具有 (30,2,3,2)
形状(与我们预期的 (30,3)
相反,如果我们有批量提取上三角条目。
我们如何沿着批量维度广播元组索引?
最佳答案
获取元组并使用它们来索引最后两个暗淡 -
r,c = np.triu_indices(2)
out = foo[:,r,c]
或者,带有 Ellipsis
的单行代码适用于 3D
和 2D
数组 -
foo[(Ellipsis,)+np.triu_indices(2)]
它同样适用于 2D
数组 -
out = foo[r,c] # foo as 2D input array
<小时/>
遮蔽方式
3D阵列案例
我们还可以使用掩码进行基于掩码
的方式 -
foo[:,~np.tri(2,k=-1, dtype=bool)]
二维数组案例
foo[~np.tri(2,k=-1, dtype=bool)]
关于python - 如何沿批量维度广播 numpy 索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58100302/