在 Keras 中,为什么是 input_shape
当作为参数传递给像 Dense
这样的层时,不包括批处理维度但在 input_shape
时包含批处理维度传递给 build
模型的方法?
import tensorflow as tf
from tensorflow.keras.layers import Dense
if __name__ == "__main__":
model1 = tf.keras.Sequential([Dense(1, input_shape=[10])])
model1.summary()
model2 = tf.keras.Sequential([Dense(1)])
model2.build(input_shape=[None, 10]) # why [None, 10] and not [10]?
model2.summary()
这是 API 设计的明智选择吗?如果是,为什么?
最佳答案
您可以通过几种不同的方式指定模型的输入形状。例如,通过向模型的第一层提供以下参数之一:
batch_input_shape
:第一个维度是批量大小的元组。 input_shape
: 不包括批大小的元组,例如,批大小假定为 None
或 batch_size
,如果指定。 input_dim
: 一个标量,表示输入的维度。 在所有这些情况下,Keras 都是 internally storing一个属性
_batch_input_size
建立模型。关于
build
方法,我的猜测是这确实是一个有意识的选择——关于批量大小的信息可能对在某些(也许是未曾想到的)情况下构建模型有用。因此,包含批处理维度作为输入的框架 build
比没有的框架更通用和完整。尽管如此,我同意你将论点命名为 batch_input_shape
而不是 input_shape
将使一切更加一致。值得一提的是,用户很少需要调用
build
自己的方法。这在需要时在内部发生。如今,甚至可以 ignore input_shape
创建模型时的参数(尽管像 summary
这样的方法在模型构建之前将不起作用)。在这种情况下,Keras 能够从参数 x
推断输入形状。的 fit
.
关于python - 为什么当作为参数传递给 `input_shape` 层时, `Dense` 不包括批处理维度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64681232/