python-3.x - 是否可以控制函数如何将 np.array 向量化为参数?

标签 python-3.x function numpy vectorization

我试图利用以 np.arrays 作为参数的函数的并行处理,这只是我得到的经常性问题的简化示例。
当我编写一个接收 np.arrays 作为要矢量化的参数的函数时,我无法控制它应该如何矢量化输入。例如,在这种情况下,我试图对 W 和 X 数组进行张量乘法,但 X 以“1”作为其第一个坐标进行聚合。当然,我会发送一个 np.array,里面有很多 X。不幸的是,解释器没有向量化 X 试图进行聚合,所以它崩溃了。
所需的结果是注释行(解释器正确矢量化输入的位置),但我想知道是否有任何方法可以控制函数的矢量化以使工作类似于第一行。

def h(X,W):
    return np.tensordot( np.r_[1,X],W, (0,0) )      # dimensions error!
    #return W[0] + np.tensordot( X, W[1:], (0,0) )  # desired result
    

W = np.array([0,1,2])   # plane coeffs

X0 = np.arange(4+1)
X1 = np.arange(5+1)
X0,X1 = np.meshgrid(X0,X1)
X = np.array([X0,X1])

fx = h(X,W)

fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
ax.plot_wireframe( X0, X1, fx,  linewidth=0.5,color='b')
plt.show()
PD:示例已完成,它将绘制平面 z = 0 + 1x + 2y 的 4 x 5 网格

最佳答案

你的变量:

In [41]: W = np.array([0,1,2])   # plane coeffs
    ...: 
    ...: X0 = np.arange(4+1)
    ...: X1 = np.arange(5+1)
    ...: X0,X1 = np.meshgrid(X0,X1)
    ...: X = np.array([X0,X1])
In [42]: W.shape
Out[42]: (3,)
In [43]: X.shape
Out[43]: (2, 6, 5)
meshgrid您可能想使用 index='ij' .但这与以下计算无关
In [44]: temp=W[0]+np.tensordot(X, W[1:], (0,0))
In [45]: temp.shape
Out[45]: (6, 5)
一种在不分离 W[0] 的情况下执行相同操作的方法, 是展开X至 (3,6,5)
In [46]: Xnew = np.concatenate((np.ones((1,6,5),X.dtype),X), axis=0)
In [47]: Xnew.shape
Out[47]: (3, 6, 5)
In [48]: temp2 = np.tensordot(Xnew,W, (0,0))
In [49]: temp2.shape
Out[49]: (6, 5)
In [50]: np.allclose(temp,temp2)
Out[50]: True
通常当我们谈论 numpy矢量化我们的意思是用 numpy 方法中的编译迭代替换 python 级别的迭代。在这里,我认为第二个没有帮助。它只是替换了一个单独的添加,同时增加了 dot 的大小。产品 - 所以有 concatenate 的附加时间加上更多的时间在dot .如果我们用更复杂的任务替换复杂任务的几次迭代,“矢量化”是不值得的。
这实际上更快:
temp3 = W[0]+W[1]*X0+W[2]*X1
tensordot申请 reshapetranspose到参数以将问题减少到对 np.dot 的单次调用.然后它可能会 reshape 结果。einsum优于tensordot (至少在速度方面):
temp=W[0]+np.einsum('ijk,i->jk',X,W[1:])

关于python-3.x - 是否可以控制函数如何将 np.array 向量化为参数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68313283/

相关文章:

python-3.x - 将大型数据集加载/feed_dicting 到 Tensorflow session 中

python - 聚合组产生 Pandas 数据框

python - 尝试以标识为列表的奇怪数据格式打印每个第 n 个元素

Javascript - 返回语句被误解

python - 在 numpy 中减去数组并用 pylab 绘图

python - Python 对数下降曲线上的梯度下降

python - 使用 python 将文本发送到带有逗号分隔符的列

Javascript 将函数分配给现有变量

javascript - jQuery 动画之间的延迟

python - Numpy 索引 - 关于奇怪行为/不一致的问题