python - 是否可以使用 numpy 压缩除 N 维以外的所有维度?

标签 python numpy

我想知道是否有一种方法可以将大小为 1 的所有维度压缩到一个数组中,并且不压缩 N 个维度(即使这些维度的大小为 1)。

为什么?假设我有一个接收一个数组的函数,它返回数组及其转置的矩阵乘积,但数组的形状是未知的(最大 2 个 dims,大小 > 1,但可以有更多大小为 1 的 dims)

可能的矩阵形状示例:

A.shape -> (M,N)
B.shape -> (M,N,1[...,1])
C.shape -> (M,1[...,1])

我希望始终具有 A 的形状 (ndim = 2) 以便执行矩阵乘积。

我可以使用 np.squeeze(X),仅此而已,但在 C 的情况下,这会导致以下问题:

import numpy as np

def my_function(arr):
    arr = np.squeeze(arr)
    return np.dot(arr, arr.transpose())

x = np.arange(1, 6)  # shape (5,)
x = x.reshape((x.size, 1, 1))  # shape (5, 1, 1)
y = my_function(x)
print(y)
# Actual y.shape -> () [is a number]
# Expected y.shape -> (5, 5) [matrix]

我希望 np.squeeze() 函数有一个参数 axis_to_keep。你知道是否有办法轻松实现这一目标?我知道一些方法,但我需要最有效的方法,因为我必须多次执行这些操作。

最佳答案

使用 axes_to_keep 参数进行挤压

这是一个用于通用 n-dim 数组的请求 axes_to_keep 参数,可将这些轴保持在原位 -

def squeeze_generic(a, axes_to_keep):
    out_s = [s for i,s in enumerate(a.shape) if i in axes_to_keep or s!=1]
    return a.reshape(out_s)

样本运行-

In [105]: a = np.random.rand(3,4,5,1,1,6,1)

In [106]: squeeze_generic(a, axes_to_keep=(3,4)).shape
Out[106]: (3, 4, 5, 1, 1, 6)

In [107]: squeeze_generic(a, axes_to_keep=(3,4,6)).shape
Out[107]: (3, 4, 5, 1, 1, 6, 1)

# For cases when axes_to_keep lists axes that aren't singleton
In [108]: squeeze_generic(a, axes_to_keep=(0,1)).shape
Out[108]: (3, 4, 5, 6)

解决您的问题以保留前两个轴

因此,要解决您保留前两个轴的特定情况,它将是 -

squeeze_generic(a, axes_to_keep=range(2))

让我们看一下示例案例 -

In [55]: a = np.random.rand(3,5)

In [56]: squeeze_generic(a, axes_to_keep=range(2)).shape
Out[56]: (3, 5)

In [57]: a = np.random.rand(3,5,1)

In [58]: squeeze_generic(a, axes_to_keep=range(2)).shape
Out[58]: (3, 5)

In [59]: a = np.random.rand(3,1)

In [60]: squeeze_generic(a, axes_to_keep=range(2)).shape
Out[60]: (3, 1)

如果保证第二个轴之后的所有轴都是单轴(长度轴=1)(如果有的话),那么一个简单的 reshape 也可以完成这项工作-

a.reshape(a.shape[0],-1)

关于python - 是否可以使用 numpy 压缩除 N 维以外的所有维度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57472104/

相关文章:

python - 列表理解表达式错误中的 List remove() 方法

python - Python/Django 中具有拆分模型的循环模块依赖关系

python - 用于将一个文档包含到另一个文档中的 sphinx 指令是什么?

python - 在 Pandas 中按行中的值过滤列

python - 无法安装最新版本的 pandas (1.0.3)

python - 通过索引进行条件性 numpy 数组修改

python - 属性错误 : module 'numpy' has no attribute '__version__'

python - 滚动操作性能缓慢创建新列

python - '值错误 : Nothing can be done for the type <class 'numpy.core.records.recarray' > at the moment' error

Python,规则网格上的邻居