python - 如何从Pytorch中的高IO数据集读取,该数据集随着时间的推移而增长

标签 python pytorch

我使用Tensorflow,但我是为用户编写文档,这些文档通常会在深度学习框架中有所不同。

当使用不适合本地文件系统(TB +)的数据集时,我从远程数据存储中采样数据,然后将采样示例本地写入Tensorflow标准tfrecords格式。

在训练的第一个时期,我将仅采样几个值,因此,一个局部数据的时期非常小,我对此进行了训练。在第2阶段,我重新检查采样子过程(现在有更多)产生了哪些数据文件,并在下一个阶段对扩展的本地数据文件集进行训练。每个时期重复该过程。这样,我可以建立样本的本地缓存,并在填满本地存储时可以驱逐较旧的样本。局部样本缓存大约在模型最需要方差时(朝训练的后半部分)增长。

在Python/Tensorflow中,至关重要的是,我不要在Python训练循环过程中反序列化数据,因为Python GIL无法支持数据传输速率(300-600 MB/秒,数据是原始的科学不可压缩的),因此不能保证GPU的性能当Python GIL无法快速服务于训练循环时,它会遭受苦难。

将样本从子进程(python多重处理)写入tfrecords文件中,从而允许tensorflow的 native TFRecordsDataset在Python之外进行反序列化,因此我们避开了Python GIL问题,并且我可以使具有高IO数据速率的GPU饱和。

I would like to know how I would address this issue in Pytorch. I'm writing about the sampling strategy that's being used, and want to provide specific recommendations to users of both Tensorflow and PyTorch, but I don't know the PyTorch preprocessing ecosystem well enough to write with sufficient detail.



旁注:支持这些数据传输速率的唯一纯基于Python的解决方案可能是带有System V共享内存和多处理功能的Python 3.8,但我还没有尝试过,因为对它的支持还不够(很快就可以了) )。现有的多处理解决方案是不够的,因为它们需要在训练循环过程中进行反序列化,从而在反序列化期间以高IO速率锁定GIL。

最佳答案

实际上,您可以使用torch.utils.data.DataLoader在子流程中轻松反序列化数据。通过将num_workers参数设置为1或更大的值,可以使用它们自己的python解释器和GIL生成子进程。

loader = torch.utils.data.DataLoader(your_dataset, num_workers=n, **kwargs)
for epoch in range(epochs):
    for batch_idx, data in enumerate(loader):
         # loader in the main process does not claim GIL at this point
Dataloader需要torch.utils.data.Dataset才能从中获取数据。在您的情况下,实现适当的子类可能不是一件容易的事。如果您需要为每个纪元重新创建Dataset实例,则可以执行以下操作。
for epcoh in range(epochs):
    dset = get_new_dataset()
    loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
    for batch_idx, data in enumerate(loader):
        # Do training

甚至更好
dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)

for epcoh in range(epochs):
    last_batch_idx =  (len(dset)-1) // loader.batch_size
    for batch_idx, data in enumerate(loader):
        # Prepare next loader in advance to avoid blocking
        if batch_idx == last_batch_idx:
            dset = get_new_dataset()
            loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
        # Do training

附带说明一下,请注意,在大多数情况下,受GIL影响的是受CPU约束的操作,而不是受I/O约束的操作,即threading可用于任何纯粹的I/O繁重的操作,甚至不需要subprocess。有关更多信息,请引用此question和此Wikipedia article

关于python - 如何从Pytorch中的高IO数据集读取,该数据集随着时间的推移而增长,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60119934/

相关文章:

python - 算法中 "combine"函数的最佳方法?

python - 简化链式比较

python - 在 Windows 中从命令行运行稳定扩散时出错

python - 关于 Tensorflow 和 PyTorch 中的自定义操作

python - PyTorch 中复杂矩阵的行列式

python - 从 SQL 数据库中的 OHLC 数据中选择 7、14、20、50、200 天的价格。

python - 如何写入带有小数点后固定数字的 astropy Table 对象值?

python - 计算全连接层的尺寸?

python - 将剩余连接添加到简单的 CNN

python - 是否可以只更新 pypi 索引中的详细信息,而不重新创建包?