tensorflow - 为什么 Keras 层的方法调用中的批量大小为 None?

标签 tensorflow keras

我正在 Keras 中实现一个自定义层。如果我打印传递给 call 的输入的形状方法,我得到 None作为第一个元素。这是为什么?第一个元素不应该是批量大小吗?

def call(self, x):
    print(x.shape)  # (None, ...)

当我调用 model.fit ,我正在传递批量大小
batch_size = 50
model.fit(x_train, y_train, ..., batch_size=batch_size)

那么,方法是什么时候call居然叫?在 call 方法中获取批量大小的推荐方法是什么? ?

最佳答案

None表示它是动态形状。它可以取任何值,具体取决于您选择的批量大小。

默认情况下定义模型时,它被定义为支持您可以选择的任何批量大小。这就是 None方法。在 TensorFlow 1.*模型的输入是 tf.placeholder() 的一个实例.

如果您不使用 keras.InputLayer()使用指定的批量大小,您将获得第一个维度 None默认情况下:

import tensorflow as tf

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(units=2, input_shape=(2, )))
print(model.inputs[0].get_shape().as_list()) # [None, 2]
print(model.inputs[0].op.type == 'Placeholder') # True

当你使用 keras.InputLayer()使用指定的批量大小,您可以定义具有固定批量大小的输入占位符:

import tensorflow as tf

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer((2,), batch_size=50))
model.add(tf.keras.layers.Dense(units=2, input_shape=(2, )))
print(model.inputs[0].get_shape().as_list()) # [50, 2]
print(model.inputs[0].op.type == 'Placeholder') # True

当您将批大小指定为 model.fit()方法这些输入占位符已经定义,您不能修改它们的形状。 model.fit() 的批量大小仅用于拆分您提供给批处理的数据。

如果您使用批量大小定义输入层 2然后将批处理大小的不同值传递给 model.fit()方法你会得到ValueError :

import tensorflow as tf
import numpy as np

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer((2,), batch_size=2)) # <--batch_size==2
model.add(tf.keras.layers.Dense(units=2, input_shape=(2, )))
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss='categorical_crossentropy')
x_train = np.random.normal(size=(10, 2))
y_train = np.array([[0, 1] for _ in range(10)])

model.fit(x_train, y_train, batch_size=3) # <--batch_size==3

这将引发: ValueError: The批量大小 argument value 3 is incompatible with the specified batch size of your Input Layer: 2

关于tensorflow - 为什么 Keras 层的方法调用中的批量大小为 None?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55890678/

相关文章:

python - tf.keras.backend.get_session() 和 keras.backend.get_session() 返回不同的 session 对象

python - 我可以在 Tensorflow 联合学习 (TFF) 的 keras 模型中使用 class_weight

python - 如何将来自 python 函数的数据排入 TensorFlow 队列

python - Keras SimpleRNN 混淆

initialization - 属性错误 : module 'keras' has no attribute 'initializers'

python - 合并 hdf5 检查点文件

tensorflow - Keras 类型错误 : can't pickle _thread. RLock 对象

python - Tensorflow session 不执行函数

python - 将图层添加到 keras 预训练模型中

Keras功能API : Combine CNN model with a RNN to to look at sequences of images