我创建了一个无限迭代器,它以螺旋方式将自然数映射到所有格点,类似于乌拉姆螺旋:
代码如下,我已使其尽可能快,并且没有使用单个 if 条件:
from itertools import islice, repeat
def ulamish_spiral_gen():
xc = yc = length = 0
yield 0, 0
while True:
length += 1
yield from zip(range(xc + 1, (xc := xc + length) + 1, 1), repeat(yc))
yield from zip(repeat(xc), range(yc + 1, (yc := yc + length) + 1, 1))
length += 1
yield from zip(range(xc - 1, (xc := xc - length) - 1, -1), repeat(yc))
yield from zip(repeat(xc), range(yc - 1, (yc := yc - length) - 1, -1))
def ulamish_spiral(n):
return list(islice(ulamish_spiral_gen(), n))
我想知道,如何记住无限迭代器的输出,以便仅当 n 的值时才调用
大于最后一个 list(islice(ulamish_spiral_gen(), n))
n
。
类似这样的事情:
COMPUTED = []
def ulamish_spiral(n):
global COMPUTED
if n > len(COMPUTED):
COMPUTED = list(islice(ulamish_spiral_gen(), n))
return COMPUTED[:n]
这很简单,但是第一个 len(COMPUTED)
项已经计算完毕,只需要计算 range(len(COMPUTED), n)
中的项已计算,但调用计算所有已计算的项。因此,我尝试重用相同的生成器对象,并且只请求下一个 n - len(COMPUTED)
项,并且我成功了。
但这样做实际上会使代码变慢:
COMPUTED = []
ULAMISH_GEN = ulamish_spiral_gen()
def ulamish_spiral(n):
if n > (l := len(COMPUTED)):
COMPUTED.extend(islice(ULAMISH_GEN, n - l))
return COMPUTED[:n]
In [225]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(8192)
928 µs ± 8.96 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [226]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(1024); ulamish_spiral(2048); ulamish_spiral(4096); ulamish_spiral(8192)
993 µs ± 18.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [227]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(1024); ulamish_spiral(2048); ulamish_spiral(4096); ulamish_spiral(8192); ulamish_spiral(16384)
2.14 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [228]: %timeit COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(16384)
2 ms ± 88.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [229]: COMPUTED.clear(); ULAMISH_GEN = ulamish_spiral_gen(); ulamish_spiral(1024); ulamish_spiral(2048); ulamish_spiral(16384) == list(islice(ulamish_spiral_gen(), 16384))
Out[229]: True
如何跳过已经计算的项并使代码更快?
最佳答案
您已经介绍了“memoize”/“skip”。这是更快的无限迭代器。前 16384 个坐标的时间(如您的基准测试中所示):
1.01 ± 0.00 ms gen_Kelly_4
1.01 ± 0.00 ms gen_Kelly_5
1.03 ± 0.00 ms gen_Kelly_3
1.06 ± 0.00 ms gen_Kelly_2
1.21 ± 0.00 ms gen_Kelly_1
1.57 ± 0.00 ms gen_original
Python: 3.11.4 (main, Jun 24 2023, 10:18:04) [GCC 13.1.1 20230429]
您将所有 zip
迭代器与您自己的生成器结合起来。我的 gen_Kelly_1
使用 chain.from_iterable
来实现此目的。
您使用range
函数一遍又一遍地生成int
对象。相反,我的 gen_Kelly_2
将它们保存在列表中并重复使用它们。
我的 gen_Kelly_3
进一步重用了 repeat
迭代器,gen_Kelly_4
反转了我的列表而不是使用反向迭代器,gen_Kelly_5
> 删除现在重复的代码(i
经过 0、1、-1、2、-2、3、-3 等):
def gen_Kelly_5():
def parts():
i = 0
range = []
while True:
rep = repeat(i)
yield zip(rep, range)
range.append(i)
range.reverse()
yield zip(range, rep)
i = (i<1) - i
return chain.from_iterable(parts())
完整代码(Attempt This Online!):
from timeit import timeit
from statistics import mean, stdev
from itertools import islice, repeat, cycle, chain, count
import sys
def gen_original():
xc = yc = length = 0
yield 0, 0
while True:
length += 1
yield from zip(range(xc + 1, (xc := xc + length) + 1, 1), repeat(yc))
yield from zip(repeat(xc), range(yc + 1, (yc := yc + length) + 1, 1))
length += 1
yield from zip(range(xc - 1, (xc := xc - length) - 1, -1), repeat(yc))
yield from zip(repeat(xc), range(yc - 1, (yc := yc - length) - 1, -1))
def gen_Kelly_1():
def parts():
xc = yc = length = 0
yield (0, 0),
while True:
length += 1
yield zip(range(xc + 1, (xc := xc + length) + 1, 1), repeat(yc))
yield zip(repeat(xc), range(yc + 1, (yc := yc + length) + 1, 1))
length += 1
yield zip(range(xc - 1, (xc := xc - length) - 1, -1), repeat(yc))
yield zip(repeat(xc), range(yc - 1, (yc := yc - length) - 1, -1))
return chain.from_iterable(parts())
def gen_Kelly_2():
def parts():
i = 0
range = []
while True:
yield zip(repeat(-i), reversed(range))
range.insert(0, -i)
yield zip(range, repeat(-i))
i += 1
yield zip(repeat(i), range)
range.append(i)
yield zip(reversed(range), repeat(i))
return chain.from_iterable(parts())
def gen_Kelly_3():
def parts():
i = 0
range = []
while True:
rep = repeat(-i)
yield zip(rep, reversed(range))
range.insert(0, -i)
yield zip(range, rep)
i += 1
rep = repeat(i)
yield zip(rep, range)
range.append(i)
yield zip(reversed(range), rep)
return chain.from_iterable(parts())
def gen_Kelly_4():
def parts():
i = 0
range = []
while True:
rep = repeat(-i)
yield zip(rep, range)
range.append(-i)
range.reverse()
yield zip(range, rep)
i += 1
rep = repeat(i)
yield zip(rep, range)
range.append(i)
range.reverse()
yield zip(range, rep)
return chain.from_iterable(parts())
def gen_Kelly_5():
def parts():
i = 0
range = []
while True:
rep = repeat(i)
yield zip(rep, range)
range.append(i)
range.reverse()
yield zip(range, rep)
i = (i<1) - i
return chain.from_iterable(parts())
funcs = gen_original, gen_Kelly_1, gen_Kelly_2, gen_Kelly_3, gen_Kelly_4, gen_Kelly_5
n = 16384
# Correctness
expect = list(islice(funcs[0](), n))
for f in funcs[1:]:
result = list(islice(f(), n))
assert result == expect
# Speed
times = {f: [] for f in funcs}
def stats(f):
ts = [t * 1e3 for t in sorted(times[f])[:10]]
return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} ms '
for _ in range(1000):
for f in funcs:
t = timeit(lambda: list(islice(f(), n)), number=1)
times[f].append(t)
for f in sorted(funcs, key=stats):
print(stats(f), f.__name__)
print('\nPython:', sys.version)
关于python - 如何优化ulam螺旋无限迭代器?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/76755607/