python - numpy 对角线函数很慢

标签 python numpy

我想在 python 中实现 connect 4 游戏作为一个业余爱好项目,但我不知道为什么在对角线上搜索匹配项如此缓慢。 当使用 psstats 分析我的代码时,我发现这是瓶颈。 我想构建一个计算机敌人来分析游戏中数千个 future 步骤,因此性能是一个问题。

有谁知道如何提高以下代码的性能? 我选择 numpy 来做这件事,因为我认为这会加快速度。 问题是,我找不到避免 for 循环的方法。

import numpy as np

# Finds all the diagonal and off-diagonal-sequences in a 7x6 numpy array
def findseq(sm,seq=2,redyellow=1):
    matches=0
    # search in the diagonals
    # diags stores all the diagonals and off diagonals as rows of a matrix
    diags=np.zeros((1,6),dtype=np.int8)
    for k in range(-5,7):   
        t=np.zeros(6,dtype=np.int8)
        a=np.diag(sm,k=k).copy()
        t[:len(a)] += a
        s=np.zeros(6,dtype=np.int8)
        a=np.diag(np.fliplr(sm),k=k).copy()
        s[:len(a)] += a
        diags=np.concatenate(( diags,t[None,:],s[None,:]),axis=0)
    diags=np.delete(diags,0,0)
    # print(diags)
    # now, search for sequences
    Na=np.size(diags,axis=1)
    n=np.arange(Na-seq+1)[:,None]+np.arange(seq)
    seqmat=np.all(diags[:,n]==redyellow,axis=2)
    matches+=seqmat.sum()

    return matches

def randomdebug():
    # sm=np.array([[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,2,1,1,0,0]])
    sm=np.random.randint(0,3,size=(6,7))
    return sm

# in my main program, I need to do this thousands of times
matches=[]
for i in range(1000):
    sm=randomdebug()
    matches.append(findseq(sm,seq=3,redyellow=1))
    matches.append(findseq(sm,seq=3,redyellow=2))
    # print(sm)
    # print(findseq(sm,seq=3))

这是psstats

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2000    1.965    0.001    4.887    0.002 Frage zu diag.py:4(findseq)
151002/103002    0.722    0.000    1.979    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
    48000    0.264    0.000    0.264    0.000 {method 'diagonal' of 'numpy.ndarray' objects}
    48072    0.251    0.000    0.251    0.000 {method 'copy' of 'numpy.ndarray' objects}
    48000    0.209    0.000    0.985    0.000 twodim_base.py:240(diag)
    48000    0.179    0.000    1.334    0.000 <__array_function__ internals>:177(diag)
    50000    0.165    0.000    0.165    0.000 {built-in method numpy.zeros}

我是 python 新手,所以请想象一个标签“hopeless noob”;-)

最佳答案

正如 Andrey 在评论中所述,代码正在调用许多需要额外内存分配的 np 函数。我相信这就是瓶颈。

我建议预先计算所有对角线的索引,因为它们在您的情况下不会改变太多(矩阵形状保持不变,我猜序列可能会改变)。然后您可以使用它们快速解决对角线问题:

import numpy as np


known_diagonals = dict()
def diagonal_indices(h: int, w: int, length: int = 3) -> np.array:
    '''
    Returns array (shape diagonal_count x length) of diagonal indices
    of a flatten array
    '''
    # one of many ways to store precomputed function output
    # cleaner way would probably be to do this outside this function
    diagonal_indices_key = (h, w, length)
    if diagonal_indices_key in known_diagonals:
        return known_diagonals[diagonal_indices_key]
    
    diagonals_count = (h + 1 - length) * (w + 1 - length) * 2

    # default value is meant to ease process with cumsum:
    # adding h + 1 selects an index 1 down and 1 right, h - 1 index 1 down 1 left
    # firts half dedicated to right down diagonals
    diagonals = np.full((diagonals_count, length), w + 1, dtype=np.int32)
    # second half dedicated to left down diagonals
    diagonals[diagonals_count//2::] = w - 1

    # this could have been calculated mathematicaly
    flat_indices = np.arange(w * h).reshape((h, w))
    # print(flat_indices)

    # selects rectangle offseted by l - 1 from right and down edges
    diagonal_starts_rd = flat_indices[:h + 1 - length, :w + 1 - length]
    # selects rectangle offseted by l - 1 from left and down edges
    diagonal_starts_ld = flat_indices[:h + 1 - length, -(w + 1 - length):]
    
    # sets starts
    diagonals[:diagonals_count//2, 0] = diagonal_starts_rd.flatten()
    diagonals[diagonals_count//2::, 0] = diagonal_starts_ld.flatten()

    # sum triplets left to right
    # diagonals contains triplets (or vector of other length) of (start, h+-1, h+-1). cumsum makes diagonal indices
    diagonals = diagonals.cumsum(axis=1)

    # save ouput
    known_diagonals[diagonal_indices_key] = diagonals

    return diagonals

# Finds all the diagonal and off-diagonal-sequences in a 7x6 numpy array
def findseq(sm: np.array, seq: int = 2, redyellow: int = 1) -> int:
    matches = 0
    diagonals = diagonal_indices(*sm.shape, seq)

    seqmat = np.all(sm.flatten()[diagonals] == redyellow, axis=1)
    matches += seqmat.sum()

    return matches

def randomdebug():
    # sm=np.array([[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,2,1,1,0,0]])
    sm=np.random.randint(0,3,size=(6,7))
    return sm

# in my main program, I need to do this thousands of times
matches=[]
for i in range(1000):
    sm=randomdebug()
    matches.append(findseq(sm,seq=3,redyellow=1))
    matches.append(findseq(sm,seq=3,redyellow=2))
    # print(sm)
    # print(findseq(sm,seq=3))

关于python - numpy 对角线函数很慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74178272/

相关文章:

javascript - Python 和 Javascript 正则表达式有什么不同?

python - 如何在Python中将txt文件作为数据加载?

python - 在 python 或 numpy 中合并记录

python - 为同一进程创建新控制台

python - thor websocket 负载测试 - 在 Web 套接字握手请求中添加自定义 http header

python - 序列化大型 scipy 稀疏矩阵的最佳方法是什么?

python exe文件在Windows xp上启动时崩溃

python - 如何将 numpy 数组呈现到 pygame 表面?

Python 字典多值

python - 返回 pandas df 中值的运行计数