我在所有层中设置 trainable=False
,通过 Model
API 实现,但我想验证它是否有效。 model.count_params()
返回参数的总数,但是除了查看 model 的最后几行之外,有什么方法可以获得可训练参数的总数。总结()
?
最佳答案
from keras import backend as K
trainable_count = int(
np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
上面的代码片段可以在layer_utils.print_summary()
的末尾找到定义,summary()
正在打电话。
编辑:更新版本的 Keras 有一个辅助函数 count_params()
为此目的:
from keras.utils.layer_utils import count_params
trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)
关于python - 如何获取 Keras 模型的可训练参数数量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45046525/