我有一个脚本可以累积(计算)两个文件中包含的字节。字节是类似于 C 的 unsigned char
0 到 255 之间的整数值。
此累加器脚本的目标是计算这两个文件中字节的联合计数或频率。可能会将其扩展到多个文件/维度。
这两个文件大小相同,但都很大,大约有 6 TB 左右。
我正在使用 numpy.uint64
值,因为我使用 Python 的 int
类型遇到溢出问题。
我有一个长度为 255**2
的一维累加器数组,用于存储联合计数。
我通过逐行到数组的偏移计算来计算偏移,以便增加正确索引处的联合频率。我以字节 block (n_bytes
) 的形式遍历这两个文件,解压它们,并增加频率计数器。
这是代码的粗略草图:
import numpy
import ctypes
import struct
buckets_per_signal_type = 2**(ctypes.c_ubyte(1).value * 8)
total_buckets = buckets_per_signal_type**2
buckets = numpy.zeros((total_buckets,), dtype=numpy.uint64)
# open file handles to two files (omitted for brevity...)
# buffer size that is known ahead of time to be a divisible
# unit of the original files
# (for example, here, reading in 2.4e6 bytes per loop iteration)
n_bytes = 2400000
total_bytes = 0L
# format used to unpack bytes
struct_format = "=%dB" % (n_bytes)
while True:
# read in n_bytes chunk from each file
first_file_bytes = first_file_handle.read(n_bytes)
second_file_bytes = second_file_handle.read(n_bytes)
# break if both file handles have nothing left to read
if len(first_file_bytes) == 0 and len(second_file_bytes) == 0:
break
# unpack actual bytes
first_bytes_unpacked = struct.unpack(struct_format, first_file_bytes)
second_bytes_unpacked = struct.unpack(struct_format, second_file_bytes)
for index in range(0, n_bytes):
first_byte = first_bytes_unpacked[index]
second_byte = second_bytes_unpacked[index]
offset = first_byte * buckets_per_signal_type + second_byte
buckets[offset] += 1
total_bytes += n_bytes
# repeat until both file handles are both EOF
# print out joint frequency (omitted)
与我使用 int
的版本相比,这非常慢,慢了一个数量级。原始作业在大约 8 小时内完成(由于溢出而错误地完成),而这个基于 numpy 的版本必须提前退出,因为它似乎需要大约 12-14 天才能完成。
要么 numpy 在这个基本任务上非常慢,要么我没有以类似于 Python 的方式使用 numpy 做累加器。我怀疑是后者,这就是为什么我向 SO 寻求帮助。
我读到了numpy.add.at
,但是我要添加到buckets
数组中的解压字节数组没有自然转换为“形状”的偏移值buckets
数组的“。
有没有一种方法可以存储和递增(长)整数数组,并且不会溢出,并且性能相当好?
我想我可以用 C 重写这个,但我希望 numpy 中有一些我忽略的东西可以快速解决这个问题。感谢您的建议。更新
我有旧版本的 numpy 和 scipy,不支持 numpy.add.at
。所以这是另一个需要研究的问题。
我将尝试以下操作,看看效果如何:
first_byte_arr = np.array(first_bytes_unpacked)
second_byte_arr = np.array(second_bytes_unpacked)
offsets = first_byte_arr * buckets_per_signal_type + second_byte_arr
np.add.at(buckets, offsets, 1L)
希望它运行得快一点!
更新二
使用 np.add.at
和 np.array
,这项工作大约需要 12 天才能完成。我现在将放弃 numpy 并返回使用 C 读取原始字节,其中运行时间更合理一些。谢谢大家的建议!
最佳答案
如果不尝试跟踪所有文件读取和 struct
代码,看起来您正在将 1
添加到 buckets
中的各种插槽中> 数组。这部分不应该花费几天的时间。
但为了了解存储桶的 dtype
如何影响该步骤,我将测试向随机索引分类中添加 1。
In [57]: idx = np.random.randint(0,255**2,10000)
In [58]: %%timeit buckets = np.zeros(255**2, dtype=np.int64)
...: for i in idx:
...: buckets[i] += 1
...:
9.38 ms ± 39.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [59]: %%timeit buckets = np.zeros(255**2, dtype=np.uint64)
...: for i in idx:
...: buckets[i] += 1
...:
71.7 ms ± 2.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
uint64
大约慢 8 倍。
如果没有重复项,我们可以直接执行buckets[idx] += 1
。但考虑到重复,我们必须使用 add.at
:
In [60]: %%timeit buckets = np.zeros(255**2, dtype=np.int64)
...: np.add.at(buckets, idx, 1)
...:
1.6 ms ± 348 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [61]: %%timeit buckets = np.zeros(255**2, dtype=np.uint64)
...: np.add.at(buckets, idx, 1)
...:
1.62 ms ± 15.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
有趣的是,在这种情况下,dtype uint64
不会影响计时。
您在评论中提到您尝试了列表累加器。我假设是这样的:
In [62]: %%timeit buckets = [0]*(255**2)
...: for i in idx:
...: buckets[i] += 1
...:
3.59 ms ± 44.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
这比数组的迭代版本更快。一般来说,数组上的迭代比列表上的迭代慢。 “整个数组”操作速度更快,例如 add.at
。
要验证 add.at
是否是迭代的正确替代品,请进行比较
In [63]: buckets0 = np.zeros(255**2, dtype=np.int64)
In [64]: for i in idx: buckets0[i] += 1
In [66]: buckets01 = np.zeros(255**2, dtype=np.int64)
In [67]: np.add.at(buckets01, idx, 1)
In [68]: np.allclose(buckets0, buckets01)
Out[68]: True
In [69]: buckets02 = np.zeros(255**2, dtype=np.int64)
In [70]: buckets02[idx] += 1
In [71]: np.allclose(buckets0, buckets02)
Out[71]: False
In [75]: bucketslist = [0]*(255**2)
In [76]: for i in idx: bucketslist[i] += 1
In [77]: np.allclose(buckets0, bucketslist)
Out[77]: True
关于python - 有没有一种方法可以不慢地增加 numpy 数组?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46031105/