我有一个也许是小问题,但我现在被困了很长一段时间。希望有人能帮助我。我目前正在使用 Kddcup99 数据集,我喜欢通过深度学习(CNN 网络)来训练该数据集
我有一个“数据集”类,其中包括 Panda Dataframe。因此我分为正常数据集和验证数据集。到目前为止,没有问题。 我将其加载到 Numpy 向量中,将其传输到 Tensor,然后将其定向到 DataLoader。
数据集类有两个重要的类用于迭代:
def __len__(self):
return len(self.val_df)
def __getitem__(self, index):
img, target = self.val_df[index][:-1], self.val_df[index][-1]
return img, target, index
DataLoader 字符串不在类中:
test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)
在我的 Trainer 类中,我有一个 for 循环,它应该迭代数据加载器:
with torch.no_grad():
for data in dataloader:
inputs, labels, idx = data
inputs = inputs.to(self.device)
但是不会。我无法访问标签、索引等。
我现在的问题是:为什么? 如何通过数据加载器访问给定数据集中的标签、索引?
谢谢大家的帮助! 非常感谢。
最佳答案
DataLoader
的第一个参数是您要从中加载数据的数据集,通常是 Dataset
,但它不限于 Dataset
的任何实例。只要它定义了长度 (__len__
) 并且可以被索引(__getitem__
允许),它就可以接受。
您正在将 datat.val_df
传递给 DataLoader
,它可能是一个 NumPy 数组。 NumPy 数组具有长度并且可以索引,因此可以在 DataLoader 中使用。由于您直接传递该数组,因此永远不会调用数据集的 __getitem__
,但数组本身已建立索引,因此每个项目只是 data.val_df[index]
。
您必须使用数据集本身 (datat
),而不是使用 DataLoader
的基础数据:
test_dataloader = DataLoader(datat, batch_size=10, shuffle=True)
关于Python Dataset Class + PyTorch Dataloader : Stuck at __getitem__, 测试时如何获取索引、标签等?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61868754/