python - Python中多个张量的高效缩减

标签 python arrays algorithm numpy linear-algebra

我有四个多维张量 v[i,j,k], a[i,s,l], w[j,s,t ,m], Numpy 中的 x[k,t,n],我正在尝试计算给定的张量 z[l,m,n]通过:

z[l,m,n] = sum_{i,j,k,s,t} v[i,j,k] * a[i,s,l] * w[j,s ,t,m] * x[k,t,n]

所有张量都相对较小(比如总共不到 32k 个元素),但是我需要多次执行此计算,所以我希望函数的开销尽可能小。

我尝试像这样使用 numpy.einsum 来实现它:

z = np.einsum('ijk,isl,jstm,ktn', v, a, w, x)

但是速度很慢。我还尝试了以下 numpy.tensordot 调用序列:

z = np.zeros((a.shape[-1],w.shape[-1],x.shape[-1]))
for s in range(a.shape[1]):
  for t in range(x.shape[1]):
    res = np.tensordot(v, a[:,s,:], (0,0))
    res = np.tensordot(res, w[:,s,t,:], (0,0))
    z += np.tensordot(res, x[:,s,:], (0,0))

在一个双 for 循环内对 st 求和(st 都非常小,所以这不是什么大问题)。这工作得更好,但仍然没有我预期的那么快。我认为这可能是因为 tensordot 在获取实际产品之前需要在内部执行的所有操作(例如排列轴)。

我想知道在 Numpy 中是否有更有效的方法来实现这种操作。我也不介意在 Cython 中实现这部分,但我不确定使用什么算法才是正确的。

最佳答案

使用 np.tensordot在某些部分,您可以像这样对事物进行矢量化 -

# Perform "np.einsum('ijk,isl->jksl', v, a)"
p1 = np.tensordot(v,a,axes=([0],[0]))         # shape = jksl

# Perform "np.einsum('jksl,jstm->kltm', p1, w)"
p2 = np.tensordot(p1,w,axes=([0,2],[0,1]))    # shape = kltm

# Perform "np.einsum('kltm,ktn->lmn', p2, w)"
z = np.tensordot(p2,x,axes=([0,2],[0,1]))     # shape = lmn

运行时测试和验证输出 -

In [15]: def einsum_based(v, a, w, x):
    ...:     return np.einsum('ijk,isl,jstm,ktn', v, a, w, x) # (l,m,n)
    ...: 
    ...: def vectorized_tdot(v, a, w, x):
    ...:     p1 = np.tensordot(v,a,axes=([0],[0]))        # shape = jksl
    ...:     p2 = np.tensordot(p1,w,axes=([0,2],[0,1]))   # shape = kltm
    ...:     return np.tensordot(p2,x,axes=([0,2],[0,1])) # shape = lmn
    ...: 

案例#1:

In [16]: # Input params
    ...: i,j,k,l,m,n = 10,10,10,10,10,10
    ...: s,t = 3,3 # As problem states : "both s and t are very small".
    ...: 
    ...: # Input arrays
    ...: v = np.random.rand(i,j,k)
    ...: a = np.random.rand(i,s,l)
    ...: w = np.random.rand(j,s,t,m)
    ...: x = np.random.rand(k,t,n)
    ...: 

In [17]: np.allclose(einsum_based(v, a, w, x),vectorized_tdot(v, a, w, x))
Out[17]: True

In [18]: %timeit einsum_based(v,a,w,x)
10 loops, best of 3: 129 ms per loop

In [19]: %timeit vectorized_tdot(v,a,w,x)
1000 loops, best of 3: 397 µs per loop

案例 #2(更大的数据量):

In [20]: # Input params
    ...: i,j,k,l,m,n = 15,15,15,15,15,15
    ...: s,t = 3,3 # As problem states : "both s and t are very small".
    ...: 
    ...: # Input arrays
    ...: v = np.random.rand(i,j,k)
    ...: a = np.random.rand(i,s,l)
    ...: w = np.random.rand(j,s,t,m)
    ...: x = np.random.rand(k,t,n)
    ...: 

In [21]: np.allclose(einsum_based(v, a, w, x),vectorized_tdot(v, a, w, x))
Out[21]: True

In [22]: %timeit einsum_based(v,a,w,x)
1 loops, best of 3: 1.35 s per loop

In [23]: %timeit vectorized_tdot(v,a,w,x)
1000 loops, best of 3: 1.52 ms per loop

关于python - Python中多个张量的高效缩减,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35838037/

相关文章:

python - 从 Python 2.7 更改为 Python 3.7 数据后获得了额外的字母?

python - hog() 得到了一个意外的关键字参数 'visualize'

python - python中所有嵌套字典的值总和

c - 是否有可能只打印字符串的 1 个字节?

java - 在Java中用数组表示的数字相乘?

python - Go Web 服务器请求产生自己的 goroutine?

Python 根据条件合并两个 Numpy 数组

生成 n 个变量的所有可能 bool 函数的算法

c++ - 算法优化

Java 生活游戏没有提供所需的输出?