python-3.x - 使用大型多个数据集,其中每个数据集包含多个值 - Pytorch

标签 python-3.x neural-network pytorch large-files

我正在训练一个神经网络,并且在一个文件夹中有超过 15GB 的数据,该文件夹有多个 pickle 文件,每个文件包含两个列表,每个列表包含多个值。 这看起来像下面这样: 数据集_文件夹:\

  • 文件.pickle
  • file_2.pickle
  • ...
  • ...
  • file_n.pickle

每个文件_*.pickle 包含一个可变长度列表(列表 x 和列表 y)。

如何在没有内存问题的情况下加载所有数据来训练模型?

最佳答案

通过实现自定义 dataset Pytorch 提供的类,我们需要实现三个方法,以便 pytorch 加载器可以处理您的数据

  • __len__
  • __getitem__
  • __init__

让我们来看看如何分别实现它们中的每一个。

  • __init__

    def __init__(self):
    
     # Original Data has the following format
     """
       dict_object = 
       {
         "x":[],
         "y":[]
       }
     """
     DIRECTORY = "data/raw"
     self.dataset_file_name = os.listdir(DIRECTORY)
     self.dataset_file_name_index = 0
     self.dataset_length =0
     self.prefix_sum_idx = list()
     # Loop over each file and calculate the length of overall dataset
     # you might need to check if file_name is file
     for file_name in os.listdir(DIRECTORY):
       with (open(f'{DIRECTORY}/{file_name}', "rb")) as openfile:
         dict_object = pickle.load(openfile)
         curr_page_sum = len(dict_object["x"]) + len(dict_object["y"])
         self.prefix_sum_idx.append(curr_page_sum)
         self.dataset_length += curr_page_sum
     # prefix sum so we have an idea of where each index appeared in which file. 
     for i in range (1,len(self.prefix_sum_idx)):
       self.prefix_sum_idx[i] = self.prefix_sum_idx[i] + self.prefix_sum_idx[i-1]
    
     assert self.prefix_sum_idx[-1] == self.dataset_length
     self.x = []
     self.y = []
    

如上所示,主要思路是使用prefix sum将所有数据集“视为”一次,所以逻辑是每当您需要稍后访问特定索引时,您只需查看 prefix_sum_idx 以查看此 idx 出现的位置。

prefix sum illustration

在上图中,假设我们需要访问索引 150。感谢前缀和,我们现在能够知道第二个 .pickle 文件中存在 150。我们仍然需要一种快速机制来知道 idxprefix_sum_idx 中的位置。这将在 __getitem__

中解释
  • __getitem__

    def read_pickle_file(self, idx):
     file_name = self.dataset_file_name[idx]
     dict_object = dict()
     with (open(f'{YOUR_DIRECTORY}/{file_name}', "rb")) as openfile:
         dict_object = pickle.load(openfile)
    
     self.x = dict_object['x']
     self.y = #some logic here
     ......
     # Some logic here....
    
    
    def __getitem__(self,idx):
    
     # Similar to C++ std::upper_bound - O(log n)
     temp = bisect.bisect_right(self.prefix_sum_idx, idx)
    
     self.read_pickle_file(temp)
     local_idx = idx - self.prefix_sum_idx[temp] 
    
     return self.x[local_idx],self.y[local_idx]
    

查看 bisect_right() 文档以了解其工作原理的详细信息,但它只是返回排序列表中最右边的位置以插入给定元素并保持排序。在我们的方法中,我们只对以下问题感兴趣,“我应该访问哪个文件以获得适当的数据”。更重要的是,它在 O(log n)

中这样做
  • __len__

    def __len__(self):
     return self.dataset_length
    

为了获得数据集的长度,我们循环遍历每个文件并累积结果,如 __init__ 所示。

完整的代码示例如下:

import pickle
import torch
import torch.nn as nn
import numpy
import os 
import bisect
from torch.utils.data import Dataset, DataLoader
from src.data.make_dataset import main
from torch.nn import functional as F

class dataset(Dataset):
  def __init__(self):

    # Original Data has the following format
    """
    dict_object = 
    {
        "x":[],
        "y":[]
    }
    """
    DIRECTORY = "data/raw"
    self.dataset_file_name = os.listdir(DIRECTORY)
    self.dataset_file_name_index = 0
    self.dataset_length =0
    self.prefix_sum_idx = list()
    # Loop over each file and calculate the length of overall dataset
    # you might need to check if file_name is file
    for file_name in os.listdir(DIRECTORY):
    with (open(f'{DIRECTORY}/{file_name}', "rb")) as openfile:
        dict_object = pickle.load(openfile)
        curr_page_sum = len(dict_object["x"]) + len(dict_object["y"])
        self.prefix_sum_idx.append(curr_page_sum)
        self.dataset_length += curr_page_sum
    # prefix sum so we have an idea of where each index appeared in which file. 
    for i in range (1,len(self.prefix_sum_idx)):
    self.prefix_sum_idx[i] = self.prefix_sum_idx[i] + self.prefix_sum_idx[i-1]

    assert self.prefix_sum_idx[-1] == self.dataset_length
    self.x = []
    self.y = []

    



def read_pickle_file(self, idx):
 file_name = self.dataset_file_name[idx]
 dict_object = dict()
 with (open(f'{YOUR_DIRECTORY}/{file_name}', "rb")) as openfile:
     dict_object = pickle.load(openfile)

 self.x = dict_object['x']
 self.y = #some logic here
 ......
 # Some logic here....


def __getitem__(self,idx):

 # Similar to C++ std::upper_bound - O(log n)
 temp = bisect.bisect_right(self.prefix_sum_idx, idx)

 self.read_pickle_file(temp)
 local_idx = idx - self.prefix_sum_idx[temp] 

 return self.x[local_idx],self.y[local_idx]



def __len__(self):
 return self.dataset_length


large_dataset = dataset()
train_size = int (0.8 * len(large_dataset))
validation_size = len(large_dataset) - train_size

train_dataset, validation_dataset = torch.utils.data.random_split(large_dataset, [train_size, validation_size])
validation_loader = DataLoader(validation_dataset, batch_size=64, num_workers=4, shuffle=False)
train_loader = DataLoader(train_dataset,batch_size=64, num_workers=4,shuffle=False)

关于python-3.x - 使用大型多个数据集,其中每个数据集包含多个值 - Pytorch,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72487788/

相关文章:

python-3.x - PyArrow 表到 PySpark 数据帧的转换

python - 根据python中每个列表元素中的一些定界符将列表分成两部分

java - 使用神经网络进行文本分类

python - pytorch cnn模型在loss.backward()处停止而没有任何提示?

django - 由于导入错误,无法运行 gunicorn

python-3.x - tensorflow 2 api回归tensorflow.python.framework.ops.EagerTensor'对象不可调用

python - 验证损失增加

tensorflow - 低密度区域中神经网络的回归精度

python - 我可以使用逻辑索引或索引列表对张量进行切片吗?

python - PyTorch 数据集中哪里使用了 len 函数?