python - 如何优化ulam螺旋无限迭代器?

标签 python python-3.x performance optimization

我创建了一个无限迭代器,它以螺旋方式将自然数映射到所有格点,类似于乌拉姆螺旋:

代码如下,我已使其尽可能快,并且没有使用单个 if 条件:

from itertools import islice, repeat

def ulamish_spiral_gen():
    xc = yc = length = 0
    yield 0, 0
    while True:
        length += 1
        yield from zip(range(xc + 1, (xc := xc + length) + 1, 1), repeat(yc))
        yield from zip(repeat(xc), range(yc + 1, (yc := yc + length) + 1, 1))
        length += 1
        yield from zip(range(xc - 1, (xc := xc - length) - 1, -1), repeat(yc))
        yield from zip(repeat(xc), range(yc - 1, (yc := yc - length) - 1, -1))

def ulamish_spiral(n):
    return list(islice(ulamish_spiral_gen(), n))

我想知道,如何记住无限迭代器的输出,以便仅当 n 的值时才调用 list(islice(ulamish_spiral_gen(), n)) 大于最后一个 n

类似这样的事情:

COMPUTED = []

def ulamish_spiral(n):
    global COMPUTED
    if n > len(COMPUTED):
        COMPUTED = list(islice(ulamish_spiral_gen(), n))
    return COMPUTED[:n]

这很简单,但是第一个 len(COMPUTED) 项已经计算完毕,只需要计算 range(len(COMPUTED), n) 中的项已计算,但调用计算所有已计算的项。因此,我尝试重用相同的生成器对象,并且只请求下一个 n - len(COMPUTED) 项,并且我成功了。

但这样做实际上会使代码变慢:

COMPUTED = []
ULAMISH_GEN = ulamish_spiral_gen()
def ulamish_spiral(n):
    if n > (l := len(COMPUTED)):
        COMPUTED.extend(islice(ULAMISH_GEN, n - l))
    return COMPUTED[:n]
In [225]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(8192)
928 µs ± 8.96 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [226]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(1024); ulamish_spiral(2048); ulamish_spiral(4096); ulamish_spiral(8192)
993 µs ± 18.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [227]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(1024); ulamish_spiral(2048); ulamish_spiral(4096); ulamish_spiral(8192); ulamish_spiral(16384)
2.14 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [228]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(16384)
2 ms ± 88.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [229]: COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(1024); ulamish_spiral(2048); ulamish_spiral(16384) == list(islice(ulamish_spiral_gen(), 16384))
Out[229]: True

如何跳过已经计算的项并使代码更快?

最佳答案

您已经介绍了“memoize”/“skip”。这是更快的无限迭代器。前 16384 个坐标的时间(如您的基准测试中所示):

  1.01 ± 0.00 ms  gen_Kelly_4
  1.01 ± 0.00 ms  gen_Kelly_5
  1.03 ± 0.00 ms  gen_Kelly_3
  1.06 ± 0.00 ms  gen_Kelly_2
  1.21 ± 0.00 ms  gen_Kelly_1
  1.57 ± 0.00 ms  gen_original

Python: 3.11.4 (main, Jun 24 2023, 10:18:04) [GCC 13.1.1 20230429]

您将所有 zip 迭代器与您自己的生成器结合起来。我的 gen_Kelly_1 使用 chain.from_iterable 来实现此目的。

您使用range函数一遍又一遍地生成int对象。相反,我的 gen_Kelly_2 将它们保存在列表中并重复使用它们。

我的 gen_Kelly_3 进一步重用了 repeat 迭代器,gen_Kelly_4 反转了我的列表而不是使用反向迭代器,gen_Kelly_5 > 删除现在重复的代码(i 经过 0、1、-1、2、-2、3、-3 等):

def gen_Kelly_5():
    def parts():
        i = 0
        range = []
        while True:
            rep = repeat(i)
            yield zip(rep, range)
            range.append(i)
            range.reverse()
            yield zip(range, rep)
            i = (i<1) - i
    return chain.from_iterable(parts())

完整代码(Attempt This Online!):

from timeit import timeit
from statistics import mean, stdev
from itertools import islice, repeat, cycle, chain, count
import sys

def gen_original():
    xc = yc = length = 0
    yield 0, 0
    while True:
        length += 1
        yield from zip(range(xc + 1, (xc := xc + length) + 1, 1), repeat(yc))
        yield from zip(repeat(xc), range(yc + 1, (yc := yc + length) + 1, 1))
        length += 1
        yield from zip(range(xc - 1, (xc := xc - length) - 1, -1), repeat(yc))
        yield from zip(repeat(xc), range(yc - 1, (yc := yc - length) - 1, -1))


def gen_Kelly_1():
    def parts():
        xc = yc = length = 0
        yield (0, 0),
        while True:
            length += 1
            yield zip(range(xc + 1, (xc := xc + length) + 1, 1), repeat(yc))
            yield zip(repeat(xc), range(yc + 1, (yc := yc + length) + 1, 1))
            length += 1
            yield zip(range(xc - 1, (xc := xc - length) - 1, -1), repeat(yc))
            yield zip(repeat(xc), range(yc - 1, (yc := yc - length) - 1, -1))
    return chain.from_iterable(parts())


def gen_Kelly_2():
    def parts():
        i = 0
        range = []
        while True:
            yield zip(repeat(-i), reversed(range))
            range.insert(0, -i)
            yield zip(range, repeat(-i))
            i += 1
            yield zip(repeat(i), range)
            range.append(i)
            yield zip(reversed(range), repeat(i))
    return chain.from_iterable(parts())


def gen_Kelly_3():
    def parts():
        i = 0
        range = []
        while True:
            rep = repeat(-i)
            yield zip(rep, reversed(range))
            range.insert(0, -i)
            yield zip(range, rep)
            i += 1
            rep = repeat(i)
            yield zip(rep, range)
            range.append(i)
            yield zip(reversed(range), rep)
    return chain.from_iterable(parts())


def gen_Kelly_4():
    def parts():
        i = 0
        range = []
        while True:
            rep = repeat(-i)
            yield zip(rep, range)
            range.append(-i)
            range.reverse()
            yield zip(range, rep)
            i += 1
            rep = repeat(i)
            yield zip(rep, range)
            range.append(i)
            range.reverse()
            yield zip(range, rep)
    return chain.from_iterable(parts())


def gen_Kelly_5():
    def parts():
        i = 0
        range = []
        while True:
            rep = repeat(i)
            yield zip(rep, range)
            range.append(i)
            range.reverse()
            yield zip(range, rep)
            i = (i<1) - i
    return chain.from_iterable(parts())


funcs = gen_original, gen_Kelly_1, gen_Kelly_2, gen_Kelly_3, gen_Kelly_4, gen_Kelly_5

n = 16384

# Correctness
expect = list(islice(funcs[0](), n))
for f in funcs[1:]:
    result = list(islice(f(), n))
    assert result == expect

# Speed
times = {f: [] for f in funcs}
def stats(f):
    ts = [t * 1e3 for t in sorted(times[f])[:10]]
    return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} ms '
for _ in range(1000):
    for f in funcs:
        t = timeit(lambda: list(islice(f(), n)), number=1)
        times[f].append(t)
for f in sorted(funcs, key=stats):
    print(stats(f), f.__name__)
print('\nPython:', sys.version)

关于python - 如何优化ulam螺旋无限迭代器?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/76755607/

相关文章:

python - 在 sql 和 pyodbc 中参数化 TOP 值

python - 类型错误 : 'bool' object is not iterable

python - 在 Pandas 数据框中高效快速地查找和匹配唯一值

sql-server - SQL Server 与 Access 插入性能对比,尤其是在使用 GUID 时

c# - 使用 Group By 优化 LINQ 性能

java - 循环是否每次都需要更长的时间来执行?

python Postgresql : Ignoring the last column from csv file

python - 导入 xarray 和 odc python 包时出现打字错误

python - pandas read_csv 忽略最后一列中的分隔符

python - 如何运行 python 版本 3.6.2 而不是 3.6.1