python - 加速索引 "revert"

标签 python arrays numpy indexing

我有一个形状为 (n, 3) 的 numpy 数组 a,其中填充了从 0m 的整数>。 mn 都可以相当大。众所周知,从 0m 的每个整数有时只出现一次,但大多数情况下在 a 中的某处恰好出现两次。连续没有重复的索引。

我现在想构造“反向”索引,即两个形状为 (m, 2) 的数组 b_rowb_col每一行都包含 a 中的(一个或两个)行/列索引,其中 row_idx 出现在 a 中。

这有效:

import numpy

a = numpy.array([
    [0, 1, 2],
    [0, 1, 3],
    [2, 3, 4],
    [4, 5, 6],
    # ...
    ])

print(a)

b_row = -numpy.ones((7, 2), dtype=int)
b_col = -numpy.ones((7, 2), dtype=int)
count = numpy.zeros(7, dtype=int)
for k, row in enumerate(a):
    i = count[row]
    b_row[row, i] = k
    b_col[row, i] = [0, 1, 2]
    count[row] += 1

print(b_row)
print(b_col)
[[0 1 2]
 [0 1 3]
 [2 3 4]
 [4 5 6]]

[[ 0  1]
 [ 0  1]
 [ 0  2]
 [ 1  2]
 [ 2  3]
 [ 3 -1]
 [ 3 -1]]

[[ 0  0]
 [ 1  1]
 [ 2  0]
 [ 2  1]
 [ 2  0]
 [ 1 -1]
 [ 2 -1]]

但是由于a上的显式循环而很慢。

有关如何加快速度的任何提示?

最佳答案

这是一个解决方案:

import numpy as np

m = 7
a = np.array([
    [0, 1, 2],
    [0, 1, 3],
    [2, 3, 4],
    [4, 5, 6],
    # ...
    ])

print('a:')
print(a)

a_flat = a.flatten()  # Or a.ravel() if can modify original array
v1, idx1 = np.unique(a_flat, return_index=True)
a_flat[idx1] = -1
v2, idx2 = np.unique(a_flat, return_index=True)
v2, idx2 = v2[1:], idx2[1:]
rows1, cols1 = np.unravel_index(idx1, a.shape)
rows2, cols2 = np.unravel_index(idx2, a.shape)
b_row = -np.ones((m, 2), dtype=int)
b_col = -np.ones((m, 2), dtype=int)
b_row[v1, 0] = rows1
b_col[v1, 0] = cols1
b_row[v2, 1] = rows2
b_col[v2, 1] = cols2

print('b_row:')
print(b_row)
print('b_col:')
print(b_col)

输出:

a:
[[0 1 2]
 [0 1 3]
 [2 3 4]
 [4 5 6]]
b_row:
[[ 0  1]
 [ 0  1]
 [ 0  2]
 [ 1  2]
 [ 2  3]
 [ 3 -1]
 [ 3 -1]]
b_col:
[[ 0  0]
 [ 1  1]
 [ 2  0]
 [ 2  1]
 [ 2  0]
 [ 1 -1]
 [ 2 -1]]

编辑:

IPython 中用于比较的小基准。如@eozd所示由于 np.unique 在 O(n) 中运行,算法复杂度原则上更高,但对于实际大小来说,矢量化解决方案似乎仍然要快得多:

import numpy as np

def method_orig(a, m):
    b_row = -np.ones((m, 2), dtype=int)
    b_col = -np.ones((m, 2), dtype=int)
    count = np.zeros(m, dtype=int)
    for k, row in enumerate(a):
        i = count[row]
        b_row[row, i] = k
        b_col[row, i] = [0, 1, 2]
        count[row] += 1
    return b_row, b_col

def method_jdehesa(a, m):
    a_flat = a.flatten()  # Or a.ravel() if can modify original array
    v1, idx1 = np.unique(a_flat, return_index=True)
    a_flat[idx1] = -1
    v2, idx2 = np.unique(a_flat, return_index=True)
    v2, idx2 = v2[1:], idx2[1:]
    rows1, cols1 = np.unravel_index(idx1, a.shape)
    rows2, cols2 = np.unravel_index(idx2, a.shape)
    b_row = -np.ones((m, 2), dtype=int)
    b_col = -np.ones((m, 2), dtype=int)
    b_row[v1, 0] = rows1
    b_col[v1, 0] = cols1
    b_row[v2, 1] = rows2
    b_col[v2, 1] = cols2
    return b_row, b_col

n = 100000
c = 3
m = 200000

# Generate random input
# This does not respect "no doubled indices in row" but is good enough for testing
np.random.seed(100)
a = np.random.permutation(np.concatenate([np.arange(m), np.arange(m)]))[:(n * c)].reshape((n, c))

%timeit method_orig(a, m)
# 3.22 s ± 1.3 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit method_jdehesa(a, m)
# 108 ms ± 764 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

关于python - 加速索引 "revert",我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50389518/

相关文章:

python - Python的httplib和urllib2有什么区别?

python - 如何找到哪个对象在for循环中导致错误?

php - mysqli 从数组更新

python - numpy.sum 的内部结构

python - Numpy:检查值是否为 NaT

c++ - 使用 Eigen::Map<Eigen::MatrixXd> 作为 Eigen::MatrixXd 类型的函数参数

python - 无法从 cPanel cron 作业调用 Python 3

python - 性能监控Openerp

c++ - 允许 n 维坐标的有效方法?

php - 填充数组php的算法