目前我有这个代码来显示三个图像:
imshow(image1, title='1')
imshow(image2, title='2')
imshow(image3, title='3')
它工作正常。但我试图将它们全部放在一行而不是列中。
这是我尝试过的代码:
f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1)
f.add_subplot(1,3,2)
plt.imshow(image2)
f.add_subplot(1,3,3)
plt.imshow(image3)
它抛出
TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
如果我做
f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1.cpu())
f.add_subplot(1,3,2)
plt.imshow(image2.cpu())
f.add_subplot(1,3,3)
plt.imshow(image3.cpu())
它抛出
TypeError: Invalid shape (1, 3, 128, 128) for image data
我应该如何解决这个问题,或者有更简单的方法来实现它?
最佳答案
matplotlib 函数 'imshow' 将 3 channel 图片作为 (h, w, 3) 获取,正如您在 documentation 中看到的那样.
似乎您传递了图像的三个 channel (第二维)的“批次”单个图像(第一维)(h 和 w 是第三和第四维)。
您需要 reshape 或查看您的图像(转换为 cpu 后,尝试使用:
image1.squeeze().permute(1,2,0)
结果将是所需形状(128、128、3)的图像。
挤压()函数将删除第一维。 premute() 函数将调换维度,其中第一个将移至第三个位置,而另外两个将移至开头。
另外,请查看此处以进一步讨论 GPU 和 CPU 问题:
link
希望有帮助。
关于Python matplotlib,图像数据的无效形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61480762/