python - 在 Python 中按索引对矩阵元素求和

标签 python numpy matrix sum indices

我有两个矩阵(相同的行和列):一个具有浮点值,在另一个矩阵中按索引分组。因此,我想要一个字典或一个列表,其中包含每个索引的元素总和。 索引始终从 0 开始。

A = np.array([[0.52,0.25,-0.45,0.13],[-0.14,-0.41,0.31,-0.41]])
B = np.array([[1,3,1,2],[3,0,2,2]])

RESULT = {0: -0.41, 1: 0.07, 2: 0.03, 3: 0.11}

我找到了这个解决方案,但我正在寻找更快的解决方案。 我正在处理具有 784 x 300 个单元格的矩阵,此算法需要大约 28 毫秒才能完成。

import numpy as np

def matrix_sum_by_indices(indices,matrix):
    a = np.hstack(indices)
    b = np.hstack(matrix)
    sidx = a.argsort()
    split_idx = np.flatnonzero(np.diff(a[sidx])>0)+1
    out = np.split(b[sidx], split_idx)
    return [sum(x) for x in out]

如果你能帮我找到一个更好更简单的解决方案来解决这个问题,我将不胜感激!

编辑:我犯了一个错误,完成时间在 300*10 矩阵中约为 8 毫秒,但在 784x300 中约为 28 毫秒。

EDIT2:我的 A 元素是 float64,所以 bincount 给我 ValueError。

最佳答案

您可以在此处使用 bincount:

a = np.array([[0.52,0.25,-0.45,0.13],[-0.14,-0.41,0.31,-0.41]])
b = np.array([[1,3,1,2],[3,0,2,2]])

N = b.max() + 1
id = b + (N*np.arange(b.shape[0]))[:, None] # since you can't apply bincount to a 2D array
np.sum(np.bincount(id.ravel(), a.ravel()).reshape(a.shape[0], -1), axis=0)

输出:

array([-0.41,  0.07,  0.03,  0.11])

作为函数:

def using_bincount(indices, matrx):
    N = indices.max() + 1
    id = indices + (N*np.arange(indices.shape[0]))[:, None] # since you can't apply bincount to a 2D array
    return np.sum(np.bincount(id.ravel(), matrx.ravel()).reshape(matrx.shape[0], -1), axis=0)

此示例的时间:

In [5]: %timeit using_bincount(b, a)
31.1 µs ± 1.74 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [6]: %timeit matrix_sum_by_indices(b, a)
61.3 µs ± 2.62 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [88]: %timeit scipy.ndimage.sum(a, b, index=[0,1,2,3])
54 µs ± 218 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

(scipy.ndimage.sum 在更大的样本上应该更快)

关于python - 在 Python 中按索引对矩阵元素求和,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51268386/

相关文章:

python - python 正确从文本文件中读取多个列表

r - R中协方差矩阵计算错误,通过calc.relimp()

matlab - MATLAB 中的函数矩阵向量化

python - 从 Python 中的数字列表创建包含五个元素的列表

python - 在 Pygame 表面内显示 cv2.VideoCapture 图像

python - Matplotlib 中的无花果大小(以像素为单位)

python - 使用 python 2.6 导入 numpy

python - 在给定阈值内提取高度相关变量的最佳方法是什么

python - 如何检查矩阵是否稀疏

python - 如何在Python中引发异常?