c++ - OpenCL。矩阵乘法绕过一些工作项

标签 c++ visual-studio visual-c++ opencl matrix-multiplication

尝试在 OpenCL 中实现矩阵乘法时,我尝试编写自己的方法;但是似乎某些工作项的工作似乎被其他工作项覆盖了,我真的不知道该如何处理。

我真正确定的是问题出在 OpenCL 程序中。

我的主机代码是用 C/C++ 编写的。

程序构建并返回输出(错误,但程序成功退出)。

这是我的方法:

__kernel void matrixMultiplication(
         __global double* matrix1,
         __global double* matrix2,
         __global double* output,
         const unsigned int ROWS_M1, // ROWS_M1 = 3
         const unsigned int ROWS_M1, // COLS_M1 = 2
         const unsigned int ROWS_M2, // ROWS_M2 = 2
         const unsigned int ROWS_M2, // COLS_M2 = 4
         const unsigned int ROWS_M3, // ROWS_M3 = 3
         const unsigned int ROWS_M3) { // COLS_M3 = 4

    int i = get_global_id(0);
    int j = get_global_id(1);

    // for each value in the matrix1 (for each work-item)
    // and for each value in the "jth" row in the second matrix...
    // multiply the values and then add them according to the right offset.

    for(int k =0; k < COLS_M2; k++){
        int offsetM1 = (i*COLS_M1)+j;
        int offsetM2 = (j*COLS_M2)+k;
        int offsetM3 = (i*COLS_M3)+k;

        //output[i][k] += matrix1[i][j]*matrix2[j][k];
        output[offsetM3] += matrix1[offsetM1]*matrix2[offsetM2];
    }

}

为每个“const unsigned int”设置的值在代码中指定。

矩阵的值是:

矩阵1:

1 2
3 4
5 6

矩阵2:

2 3 4 5
6 7 8 9

给定输出:

12 14 16 18
24 28 32 36
36 42 48 54

期望的输出:

14 17 20 23
30 37 44 51
46 57 68 79

最佳答案

我认为您在索引方面做错了。 *offsetM3* 应该等于 *i\*COLS_M3+j**offsetM1* 应该等于 * i\*COLS_M1+k**offsetM2**k\*COLS_M2+j*

将矩阵写在纸上并进行数学运算,然后将矩阵写在数组中,就像在内存中一样,然后将它们相乘,然后您将看到索引模式。请记住,每个线程(工作项)都用于新矩阵的一个元素。如果您通过 for 循环更改新矩阵的索引,则您没有遵循一个矩阵元素一个工作项的逻辑,如果您希望这样做,您应该考虑另一种逻辑。 希望这有帮助

关于c++ - OpenCL。矩阵乘法绕过一些工作项,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48372046/

相关文章:

c++ - 这段 C 代码有什么问题?

c++ - 在 C++ 中使用重载 * 运算符与括号结果相乘

c++ - 在 C++ 中实现 Set of octets 类

java - 有没有办法用java来执行一系列的命令行

C++ 数组对坐标周围的值求和

c++ - 搜索字符串数组

c++ - 无需单例即可将 QML 与 C++ 连接

c# - 如何在 WPF Visual Studio 设计器中显示占位符值,直到可以加载实际值

c# - 用于变量分配的 Visual Studio 自动缩进

c++ - 如何禁用 VC++ 中断代码块的特定 win32 异常类型?