Python Dataset Class + PyTorch Dataloader : Stuck at __getitem__, 测试时如何获取索引、标签等?

标签 python machine-learning dataset pytorch dataloader

我有一个也许是小问题,但我现在被困了很长一段时间。希望有人能帮助我。我目前正在使用 Kddcup99 数据集,我喜欢通过深度学习(CNN 网络)来训练该数据集

我有一个“数据集”类,其中包括 Panda Dataframe。因此我分为正常数据集和验证数据集。到目前为止,没有问题。 我将其加载到 Numpy 向量中,将其传输到 Tensor,然后将其定向到 DataLoader。

数据集类有两个重要的类用于迭代:

def __len__(self):
        return len(self.val_df)

def __getitem__(self, index):        
        img, target = self.val_df[index][:-1], self.val_df[index][-1]
        return img, target, index

DataLoader 字符串不在类中:

test_dataloader = DataLoader(datat.val_df, batch_size=10, shuffle=True)

在我的 Trainer 类中,我有一个 for 循环,它应该迭代数据加载器:

with torch.no_grad():
            for data in dataloader:
                inputs, labels, idx = data
                inputs = inputs.to(self.device)

但是不会。我无法访问标签、索引等。

我现在的问题是:为什么? 如何通过数据加载器访问给定数据集中的标签、索引?

谢谢大家的帮助! 非常感谢。

最佳答案

DataLoader 的第一个参数是您要从中加载数据的数据集,通常是 Dataset ,但它不限于 Dataset 的任何实例。只要它定义了长度 (__len__) 并且可以被索引(__getitem__ 允许),它就可以接受。

您正在将 datat.val_df 传递给 DataLoader,它可能是一个 NumPy 数组。 NumPy 数组具有长度并且可以索引,因此可以在 DataLoader 中使用。由于您直接传递该数组,因此永远不会调用数据集的 __getitem__,但数组本身已建立索引,因此每个项目只是 data.val_df[index]

您必须使用数据集本身 (datat),而不是使用 DataLoader 的基础数据:

test_dataloader = DataLoader(datat, batch_size=10, shuffle=True)

关于Python Dataset Class + PyTorch Dataloader : Stuck at __getitem__, 测试时如何获取索引、标签等?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61868754/

相关文章:

machine-learning - caffe 中的 .net 文件是什么?

mysql - 声明没有返回我需要的内容

C# 数据集.关系 : How to use DataSet Relations?

python - 如何从Playstore中的应用程序获取权限信息?

python - 是否可以使用自定义 Python 窗口管理器来主题化 GUI 应用程序?

algorithm - 如何计算多标签分配的分类任务的成功率

tensorflow - 用于在 keras 中调用的自定义宏

python - Key_Value 计数过滤字典

Python 最低公共(public) CIDR

Python:使用 'Null' 作为 mysql.connector 的端口参数