python - 使用 N 维索引列表查询 numpy 数组

标签 python arrays numpy

给定一个数据数组,有没有办法通过N维索引列表查询数据?

示例:

import numpy as np
data = np.array([[-14., 2.,  19.],
                 [-13., 1.,  20.],
                 [-15., 2.,  18.],
                 [-13., 0.,  19.],
                 [-15., 1.,  19.],
                 [-14., 0.,  19.],
                 [-14., 1.,  20.]])


# Uniformly shaped array: works
queries = np.array([[2, 4, 6, 0], [3, 6, 4, 5]])
print data[queries]

# Properly returns
#[[[-15.   2.  18.]
#  [-15.   1.  19.]
#  [-14.   1.  20.]
#  [-14.   2.  19.]]
#
# [[-13.   0.  19.]
#  [-14.   1.  20.]
#  [-15.   1.  19.]
#  [-14.   0.  19.]]]


# N-dimentional array fails
queries = np.array([[4, 6, 0], [3, 6, 4, 5]])
print data[queries]

# IndexError: arrays used as indices must be of integer (or boolean) type #
#
# Desired result:
#[[[-15.   1.  19.]
#  [-14.   1.  20.]
#  [-14.   2.  19.]]
#
# [[-13.   0.  19.]
#  [-14.   1.  20.]
#  [-15.   1.  19.]
#  [-14.   0.  19.]]]

最佳答案

查询中的两个元素具有不同的长度,因此它们存储为列表而不是 numpy 数组;同样,结果也会在内部存储为列表,并且使用 numpy 数组对 python 列表将不再有任何优势;你能做的最好的可能就是一个普通的 for 循环:

[data[query].tolist() for query in queries]

#[[[-15.0, 1.0, 19.0], 
#  [-14.0, 1.0, 20.0], 
#  [-14.0, 2.0, 19.0]],
#
# [[-13.0, 0.0, 19.0],
#  [-14.0, 1.0, 20.0],
#  [-15.0, 1.0, 19.0],
#  [-14.0, 0.0, 19.0]]]

或者,如果您想将结果部分保留为 numpy 数组:

[data[query] for query in queries]

#[array([[-15.,   1.,  19.],
#        [-14.,   1.,  20.],
#        [-14.,   2.,  19.]]), array([[-13.,   0.,  19.],
#        [-14.,   1.,  20.],
#        [-15.,   1.,  19.],
#        [-14.,   0.,  19.]])]

关于python - 使用 N 维索引列表查询 numpy 数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43882341/

相关文章:

python - tensorflow 概率中的热切执行在第二次迭代时停止记录梯度

python - 如何在 Jupyter Notebook 或 JupyterLab 中使用破折号?

ios - 当数组不只有整数类型数据时,如何获得数组元素的平均数?

python - PyTorch 的 'ToPILImage' 问题

python - 在 Selenium/Python 中循环遍历 div 中的 div

python - subprocess.call() 中的 Pytest 模拟全局变量

Java Quicksort(数组值在重新分配时不改变值)

arrays - 如何在数组中存储带有空格的元素?

python - 基于 numpy 中的向量生成张量的元素

Python numpy 数组之间共享指针