Python matplotlib,图像数据的无效形状

标签 python python-3.x matplotlib pytorch

目前我有这个代码来显示三个图像:

imshow(image1, title='1')
imshow(image2, title='2')
imshow(image3, title='3')

它工作正常。但我试图将它们全部放在一行而不是列中。

这是我尝试过的代码:
f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1)
f.add_subplot(1,3,2)
plt.imshow(image2)
f.add_subplot(1,3,3)
plt.imshow(image3)

它抛出

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.



如果我做
f = plt.figure()
f.add_subplot(1,3,1)
plt.imshow(image1.cpu())
f.add_subplot(1,3,2)
plt.imshow(image2.cpu())
f.add_subplot(1,3,3)
plt.imshow(image3.cpu())

它抛出

TypeError: Invalid shape (1, 3, 128, 128) for image data



我应该如何解决这个问题,或者有更简单的方法来实现它?

最佳答案

matplotlib 函数 'imshow' 将 3 channel 图片作为 (h, w, 3) 获取,正如您在 documentation 中看到的那样.

似乎您传递了图像的三个 channel (第二维)的“批次”单个图像(第一维)(h 和 w 是第三和第四维)。

您需要 reshape 或查看您的图像(转换为 cpu 后,尝试使用:

image1.squeeze().permute(1,2,0)

结果将是所需形状(128、128、3)的图像。

挤压()函数将删除第一维。 premute() 函数将调换维度,其中第一个将移至第三个位置,而另外两个将移至开头。

另外,请查看此处以进一步讨论 GPU 和 CPU 问题:
link

希望有帮助。

关于Python matplotlib,图像数据的无效形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61480762/

相关文章:

python - argparse 中的可选 nargs 变量

python - 平滑像素化二值图像python代码的边缘

python - 字典操作

python - 如何从整数列表中提取单个数字?

python - MatPlotLib 动态时间轴

python - 删除 matplotlib 子图中的多余图

python - 添加 "search"同时保持分页和排序?

python - 如何在子目录中创建一个 "just fits"文件的 block 设备?

python - 在 Matplotlib 中同时对同一图中的两条曲线进行动画处理

Python Mechanize 说现有控件不存在