python - Keras model.summary() 结果 - 了解参数的数量

标签 python machine-learning neural-network keras theano

我有一个简单的 NN 模型,用于检测使用 Keras(Theano 后端)用 python 编写的 28x28px 图像中的手写数字:

model0 = Sequential()

#number of epochs to train for
nb_epoch = 12
#amount of data each iteration in an epoch sees
batch_size = 128

model0.add(Flatten(input_shape=(1, img_rows, img_cols)))
model0.add(Dense(nb_classes))
model0.add(Activation('softmax'))
model0.compile(loss='categorical_crossentropy', 
         optimizer='sgd',
         metrics=['accuracy'])

model0.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      verbose=1, validation_data=(X_test, Y_test))

score = model0.evaluate(X_test, Y_test, verbose=0)

print('Test score:', score[0])
print('Test accuracy:', score[1])

这运行良好,我的准确率约为 90%。然后我执行以下命令,通过执行 print(model0.summary()) 来获得网络结构的摘要。这将输出以下内容:

Layer (type)         Output Shape   Param #     Connected to                     
=====================================================================
flatten_1 (Flatten)   (None, 784)     0           flatten_input_1[0][0]            
dense_1 (Dense)     (None, 10)       7850        flatten_1[0][0]                  
activation_1        (None, 10)          0           dense_1[0][0]                    
======================================================================
Total params: 7850

我不明白他们是如何达到 7850 个总参数的,这实际上意味着什么?

最佳答案

参数的数量是 7850,因为对于每个隐藏单元,您有 784 个输入权重和一个带偏差的连接权重。这意味着每个隐藏单元都会为您提供 785 个参数。您有 10 个单位,因此总计为 7850。

这个附加偏差项的作用非常重要。它显着增加了模型的容量。您可以阅读详细信息,例如这里 Role of Bias in Neural Networks .

关于python - Keras model.summary() 结果 - 了解参数的数量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36946671/

相关文章:

python - 如果在 flask 上选中了复选框,如何获取

python - Tkinter 在文本中插入 json 数据

python - 我需要帮助将列表转换为 pandas 数据框

python - CNN 对所有输入数据预测相同的类别

python - 如何使用 lambda 函数在数据帧上使用 Pandas apply() ?

python - 如何在Python中识别过拟合和欠拟合

amazon-web-services - 如何在AWS sagemaker中运行预训练的模型?

C# Encog 神经网络——尽管神经网络的整体误差很低,但预期输出与实际误差相去甚远

python - 图像分类软件

python - lxml:Element addnext() 和 insert() 在处理 tail 时的区别