machine-learning - 图像分类-pytorch

标签 machine-learning computer-vision pytorch transfer-learning

我正在尝试使用预训练模型来预测特征。 我得到的输出如下,但是,如何使用 torch.max() 来获取感兴趣的类。我尝试过的代码:

... loading model
input = transformation_sequence(sample).unsqueeze(0)
outputs = model(input)
_, predicted = torch.max(outputs,1) #this line returns error

#print of `outputs` variable
[tensor([[ 3.0654, -3.0650]]), tensor([[ 1.5634, -1.5672]]), tensor([[ 1.2867, -1.2888]]), tensor([[ 1.2974, -1.2928]]), tensor([[ 6.4537, -6.4487]]), tensor([[ 2.4851, -2.4710]]), tensor([[ 0.9855, -0.9809]]), tensor([[ 0.3995, -0.4033]]), tensor([[ 0.6301, -0.6276]]), tensor([[ 5.7082, -5.6931]]), tensor([[ 1.9354, -1.9365]]), tensor([[ 0.6091, -0.6074]]), tensor([[ 5.4509, -5.4417]]), tensor([[ 3.7231, -3.7115]]), tensor([[ 4.4494, -4.4361]]), tensor([[ 0.8867, -0.8902]]), tensor([[ 2.7410, -2.7402]]), tensor([[ 5.4919, -5.4909]]), tensor([[ 2.2687, -2.2744]]), tensor([[-0.9695,  0.9723]]), tensor([[ 1.5100, -1.5114]]), tensor([[-2.7077,  2.7140]]), tensor([[ 4.4661, -4.4734]]), tensor([[ 0.4846, -0.4821]]), tensor([[-2.9743,  2.9643]]), tensor([[ 1.3900, -1.3874]]), tensor([[ 7.6764, -7.6742]]), tensor([[ 0.5173, -0.5118]]), tensor([[ 1.3513, -1.3503]]), tensor([[ 2.5381, -2.5356]]), tensor([[ 4.9850, -5.0074]]), tensor([[-2.8397,  2.8484]]), tensor([[ 3.1010, -3.1137]]), tensor([[-0.2374,  0.2406]]), tensor([[ 0.5338, -0.5358]]), tensor([[ 3.4912, -3.4979]]), tensor([[ 1.1957, -1.1876]]), tensor([[ 1.1189, -1.1163]]), tensor([[ 3.6400, -3.6365]]), tensor([[-1.3123,  1.3132]])]

#list of error:

  _, predicted = torch.max(outputs,1)
TypeError: max() received an invalid combination of arguments - got (list, int), but expected one of:
 * (Tensor input)
 * (Tensor input, Tensor other, Tensor out)
 * (Tensor input, int dim, bool keepdim, tuple of Tensors out)

最佳答案

您的模型返回张量列表,而不是张量。它可以用 torch.cat 修复:

torch.max(torch.cat(outputs),1)

>>> torch.return_types.max(
values=tensor([3.0654, 1.5634, 1.2867]),
indices=tensor([0, 0, 0]))

关于machine-learning - 图像分类-pytorch,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60106379/

相关文章:

Java Weka 按属性拆分实例

matlab - 循环邻域操作 : matlab color histogram

python - 为什么我的显卡不能与 PyTorch 一起使用?

machine-learning - Pytorch 中的 int8 数据类型

machine-learning - 需要帮助理解 CGAN 中的标签输入

python - Catboost 理解 - 分类值的转换

image-processing - 如何使超像素的标签在灰度图中局部一致?

image - 使用 findContours 时如何避免检测图像帧

python - 如何分段弯杆进行角度计算?

machine-learning - 我如何将 bool 张量输入到 tf.cond() 而不仅仅是一个 bool 值?