python - 可视化特征图: IndexError: too many indices for array

标签 python matplotlib error-handling conv-neural-network index-error

在学习完本教程之后,我将尝试可视化特征图。
我的模型如下所示:

model.summary()
Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_5 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
efficientnet-b0 (Functional)    (None, 7, 7, 1280)   4049564     input_5[0][0]                    
__________________________________________________________________________________________________
flatten_4 (Flatten)             (None, 62720)        0           efficientnet-b0[0][0]            
__________________________________________________________________________________________________
branch_0_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_1_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_2_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_3_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_4_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_5_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_6_Dense_16000 (Dense)    (None, 256)          16056576    flatten_4[0][0]                  
__________________________________________________________________________________________________
branch_0_output (Dense)         (None, 35)           8995        branch_0_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_1_output (Dense)         (None, 35)           8995        branch_1_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_2_output (Dense)         (None, 35)           8995        branch_2_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_3_output (Dense)         (None, 35)           8995        branch_3_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_4_output (Dense)         (None, 35)           8995        branch_4_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_5_output (Dense)         (None, 35)           8995        branch_5_Dense_16000[0][0]       
__________________________________________________________________________________________________
branch_6_output (Dense)         (None, 35)           8995        branch_6_Dense_16000[0][0]       
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 245)          0           branch_0_output[0][0]            
                                                                 branch_1_output[0][0]            
                                                                 branch_2_output[0][0]            
                                                                 branch_3_output[0][0]            
                                                                 branch_4_output[0][0]            
                                                                 branch_5_output[0][0]            
                                                                 branch_6_output[0][0]            
__________________________________________________________________________________________________
reshape_4 (Reshape)             (None, 7, 35)        0           concatenate_4[0][0]              
==================================================================================================
Total params: 116,508,561
Trainable params: 116,466,545
Non-trainable params: 42,016
我现在想可视化索引为10的图层:10 branch_0_output (None, 35)
3 branch_0_Dense_16000 (None, 256)
4 branch_1_Dense_16000 (None, 256)
5 branch_2_Dense_16000 (None, 256)
6 branch_3_Dense_16000 (None, 256)
7 branch_4_Dense_16000 (None, 256)
8 branch_5_Dense_16000 (None, 256)
9 branch_6_Dense_16000 (None, 256)
10 branch_0_output (None, 35)
11 branch_1_output (None, 35)
12 branch_2_output (None, 35)
13 branch_3_output (None, 35)
14 branch_4_output (None, 35)
15 branch_5_output (None, 35)
16 branch_6_output (None, 35)
我遵循了教程中所述的代码,对图像进行了预处理,现在我想绘制此层的35个(?)特征图:
我在教程中使用了代码,并修改了平方数,这里是1,但是我尝试了几次:
# plot all 35 maps
square = 1
ix = 1
for _ in range(square):
    for _ in range(square):
        # specify subplot and turn of axis
        ax = pyplot.subplot(square, square, ix)
        ax.set_xticks([])
        ax.set_yticks([])
        # plot filter channel in grayscale
        pyplot.imshow(feature_maps[0, :, :, ix-1], cmap='gray')
        ix += 1
# show the figure
pyplot.show()
无论尝试多少电话,我都会收到以下错误消息:
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-28-4c1f464f6978> in <module>()
      9                 ax.set_yticks([])
     10                 # plot filter channel in grayscale
---> 11                 pyplot.imshow(feature_maps[0, :, ix-1], cmap='gray')
     12                 ix += 1
     13 # show the figure

IndexError: too many indices for array
有人可以帮我修改一下吗?
非常感谢!

最佳答案

该错误在第11行显示too many indices for array。之所以发生这种情况,是因为您在要素图中错误地传递了索引。在这里,您要在正方形为1的情况下尝试在1 * 1网格中绘制35张 map 。
假设您需要绘制64映射,那么我们将使用square = 8,然后输出将是一个8 * 8的网格。

关于python - 可视化特征图: IndexError: too many indices for array,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63012225/

相关文章:

python - 从列表中分离整数和字符串

python - Matplotlib 缩放时绘图分辨率不佳

r - R glm.nb预测三个变量返回错误

oracle - PL/SQL Oracle错误处理

ruby-on-rails-3 - i18n自定义验证错误处理

python - 我怎样才能将一个有n列的矩阵转换成只有一列的矩阵?

python - 调用实例类的方法时,“int”对象不可调用

python - 绘制直方图,使条形高度总和为 1(概率)

python - 生成以 y 轴作为相对频率的直方图?

python - 哪个语法规则匹配 def foo(a, *, b=10) 复合语句?