python - Numpy 2d 和可能的 N d 通过元组数组进行索引

标签 python arrays numpy indexing

对于一维数组,可以通过 N 维整数数组对其进行索引,如下所示:

>>> rand = np.random.rand(9).astype(np.float32)
>>> rand
array([ 0.69786191,  0.09376735,  0.60141236,  0.35305005,  0.68340319,
        0.0746202 ,  0.11620298,  0.46607161,  0.90864712], dtype=float32)
>>> u = np.random.randint(0, 9, (2,2))
>>> u
array([[0, 6],
       [5, 6]])
>>> rand[u]
array([[ 0.69786191,  0.11620298],
       [ 0.0746202 ,  0.11620298]], dtype=float32)

但我不能对二维数组做同样的事情:

>>> rand2d = np.random.rand(9).astype(np.float32).reshape(3,3)
>>> rand2d
array([[ 0.83248657,  0.75025952,  0.87252802],
       [ 0.78049046,  0.92902303,  0.42035589],
       [ 0.80461669,  0.49386421,  0.56518084]], dtype=float32)
>>> u = np.random.randint(0, 3, (2,2,2))
>>> u
array([[[2, 2],
        [2, 2]],

       [[0, 2],
        [0, 1]]])
>>> rand2d[u]
array([[[[ 0.80461669,  0.49386421,  0.56518084],
         [ 0.80461669,  0.49386421,  0.56518084]],

        [[ 0.80461669,  0.49386421,  0.56518084],
         [ 0.80461669,  0.49386421,  0.56518084]]],

       [[[ 0.83248657,  0.75025952,  0.87252802],
         [ 0.80461669,  0.49386421,  0.56518084]],

        [[ 0.83248657,  0.75025952,  0.87252802],
         [ 0.78049046,  0.92902303,  0.42035589]]]], dtype=float32)

虽然我期望的结果是:

[[rand2d[2, 2], rand2d[2, 2]],
[rand2d[0, 2], rand2d[0, 1]]] ==
[[0.56518084, 0.56518084],
[0.87252802, 0.75025952]]

如何在不迭代的情况下实现这一目标?

最佳答案

直接来自example in the docs :

>>> 
>>> x
array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11]])
>>> 
>>> rows = np.array([[0,0],[3,3]])
>>> columns = np.array([[0,2],[0,2]])
>>> x[rows,columns]
array([[ 0,  2],
       [ 9, 11]])
>>> 

您可以看到它正在选择 (0,0)、(0,2) 和 (3,0),(3,2) 处的项目。

关于python - Numpy 2d 和可能的 N d 通过元组数组进行索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/27477295/

相关文章:

c - 读取文件时如何更改数组值

Java Swing Object[][] getData() 混淆

java - 编译出现问题: Symbol error in csv file split in java

python - Scrapy——抓取页面并抓取下一页

python - Tensorflow 抛出 "TypeError: unhashable type: ' list'"错误

Python 警告 : Retrying (Retry(total=4, 连接 = 无,读取 = 无,重定向 = 无,状态 = 无))

python - 如何权衡散点图中的点以进行拟合?

python - Pandas 数据框的自定义函数中的 Forex_python

python - 从 CSV 文件创建图形并使用 Django 和 Pandas Python 库呈现到浏览器

python - XPath 没有使用迭代器指向正确的 HTML 表元素(Scrapy)