matlab - 多维数组的元素矩阵乘法

标签 matlab multidimensional-array sum elementwise-operations numpy-einsum

我想在 MATLAB 中实现逐分量矩阵乘法,可以使用 numpy.einsum 在 Python 中如下:

import numpy as np
M = 2
N = 4
I = 2000
J = 300

A = np.random.randn(M, M, I)
B = np.random.randn(M, M, N, J, I)
C = np.random.randn(M, J, I)

# using einsum
D = np.einsum('mki, klnji, lji -> mnji', A, B, C)

# naive for-loop
E = np.zeros(M, N, J, I)
for i in range(I):
    for j in range(J):
        for n in range(N):
            E[:,n,j,i] = B[:,:,i] @ A[:,:,n,j,i] @ C[:,j,i]

print(np.sum(np.abs(D-E))) # expected small enough

到目前为止,我使用 i 的 for 循环, j , 和 n ,但我不想,至少 n 的 for 循环.

最佳答案

选项 1:从 MATLAB 调用 numpy
假设您的系统已设置 according to the documentation ,并且您已经安装了 numpy 包,您可以执行以下操作(在 MATLAB 中):

np = py.importlib.import_module('numpy');

M = 2;
N = 4;
I = 2000;
J = 300;

A = matpy.mat2nparray( randn(M, M, I) );
B = matpy.mat2nparray( randn(M, M, N, J, I) );
C = matpy.mat2nparray( randn(M, J, I) );

D = matpy.nparray2mat( np.einsum('mki, klnji, lji -> mnji', A, B, C) );
在哪里 matpy可以找到here .
选项 2: native MATLAB
这里最重要的部分是正确排列排列,因此我们需要跟踪我们的维度。我们将使用以下顺序:
I(1) J(2) K(3) L(4) M(5) N(6)
现在,我将解释如何获得正确的置换顺序(我们以 A 为例): einsum预计维度顺序为 mki ,根据我们的编号是5 3 1 .这告诉我们 A 的第一个维度需要是第 5 个,第 2 个需要是第 3 个,第 3 个需要是第 1 个(简称 1->5, 2->3, 3->1)。这也意味着“无源维度”(即那些没有原始维度成为它们的维度;在本例中为 2 4 6)应该是单例的。使用 ipermute这写起来真的很简单:
pA = ipermute(A, [5,3,1,2,4,6]);
在上面的例子中,1->5表示我们写 5首先,其他两个维度也是如此(产生 [5,3,1])。然后我们只需在末尾添加单例 (2,4,6) 即可得到 [5,3,1,2,4,6] .最后:
A = randn(M, M, I);
B = randn(M, M, N, J, I);
C = randn(M, J, I);

% Reference dim order: I(1) J(2) K(3) L(4) M(5) N(6)
pA = ipermute(A, [5,3,1,2,4,6]); % 1->5, 2->3, 3->1; 2nd, 4th & 6th are singletons
pB = ipermute(B, [3,4,6,2,1,5]); % 1->3, 2->4, 3->6, 4->2, 5->1; 5th is singleton
pC = ipermute(C, [4,2,1,3,5,6]); % 1->4, 2->2, 3->1; 3rd, 5th & 6th are singletons

pD = sum( ...
  permute(pA .* pB .* pC, [5,6,2,1,3,4]), ... 1->5, 2->6, 3->2, 4->1; 3rd & 4th are singletons
  [5,6]);
(见帖子底部关于 sum 的注释。)
在 MATLAB 中执行此操作的另一种方法,as mentioned by @AndrasDeak , 如下:
rD = squeeze(sum(reshape(A, [M, M, 1, 1, 1, I]) .* ...
                 reshape(B, [1, M, M, N, J, I]) .* ...
... % same as:   reshape(B, [1, size(B)]) .* ...
... % same as:   shiftdim(B,-1) .* ...
                 reshape(C, [1, 1, M, 1, J, I]), [2, 3]));
另见: squeeze , reshape , permute , ipermute , shiftdim .

这是一个完整的示例,显示测试这些方法是否等效:
function q55913093
M = 2;
N = 4;
I = 2000;
J = 300;

mA = randn(M, M, I);
mB = randn(M, M, N, J, I);
mC = randn(M, J, I);

%% Option 1 - using numpy:
np = py.importlib.import_module('numpy');

A = matpy.mat2nparray( mA );
B = matpy.mat2nparray( mB );
C = matpy.mat2nparray( mC );

D = matpy.nparray2mat( np.einsum('mki, klnji, lji -> mnji', A, B, C) );

%% Option 2 - native MATLAB:
%%% Reference dim order: I(1) J(2) K(3) L(4) M(5) N(6)

pA = ipermute(mA, [5,3,1,2,4,6]); % 1->5, 2->3, 3->1; 2nd, 4th & 6th are singletons
pB = ipermute(mB, [3,4,6,2,1,5]); % 1->3, 2->4, 3->6, 4->2, 5->1; 5th is singleton
pC = ipermute(mC, [4,2,1,3,5,6]); % 1->4, 2->2, 3->1; 3rd, 5th & 6th are singletons

pD = sum( permute( ...
  pA .* pB .* pC, [5,6,2,1,3,4]), ... % 1->5, 2->6, 3->2, 4->1; 3rd & 4th are singletons
  [5,6]);

rD = squeeze(sum(reshape(mA, [M, M, 1, 1, 1, I]) .* ...
                 reshape(mB, [1, M, M, N, J, I]) .* ...
                 reshape(mC, [1, 1, M, 1, J, I]), [2, 3]));

%% Comparisons:
sum(abs(pD-D), 'all')
isequal(pD,rD)
运行上面我们得到的结果确实是等价的:
>> q55913093
ans =
   2.1816e-10 
ans =
  logical
   1
注意这两种调用sum的方法是在最近的版本中引入的,因此如果您的 MATLAB 相对较旧,您可能需要替换它们:
S = sum(A,'all')   % can be replaced by ` sum(A(:)) `
S = sum(A,vecdim)  % can be replaced by ` sum( sum(A, dim1), dim2) `

根据评论中的要求,这是比较方法的基准:
function t = q55913093_benchmark(M,N,I,J)

if nargin == 0
  M = 2;
  N = 4;
  I = 2000;
  J = 300;
end

% Define the arrays in MATLAB
mA = randn(M, M, I);
mB = randn(M, M, N, J, I);
mC = randn(M, J, I);

% Define the arrays in numpy
np = py.importlib.import_module('numpy');
pA = matpy.mat2nparray( mA );
pB = matpy.mat2nparray( mB );
pC = matpy.mat2nparray( mC );

% Test for equivalence
D = cat(5, M1(), M2(), M3());
assert( sum(abs(D(:,:,:,:,1) - D(:,:,:,:,2)), 'all') < 1E-8 );
assert( isequal (D(:,:,:,:,2), D(:,:,:,:,3)));

% Time
t = [ timeit(@M1,1), timeit(@M2,1), timeit(@M3,1)]; 

function out = M1()
  out = matpy.nparray2mat( np.einsum('mki, klnji, lji -> mnji', pA, pB, pC) );
end

function out = M2()
  out = permute( ...
          sum( ...
            ipermute(mA, [5,3,1,2,4,6]) .* ...
            ipermute(mB, [3,4,6,2,1,5]) .* ...
            ipermute(mC, [4,2,1,3,5,6]), [3,4]...
          ), [5,6,2,1,3,4]...
        );  
end

function out = M3()
out = squeeze(sum(reshape(mA, [M, M, 1, 1, 1, I]) .* ...
                  reshape(mB, [1, M, M, N, J, I]) .* ...
                  reshape(mC, [1, 1, M, 1, J, I]), [2, 3]));
end

end
在我的系统上,这会导致:
>> q55913093_benchmark
ans =
    1.3964    0.1864    0.2428
这意味着第二种方法更可取(至少对于默认输入大小)。

关于matlab - 多维数组的元素矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55913093/

相关文章:

javascript - for/循环中的计数器变量。为什么一个程序可以运行而另一个程序却失败?

sum - 使用 AVX 一次性完成 4 个水平 double 求和

c++ - 如何在 matlab 中使用 kd-tree 文件交换和 mex?

python - 在 Numpy 中将离散值的一维数组转换为连续值的 n 维数组

c - 从文件创建 1 和 0 的矩阵 NxN 的 C 程序中的段错误

python - 切片除第 n 个之外的每个项目

matlab - 矩阵的 det 在 matlab 中返回 0

image - MATLAB imread bmp 图像不正确

matlab - 如何在 Matlab 中使用谷歌翻译?

Python 一行