python - PyTorch DataLoader 将批处理作为列表返回,并将批处理作为唯一条目。从我的 DataLoader 获取张量的最佳方法是什么

标签 python pytorch dataloader

我目前遇到以下情况,我想使用 DataLoader 批处理 numpy 数组:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])
print(x.shape)
# >> (100,10)

# Create DataLoader
input_as_tensor = torch.from_numpy(x).float()
dataset = data_utils.TensorDataset(input_as_tensor)
dataloader = data_utils.DataLoader(dataset,
                                   batch_size=100,
                                  )
batch = next(iter(dataloader))

print(type(batch))
# >> <class 'list'>

print(len(batch))
# >> 1

print(type(batch[0]))
# >> class 'torch.Tensor'>

我希望批处理已经是torch.Tensor。截至目前,我像这样索引批处理,batch[0]来获取张量,但我觉得这不太漂亮,并且使代码更难阅读。

我发现DataLoader采用一个名为collat​​e_fn的批处理函数。但是,设置 data_utils.DataLoader(..., collage_fn=lambda batch: batch[0]) 仅将列表更改为元组 (tensor([ 0.8454, ..., -0.5863 ]),) 其中唯一的条目是作为张量的批处理。

通过帮助我找出如何优雅地将批处理转换为张量(即使这包括告诉我可以对批处理中的单个条目进行索引是可以的),您会对我有很大帮助。

最佳答案

很抱歉给您的回答带来不便。

实际上,您不必从张量创建Dataset,您可以直接传递torch.Tensor,因为它实现了__getitem__并且__len__,所以这就足够了:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])

# Create DataLoader
dataset = torch.from_numpy(x).float()
dataloader = data_utils.DataLoader(dataset, batch_size=100)
batch = next(iter(dataloader))

关于python - PyTorch DataLoader 将批处理作为列表返回,并将批处理作为唯一条目。从我的 DataLoader 获取张量的最佳方法是什么,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58612401/

相关文章:

python - 花括号中的 for 循环是什么意思?

pytorch - 在pytorch中使用索引提取张量数据

python - 运行时错误: Assertion `cur_target >= 0 && cur_target < n_classes' failed

javascript - Dataloader 如何缓存和批处理数据库请求?

python-3.x - pytorch collat​​e_fn 拒绝样本并产生另一个

python - 3D CNN 在图像序列上的输入形状应该是什么?

python - 迭代器的性能优势?

python - 为什么模型有 key_name 在 google-app-engine 上没有 key().id()

python - 无法点击隐形选择python

pytorch - net.zero_grad() 与 optim.zero_grad() pytorch