python - 如何只添加到 Python 中数组的对角线?

标签 python numpy scipy

我在 Python 上有一个数组,如下所示:

array([[ 0.57733218,  0.09794384,  0.44497735],
       [ 0.87061284,  0.10253493,  0.56643557],
       [ 0.76358739,  0.44902046,  0.86064797]])

我想将标量值 20 添加到数组的对角线上,使得输出为:

array([[ 20.57733218,  0.09794384,  0.44497735],
       [ 0.87061284,  20.10253493,  0.56643557],
       [ 0.76358739,  0.44902046,  20.86064797]])

由于我可能还要处理非常大的矩阵数组, 按照此 thread 的公认解决方案中的建议,通过赋值操作执行此对角线加法的最有效方法是什么? ?

最佳答案

一种方法是在具有适当步长的扁平切片上分配 -

In [233]: a
Out[233]: 
array([[ 0.57733218,  0.09794384,  0.44497735],
       [ 0.87061284,  0.10253493,  0.56643557],
       [ 0.76358739,  0.44902046,  0.86064797]])

In [234]: a.flat[::a.shape[1]+1] += 20

In [235]: a
Out[235]: 
array([[ 20.57733218,   0.09794384,   0.44497735],
       [  0.87061284,  20.10253493,   0.56643557],
       [  0.76358739,   0.44902046,  20.86064797]])

我们也可以使用ndarray.ravel()来得到扁平化 View 然后赋值-

a.ravel()[::a.shape[1]+1] += 20

另一种方法是使用 np.einsum 让我们看到对角元素 -

In [269]: a
Out[269]: 
array([[ 0.57733218,  0.09794384,  0.44497735],
       [ 0.87061284,  0.10253493,  0.56643557],
       [ 0.76358739,  0.44902046,  0.86064797]])

In [270]: d = np.einsum('ii->i', a)

In [271]: d += 20

In [272]: a
Out[272]: 
array([[ 20.57733218,   0.09794384,   0.44497735],
       [  0.87061284,  20.10253493,   0.56643557],
       [  0.76358739,   0.44902046,  20.86064797]])

基准测试

In [285]: a = np.random.rand(10000,10000)

# @Willem Van Onsem's soln
In [286]: %timeit np.fill_diagonal(a, a.diagonal() + 20)
10000 loops, best of 3: 159 µs per loop

In [287]: %timeit a.flat[::a.shape[1]+1] += 20
10000 loops, best of 3: 179 µs per loop

In [288]: %timeit a.ravel()[::a.shape[1]+1] += 20
100000 loops, best of 3: 18.2 µs per loop

In [289]: %%timeit
     ...: d = np.einsum('ii->i', a)
     ...: d += 20
100000 loops, best of 3: 18.5 µs per loop

关于python - 如何只添加到 Python 中数组的对角线?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48170804/

相关文章:

python - 如何将字符串转换为具有相关时区的 python 日期时间对象?

python - Pyspark 使用另一列中的值替换 Spark 数据帧列中的字符串

python - 将 python ndarray 转换为 matlab 文件

python - 如何在 Python 中操作 wav 文件数据?

javascript - Django : Files not passed in the request even after setting enctype=multipart/form-data?

python - 如何将代码从 Git 更新到 Docker 容器

python - 为 PySpark 捆绑 Python3 包导致缺少导入

python - 将 nD numpy 数组折叠成一维数组

python - scipy.stats.ttest_1samp 的基于排列的替代方案

python - 在 Scikit-learn 分类器中找到最常见的术语