python - 以多线程方式加载多个 npz 文件

标签 python multithreading numpy

我有几个 .npz 文件。所有 .npz 文件都具有相同的结构:每个文件仅包含两个变量,且变量名始终相同。截至目前,我只需循环所有 .npz 文件,检索两个变量值并将它们附加到某个全局变量中:

# Let's assume there are 100 npz files
x_train = []
y_train = []
for npz_file_number in range(100):
    data = dict(np.load('{0:04d}.npz'.format(npz_file_number)))
    x_train.append(data['x'])
    y_train.append(data['y'])

需要一段时间,瓶颈是CPU。 xy 变量附加到 x_trainy_train 变量的顺序并不重要。

有什么方法可以多线程加载多个 .npz 文件吗?

最佳答案

我对 @Brent Washburne 的评论感到惊讶,并决定自己尝试一下。我认为一般问题有两个:

首先,读取数据通常受 IO 限制,因此编写多线程代码通常不会产生很高的性能增益。其次,由于语言本身的设计,在Python中进行共享内存并行化本身就很困难。与原生 c 相比,开销要大得多。

但是让我们看看我们能做什么。

# some imports
import numpy as np
import glob
from multiprocessing import Pool
import os

# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)
    for i in range(100):
        x = np.random.rand(10000, 50)
        file_path = os.path.join(tmp_dir, '%05d.npz' % i)
        np.savez_compressed(file_path, x=x)

def read_x(path):
    with np.load(path) as data:
        return data["x"]

def serial_read(files):
    x_list = list(map(read_x, files))
    return x_list

def parallel_read(files):
    with Pool() as pool:
        x_list = pool.map(read_x, files)
    return x_list

好了,东西准备得够多了。让我们了解一下时间。

files = glob.glob(os.path.join(tmp_dir, '*.npz'))

%timeit x_serial = serial_read(files)
# 1 loops, best of 3: 7.04 s per loop

%timeit x_parallel = parallel_read(files)
# 1 loops, best of 3: 3.56 s per loop

np.allclose(x_serial, x_parallel)
# True

它实际上看起来像是一个不错的加速。我使用两个真实核心和两个超线程核心。

<小时/>

要一次运行所有内容并为其计时,您可以执行以下脚本:

from __future__ import print_function
from __future__ import division

# some imports
import numpy as np
import glob
import sys
import multiprocessing
import os
import timeit

# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)
    for i in range(100):
        x = np.random.rand(10000, 50)
        file_path = os.path.join(tmp_dir, '%05d.npz' % i)
        np.savez_compressed(file_path, x=x)

def read_x(path):
    data = dict(np.load(path))
    return data['x']

def serial_read(files):
    x_list = list(map(read_x, files))
    return x_list

def parallel_read(files):
    pool = multiprocessing.Pool(processes=4)
    x_list = pool.map(read_x, files)
    return x_list


files = glob.glob(os.path.join(tmp_dir, '*.npz'))
#files = files[0:5] # to test on a subset of the npz files

# Timing:
timeit_runs = 5

timer = timeit.Timer(lambda: serial_read(files))
print('serial_read: {0:.4f} seconds averaged over {1} runs'
      .format(timer.timeit(number=timeit_runs) / timeit_runs,
      timeit_runs))
# 1 loops, best of 3: 7.04 s per loop

timer = timeit.Timer(lambda: parallel_read(files))
print('parallel_read: {0:.4f} seconds averaged over {1} runs'
      .format(timer.timeit(number=timeit_runs) / timeit_runs,
      timeit_runs))
# 1 loops, best of 3: 3.56 s per loop

# Examples of use:
x = serial_read(files)
print('len(x): {0}'.format(len(x))) # len(x): 100
print('len(x[0]): {0}'.format(len(x[0]))) # len(x[0]): 10000
print('len(x[0][0]): {0}'.format(len(x[0][0]))) # len(x[0]): 10000
print('x[0][0]: {0}'.format(x[0][0])) # len(x[0]): 10000
print('x[0].nbytes: {0} MB'.format(x[0].nbytes / 1e6)) # 4.0 MB

关于python - 以多线程方式加载多个 npz 文件,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35328085/

相关文章:

python - 如何从 Selenium 中的多个类的多个元素中获取数据?

python - 创建 dict 对象时对象未正确形成

java - 在单独的线程中运行逻辑仍然锁定 UI - Java

c# - Thread.Interrupt 等效于任务 TPL

python - NumPy 读取文件并即时过滤行

python - 在 Python 中使用 SQL Server 文件流

python - 如何在Tensorflow 2.x Keras自定义层中使用多个输入?

java - Android - 如何等待 SnapshotListener 完成?

python - 如何在训练数据集上使用 SMAPE 评估指标?

python-3.x - 如何遍历数组对每个像素应用阈值