python - 如何查找嵌套在另一个数组中的数组的索引?

标签 python arrays numpy

<分区>

我正在尝试找到最有效的方法来获取另一个数组中嵌套数组的索引。

import numpy as np
#                     0     1      2     3
haystack = np.array([[1,3],[3,4,],[5,6],[7,8]])
needles  = np.array([[3,4],[7,8]])

鉴于 needles 中包含的数组,我想在 haystack 中找到它们的索引。在本例中为 1,3。

我想到了这个解决方案:

 indexes = [idx for idx,elem in enumerate(haystack) if elem in needles ]

这是错误的,因为实际上 elem 中的一个元素在 needles 中足以返回 idx .

有没有更快的选择?

最佳答案

此响应给出了类似问题的解决方案 Get intersecting rows across two 2D numpy arrays ,您使用非常高效的 np.in1d 函数,但是您可以通过为它提供两个数组的 View 来实现,这允许将它们作为 1d 数据数组进行处理。 在你的情况下,你可以做

A = np.array([[1,3],[3,4,],[5,6],[7,8]])
B = np.array([[3,4],[7,8]])
nrows, ncols = A.shape
dtype={'names':['f{}'.format(i) for i in range(ncols)],
       'formats':ncols * [A.dtype]}
indexes, = np.where(np.in1d(A.view(dtype), B.view(dtype)))

输出:

print(indexes)
> array([1, 3])

关于python - 如何查找嵌套在另一个数组中的数组的索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56851363/

相关文章:

python - Django 模板 : Embed css from file

python - 在 for 循环中追加数组

python - 如何有条件地仅从底部对 numpy 数组进行子集化?

python - 使用 Flask REST API 将三个参数传递给 MySQL

python - 如何启动 django cms 项目

python - 为什么 python 不将异常记录到 syslog(但它确实记录了?)

c - 打印出字符串数组

java - 一次合并排序 3 个子数组

arrays - 用于将元素添加到数组的PowerShell函数

python - 色相色彩空间中的图像分割