python - 如何查找 numpy 数组中的每条步行

标签 python arrays numpy optimization

我试图通过数组找到长度为 n 的每一个“行走”。在这种情况下,步行被定义为数组中相邻元素(水平、对角线或垂直)的长度为 n 的序列,以便重复该点。例如,2x2 矩阵

[1 2]
[4 8]

路径长度为 2: (1, 2), (1, 4), (1, 8), (2, 1), (2, 4), (2, 8) ...
长度为 3 的步行:(1, 2, 4)、(1, 2, 8)、(1, 4, 2)、(1, 4, 8) ...等等

如何在 python/numpy 中快速实现小型(5x5)矩阵的算法,可能使用我目前不知道的数学的某些方面?

当前实现缓慢:

from copy import deepcopy

def get_walks(arr, n):
    n = n-1
    dim_y = len(arr)
    dim_x = len(arr[0])

    # Begin with every possibly starting location
    walks = [[(y, x)] for y in range(dim_y) for x in range(dim_x)]

    # Every possible direction to go in
    directions = [(0,1), (1,1), (1,0), (1, -1), (0, -1), (-1,-1), (-1, 0), (-1, 1)]

    temp_walks = []
    for i in range(n):
        # Go through every single current walk and add every 
        # possible next move to it, making sure to not repeat any points
        #
        # Do this n times
        for direction in directions:
            for walk in walks:
                y, x = walk[-1]
                y, x = y+direction[0], x+direction[1]
                if -1 < y < dim_y and -1 < x < dim_x and (y, x) not in walk:
                    temp_walks.append(walk + [(y, x)])

        # Overwrite current main walks list with the temporary one and start anew
        walks = deepcopy(temp_walks)
        temp_walks = []

    return walks

最佳答案

我想出了一个递归解决方案。由于您只想处理小问题,因此这种方法是可行的。我没有为 python 3 安装 numpy,所以这只能保证按原样适用于 python 2(但它应该相当兼容)。另外,我很确定我的实现远非最佳。

当检查我的输出与你的输出时,我发现对于 3x3 的情况,我得到 200 条路径,而你得到 160 条。查看路径,我认为你的代码有一些错误,而你是唯一缺少路径的人(而不是我有额外的)。这是我的版本:

import numpy as np
import timeit

def get_walks_rec(shape,inpath,ij,n):
    # add n more steps to mypath, with dimensions shape
    # procedure: call shorter walks for allowed neighbouring sites


    mypath = inpath[:]
    mypath.append(ij)

    # return if this is the last point
    if n==0:
        return mypath

    i0 = ij[0]
    j0 = ij[1]

    neighbs = [(i,j) for i in (i0-1,i0,i0+1) for j in (j0-1,j0,j0+1) if 0<=i<shape[0] and 0<=j<shape[1] and (i,j)!=(i0,j0)]
    subpaths = [get_walks_rec(shape,mypath,neighb,n-1) for neighb in neighbs]

    # flatten out the sublists for higher levels
    if n>1:
        flatpaths = []
        map(flatpaths.extend,subpaths)
    else:
        flatpaths = subpaths

    return flatpaths

# front-end for recursive function, called only once
def get_walks_rec_caller(mat,n):
    # collect all the paths starting from each point of the matrix

    sh = mat.shape
    imat,jmat = np.meshgrid(np.arange(sh[0]),np.arange(sh[1]))
    tmppaths = [get_walks_rec(sh,[],ij,n-1) for ij in zip(imat.ravel(),jmat.ravel())]

    # flatten the list of lists of paths to a single list of paths
    allpaths = []
    map(allpaths.extend,tmppaths)

    return allpaths

# input
mat = np.random.rand(3,3)
nmax = 3

# original:
walks_old = get_walks(mat,nmax)

# new recursive:
walks_new = get_walks_rec_caller(mat,nmax)

# timing:
number = 1000
print(timeit.timeit('get_walks(mat,nmax)','from __main__ import get_walks,mat,nmax',number=number))
print(timeit.timeit('get_walks_rec_caller(mat,nmax)','from __main__ import get_walks_rec_caller,mat,nmax',number=number))

对于最大路径长度为 3 的 3x3 情况,使用 timeit 运行 1000 次,你的运行时间为 1.81 秒,而我的运行时间为 0.53 秒(并且你丢失了 20% 的路径)。对于最大长度为 4 的 4x4 情况,100 次运行需要 2.1 秒(你的),而 0.67 秒(我的)。

一个示例路径,它存在于我的路径中,但似乎在你的路径中缺失:

[(0, 0), (0, 1), (0, 0)]

关于python - 如何查找 numpy 数组中的每条步行,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35955750/

相关文章:

arrays - 如何找到数组的平均值 - Swift

Python 的 sympy 求解器返回四次方程的坏根

python - 从字节文件中查找拜耳模式格式

python - HTML 代码在模板中显示为文本

python - 在 python 中执行此字符串模式替换的最快方法是什么?

python - 多核cpu中内核线程和用户线程之间的区别?

python - 创建具有一个数组形状和列表中的值的 numpy 数组

python - 更改字符串的时间格式

javascript - 这是什么类型的 JavaScript 数据类型?

java - 将对象添加到 LinkedList java 数组