python - PyTorch DataLoader 对并行运行的批处理使用相同的随机种子

标签 python numpy parallel-processing pytorch dataloader

有一个bug在 PyTorch/Numpy 中,当与 DataLoader 并行加载批次时(即设置 num_workers > 1 ),每个 worker 使用相同的 NumPy 随机种子,导致应用的任何随机函数在并行化批次中都相同。
最小的例子:

import numpy as np
from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 2)

    def __len__(self):
        return 9
    
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1, num_workers=3)

for batch in dataloader:
    print(batch)
如您所见,对于每个并行化的批次集 (3),结果是相同的:
# First 3 batches
tensor([[891, 674]])
tensor([[891, 674]])
tensor([[891, 674]])
# Second 3 batches
tensor([[545, 977]])
tensor([[545, 977]])
tensor([[545, 977]])
# Third 3 batches
tensor([[880, 688]])
tensor([[880, 688]])
tensor([[880, 688]])
解决此问题的推荐/最优雅的方法是什么?即让每批产生不同的随机化,而不管 worker 的数量。

最佳答案

这似乎有效,至少在 Colab 中是这样:

dataloader = DataLoader(dataset, batch_size=1, num_workers=3, 
    worker_init_fn = lambda id: np.random.seed(id) )
编辑:

it produces identical output (i.e. the same problem) when iterated over epochs. – iacob


迄今为止我发现的最佳解决方案:
...
dataloader = DataLoader(ds, num_workers=2, 
           worker_init_fn = lambda id: np.random.seed(id + epoch * 10 ))

for epoch in range ( 2 ):
    for batch in dataloader:
        print(batch)
    print()
仍然不能建议封闭形式,事情取决于然后调用的 var ( epoch )。理想情况下,它必须类似于 worker_init_fn = lambda id: np.random.seed(id + EAGER_EVAL(np.random.randint(10000) )其中 EAGER_EVAL 在 lambda 作为参数传递之前评估加载器构造上的种子。我想知道在python中是否可能。

关于python - PyTorch DataLoader 对并行运行的批处理使用相同的随机种子,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67180955/

相关文章:

python - 如何比较持有 numpy.ndarray 的数据类的相等性(bool(a==b) 引发 ValueError)?

python - 对整个复数数组进行插值

c# - 我如何在 Parallel.ForEach 期间添加或更新此 .NET 集合?

python - ProcessPoolExecutor 和 ThreadPoolExecutor 有什么区别?

python - Pandas 解析 json 列并将现有列保留到新的数据框中

python - 如何优化具有带条件的嵌套列表的Python代码?

python - 不使用 pandas 的多列标签编码

linux - 在 Linux shell 脚本上并行化行(命令、进程...)

python - Kivy 中的动态网格,每个网格元素包含多个小部件

python - 深度python字典递归