带有 'rows' 和索引的 ismember 的 Python 版本

标签 python matlab numpy

有人问过类似的问题,但没有一个答案能完全满足我的需要——有些答案允许多维搜索(也就是 matlab 中的“行”选项)但不返回索引。有些返回索引但不允许行。我的阵列非常大 (1M x 2),我已经成功地制作了一个有效的循环,但显然这很慢。在matlab中,内置的ismember函数大约需要10秒。

这是我要找的:

a=np.array([[4, 6],[2, 6],[5, 2]])

b=np.array([[1, 7],[1, 8],[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])

执行此操作的确切 matlab 函数是:

[~,index] = ismember(a,b,'rows')

在哪里

index = [6, 3, 9] 

最佳答案

import numpy as np

def asvoid(arr):
    """
    View the array as dtype np.void (bytes)
    This views the last axis of ND-arrays as bytes so you can perform comparisons on
    the entire row.
    http://stackoverflow.com/a/16840350/190597 (Jaime, 2013-05)
    Warning: When using asvoid for comparison, note that float zeros may compare UNEQUALLY
    >>> asvoid([-0.]) == asvoid([0.])
    array([False], dtype=bool)
    """
    arr = np.ascontiguousarray(arr)
    return arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))


def in1d_index(a, b):
    voida, voidb = map(asvoid, (a, b))
    return np.where(np.in1d(voidb, voida))[0]    

a = np.array([[4, 6],[2, 6],[5, 2]])
b = np.array([[1, 7],[1, 8],[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])

print(in1d_index(a, b))

打印

[2 5 8]

这相当于 Matlab 的 [3, 6, 9],因为 Python 使用基于 0 的索引。

一些注意事项:

  1. 索引按升序返回。他们不对应 到 a 的项目在 b 中的位置。
  2. asvoid 将适用于整数数据类型,但使用 asvoid 时要小心 在 float 数据类型上,因为 asvoid([-0.]) == asvoid([0.]) 返回 数组([False])
  3. asvoid 在连续数组上效果最好。如果数组不连续,数据将被复制到一个连续的数组中,这会降低性能。

尽管有警告,但为了速度,人们可能还是会选择使用 in1d_index:

def ismember_rows(a, b):
    # http://stackoverflow.com/a/22705773/190597 (ashg)
    return np.nonzero(np.all(b == a[:,np.newaxis], axis=2))[1]

In [41]: a2 = np.tile(a,(2000,1))
In [42]: b2 = np.tile(b,(2000,1))

In [46]: %timeit in1d_index(a2, b2)
100 loops, best of 3: 8.49 ms per loop

In [47]: %timeit ismember_rows(a2, b2)
1 loops, best of 3: 5.55 s per loop

因此 in1d_index 快了约 650 倍(对于长度在几千位的数组),但再次注意比较并不完全一致,因为 in1d_index 返回索引按递增顺序排列,而 ismember_rows 返回 a 出现在 b 中的顺序行中的索引。

关于带有 'rows' 和索引的 ismember 的 Python 版本,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/22699756/

相关文章:

python - 为什么模型的准确率高达 84%,但 AUC 却非常低(13%)?

ubuntu 亚马逊 ec2 中的 Matlab MCR

python - np.float 不匹配 np.float32 和 np.float64

python - 如何在字符级别对句子进行单热编码?

Linux 与 Windows 在 MatLab 中执行 lsqcurvefit 和 importdata

python - 将数据帧转换为 numpy 数组?

python - 在 selenium python 中通过 href 查找链接

python - 当我尝试通过 django 发送电子邮件时出现 Gmail SMTPAuthenticationError

python - Google App Engine - 数据存储 get_or_insert key_name 混淆

matlab - CPLEX API for MATLAB 中的分段线性约束