python - 如何在 PyOpenCL 中覆盖数组元素

标签 python arrays indexing pyopencl

我想用另一个数组覆盖 PyOpenCL 数组的一部分。 这么说吧

import numpy as np, pyopencl.array as cla
a = cla.zeros(queue,(3,3),'int8')
b = cla.ones(queue,(2,2),'int8')

现在我想做一些类似 a[0:2,0:2] = b 的事情,并希望得到

1 1 0
1 1 0
0 0 0

出于速度原因,如果不将所有内容复制到主机,我该如何做到这一点?

最佳答案

Pyopencl 数组能够做到这一点 - 在回答此问题时的范围非常有限 - 使用 numpy 语法(即您具体如何编写它),限制是:您只能使用沿第一个轴切片。

import numpy as np, pyopencl.array as cla

a = cla.zeros(queue,(3,3),'int8')
b = cla.ones(queue,(2,3),'int8')
# note b is 2x3 here 
a[0:2]=b #<-works
a[0:2,0:2]=b[:,0:2] #<-Throws an error about non-contiguity

因此,a[0:2,0:2] = b 将不起作用,因为目标切片数组具有非连续数据。

我知道的唯一解决方案(因为 pyopencl.array 类中没有任何内容能够处理切片数组/非连续数据),就是编写自己的 openCL 内核来执行“手工”复制。

这是我用来在所有数据类型的 1D 或 2D pyopencl 数组上进行复制的一段代码:

import numpy as np, pyopencl as cl, pyopencl.array as cla
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
kernel = cl.Program(ctx, """__kernel void copy(
            __global char *dest,      const int offsetd, const int stridexd, const int strideyd,
            __global const char *src, const int offsets, const int stridexs, const int strideys,
            const int word_size) {

            int write_idx = offsetd + get_global_id(0) + get_global_id(1) * stridexd + get_global_id(2) * strideyd ;
            int read_idx  = offsets + get_global_id(0) + get_global_id(1) * stridexs + get_global_id(2) * strideys;
            dest[write_idx] =  src[read_idx];

            }""").build()

def copy(dest,src):
    assert dest.dtype == src.dtype
    assert dest.shape == src.shape
    if len(dest.shape) == 1 :
        dest.shape=(dest.shape[0],1)
        src.shape=(src.shape[0],1)
        dest.strides=(dest.strides[0],0)
        src.strides=(src.strides[0],0)
    kernel.copy(queue, (src.dtype.itemsize,src.shape[0],src.shape[1]), None, dest.base_data, np.uint32(dest.offset), np.uint32(dest.strides[0]), np.uint32(dest.strides[1]), src.base_data, np.uint32(src.offset), np.uint32(src.strides[0]), np.uint32(src.strides[1]), np.uint32(src.dtype.itemsize))


a = cla.zeros(queue,(3,3),'int8')
b = cla.ones(queue,(2,2),'int8')

copy(a[0:2,0:2],b)
print(a)

关于python - 如何在 PyOpenCL 中覆盖数组元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32527196/

相关文章:

sql-server - SQL Server 2008 中的索引与唯一键

python - wxPython 网格函数在第一次传递到 wx.CallAfter 之后突然停止“可调用”

python - 使用滚动窗口从数据帧创建 "buffer"矩阵?

python - 如何使用 tweepy 获取关注者数量

c++ - 如何将二维数组传递和返回给函数

java - 在返回语句和平均问题中找不到符号

python - 在 Windows 上的 Click 命令行界面上修改 Usage 字符串

java - 将图形网格与二维数组网格相匹配

java - 是什么导致了 java.lang.ArrayIndexOutOfBoundsException 以及如何防止它?

database - 如果表上有索引,如何处理表更新/插入