python - 如何将这个函数向量化?

标签 python numpy

我有一个具有以下属性的 NumPy 数组:

  • 形状:(9986080, 2)
  • 数据类型:np.float32

  • 我有一个方法可以遍历数组的范围,执行一个操作,然后将结果输入到新数组中:
    def foo(arr):
        new_arr = np.empty(arr.size, dtype=np.uint64)
        for i in range(arr.size):
            x, y = arr[i]
            e, n = ''
            if x < 0:
                e = '1'
            else:
                w = '2'
            if y > 0:
                n = '3'
            else:
                s = '4'
            new_arr[i] = int(f'{abs(x)}{e}{abs(y){n}'.replace('.', ''))
    

    最佳答案

    我同意 Iguananaut 的评论,即这种数据结构似乎有点奇怪。我最大的问题是尝试矢量化将字符串中的整数放在一起然后将其重新转换为整数真的很棘手。尽管如此,这肯定有助于加速该功能:

    def foo(arr):
        x_values = arr[:,0]
        y_values = arr[:,1]
        ones = np.ones(arr.shape[0], dtype=np.uint64)
        e = np.char.array(np.where(x_values < 0, ones, ones * 2))
        n = np.char.array(np.where(y_values < 0, ones * 3, ones * 4))
        x_values = np.char.array(np.absolute(x_values))
        y_values = np.char.array(np.absolute(y_values))
        x_values = np.char.replace(x_values, '.', '')
        y_values = np.char.replace(y_values, '.', '')
        new_arr = np.char.add(np.char.add(x_values, e), np.char.add(y_values, n))
        return new_arr.astype(np.uint64)
    
    在这里,输入数组的 x 和 y 值首先被拆分。然后我们使用矢量化计算来确定哪里 en应该是 1 或 2、3 或 4。最后一行使用标准列表推导式进行字符串合并位,这对于超大数组来说仍然慢得令人不快,但比常规 for 循环快。同样矢量化之前的计算应该会大大加快函数的速度。
    编辑:
    我之前弄错了。 Numpy 确实有一种使用 np.char.add() 方法处理字符串连接的好方法。这需要转换 x_valuesy_values使用 np.char.array() 到 Numpy 字符数组.同样出于某种原因,np.char.add()方法只需要两个数组作为输入,所以需要先拼接x_valuesey_valuesn然后连接这些结果。尽管如此,这会向量化计算并且应该非常快。由于您正在执行的操作相当奇怪,代码仍然有点笨拙,但我认为这将帮助您大大加快功能。

    关于python - 如何将这个函数向量化?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63601507/

    相关文章:

    python - 文件夹存在时 os.path.isdir 返回 false?

    arrays - numpy.add.at 比就地添加慢?

    python - pandas 数据帧到 numpy 数组,没有烦人的 dtype ="blah blah"

    python - 需要在 python 中比较 1.5GB 左右的非常大的文件

    python - 将数组行从逗号拆分为列

    python matplotlib : unable to call FuncAnimation from inside a function

    python - Django Rest Framework默认图像字段值返回验证

    python - Plotly:如何在绘图中添加垂直线?

    python - Timeit, NameError : global name is not defined. 但我没有使用全局变量

    python - 如何使用 numpy 使 for 循环更快