python - 在 Numpy 中是否有更快的方法来做到这一点?

标签 python numpy numpy-ndarray array-broadcasting

我想在 numpy 中生成一个 3D 矩阵。代码是:

mean_value = np.array([1, 2, 3], dtype=np.float32)
h, w = 5, 5
b = np.ones((h, w, 1), dtype=np.float32) * np.reshape(mean_value, [1, 1, 3])
print(b.shape)  # (5, 5, 3)

有没有更快的方法来生成b?谢谢。

最佳答案

为了提高效率(内存、性能),只需使用 np.broadcast_to 进行广播对于 View 输出 -

np.broadcast_to(mean_value,(h,w,)+mean_value.shape)

作为一个 View ,它没有内存开销,因此在运行时几乎是免费的。

让我们验证一下性能部分 -

In [45]: mean_value = np.array([1, 2, 3], dtype=np.float32)
    ...: h, w = 5, 5

In [46]: %timeit np.broadcast_to(mean_value,(h,w,)+mean_value.shape)
100000 loops, best of 3: 3.21 µs per loop

In [47]: mean_value = np.random.rand(10000)
    ...: h, w = 5000, 5000

In [48]: %timeit np.broadcast_to(mean_value,(h,w,)+mean_value.shape)
100000 loops, best of 3: 3.22 µs per loop

和内存部分(作为 View )-

In [49]: np.shares_memory(mean_value,np.broadcast_to(mean_value,(h,w,)+mean_value.shape))
Out[49]: True

关于python - 在 Numpy 中是否有更快的方法来做到这一点?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58759956/

相关文章:

python - 使用 matplotlib 将灰度图像转换为 RGB 热图图像

python - PyYAML 中的数组没有缩进或空格

python - python内置函数存储在哪里?

python - 改进一个简单的 spring 网络的 numpy 实现

python - 从另一个包含字典键的数组构造值数组

python - FaceNet 嵌入的余弦距离问题

python - 如何一致地展平 numpy 数组?

python - 跨 numpy 矩阵的映射函数

c++ - 如何定义自定义浮点类型 numpy dtypes (C-API)

python - 如何仅更改 numpy 数组字典中的一个值