场景
我有一个 4D ndarray,由多个 3D 图像/体素组成,尺寸为(体素、dim1、dim2、dim3),比如说(12 体素、96 像素、96 像素、96 像素)。我的目标是从m 个体素体积的中间采样一系列n 个切片。
我已经查看了 (advanced) indexing 上的 Numpy 文档,以及this answer这解释了广播,以及 this answer这解释了 newaxis
的插入由 numpy 编写,但我仍然无法理解我的场景中的底层行为。
问题
最初,我尝试使用以下代码一次性索引数组来实现上述目标:
import numpy as np
array = np.random.rand(12, 96, 96, 96)
n = 4
m_voxels = 6
samples_range = np.arange(0, m_voxels)
middle_slices = array.shape[1] // 2
middle_slices_range = np.arange(middle_slices - n // 2, middle_slices + n // 2)
samples_from_the_middle = array[samples_range, middle_slices_range, :, :]
我没有获得形状为 (6, 4, 96, 96) 的数组,而是遇到了以下 IndexError:
IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (6,) (4,)
当我尝试显式或分两步索引数组时,它按预期工作:
explicit_indexing = array[0:6, 46:50, :, :]
temp = array[samples_range]
samples_from_the_middle = temp[:, middle_slices_range, :, :]
explicit_indexing.shape # output: (6, 4, 96, 96)
samples_from_the_middle.shape # output: (6, 4, 96, 96)
或者,如本answer中所述,另一种方法是:
samples_from_the_middle = array[samples_range[:, np.newaxis], middle_slices_range, :, :]
samples_from_the_middle.shape # output: (6, 4, 96, 96)
我有以下问题:
- 为什么
np.arange
尽管我们实际上使用相同范围的整数进行索引,但显式索引(使用冒号)正常工作时,该方法无法产生预期结果? - 为什么添加
newaxis
第一个索引一维数组似乎可以解决问题?
任何见解将不胜感激。
最佳答案
因此,numpy 处理索引的方式不同,具体取决于您是否使用 slices ,这是当您执行 my_array[a:b]
或 numpy 数组时创建的内容。一个有用的思考方式是 cartesian products 。看看这个演示:
In [1]: import numpy as np
In [2]: x = np.array([[1,2,3],[4,5,6],[7,8,9]])
In [3]: x
Out[3]:
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
In [4]: x[0:3, 0:3]
Out[4]:
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
In [5]: x[np.arange(3), np.arange(3)]
Out[5]: array([1, 5, 9])
请注意,当我们使用切片时,我们会得到您想要的输出。当我们使用 numpy 数组时,我们得到的一维数组只有 3 个元素,而不是 9 个。为什么?这是因为切片会自动用于创建笛卡尔积。 Python 会自动为所有可能的值对生成 [0, 0], [0, 1], [0, 2], [1, 0], ...
形式的索引两片。
当使用 numpy 数组进行索引时,情况并非如此。相反,数组是逐元素匹配的。这意味着仅创建 [0, 0], [1, 1], [2, 2]
对,并且我们仅获得 3 个对角元素。这与 numpy 不将一维数组视为正确的行或列向量有关,除非我们明确声明数组有多少行和列。当我们这样做时,我们启用 numpy 来执行 broadcasting ,本质上,数组沿着长度为 1 的轴“重复”。这让我们可以做类似的事情
In [10]: x = np.array([1,2,3,4,5])
In [11]: y = np.array([6,7,8])
In [12]: from numpy import newaxis as nax
In [13]: x = x[:, nax]
In [14]: y = y[nax, :]
In [15]: x + y
Out[15]:
array([[ 7, 8, 9],
[ 8, 9, 10],
[ 9, 10, 11],
[10, 11, 12],
[11, 12, 13]])
您可以在其中看到我们完全获得了您在索引时寻找的行为! x
数组中的每个元素都与 y
数组中的每个元素配对。
现在我们可以按如下方式使用这些知识:
In [16]: x = np.array([[1,2,3],[4,5,6],[7,8,9]])
In [17]: x
Out[17]:
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
In [18]: x[0:3, 0:3]
Out[18]:
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
In [19]: x[np.arange(3), np.arange(3)]
Out[19]: array([1, 5, 9])
In [20]: x[np.arange(3)[:, nax], np.arange(3)[nax, :]]
Out[20]:
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
我们就完成了!
为了完整起见,请注意 numpy.ix_函数的存在正是为了帮你处理这个问题。这是一个例子:
In [21]: x = np.array([1,2,3,4,5])
In [22]: y = np.array([6,7,8])
In [23]: x, y = np.ix_(x,y)
In [24]: x
Out[24]:
array([[1],
[2],
[3],
[4],
[5]])
In [25]: y
Out[25]: array([[6, 7, 8]])
最后,所有这些都相当于使用 numpy.meshgrid函数,它显式使用x
和y
中每个可能的元素配对创建数组。但是,您不想将其用于索引,因为同时显式创建这些配对并将它们保存在 RAM 中非常浪费内存。最好让 numpy 为您发挥它的魔力。
关于python - 了解 4D ndarray 上高级多维索引的行为,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/76627832/