python - ValueError : Error when checking input: expected time_distributed_46_input to have 5 dimensions, 但得到形状为 (200, 200, 3) 的数组

标签 python image tensorflow machine-learning keras

我正在修改 Timedistributed 层,但遇到了困难。我正在尝试创建一个非常简单的模型,该模型采用 200 x 200 rgb 图像并读取上面写的字符。

我不断收到以下错误,但不知道如何修复它:

ValueError: Error when checking input: expected time_distributed_46_input to have 5 dimensions, but got array with shape (200, 200, 3)

这是我的 keras 代码:

num_timesteps = len(chars) # length of sequence
img_width = 200
img_height = 200
img_channels = 3

def model():
    # define CNN model
    cnn = Sequential()
    cnn.add(Conv2D(64, (3,3), activation='relu', padding='same', input_shape=(img_width,img_height,img_channels)))
    cnn.add(MaxPooling2D(pool_size=(3, 3)))
    cnn.add(Flatten())

    # define LSTM model
    model = Sequential()
    model.add(TimeDistributed(cnn, input_shape=(num_timesteps, img_width,img_height,img_channels)))
    model.add(LSTM(num_timesteps))
    model.add(Dense(26))

    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

然后我像这样拟合模型:

model().fit_generator(generator=images_generator(), steps_per_epoch=20, epochs=2)

我在哪里生成图像,如下所示:

def image_sample():
    rand_str = random_str()
    blank=Image.new("RGB", (200,200),(255,255,255))
    font = ImageFont.truetype("StatePlate.ttf", 100)
    draw = ImageDraw.Draw(blank)
    draw.text((30, 40),rand_str,(0,0,0), font=font)
    draw = ImageDraw.Draw(blank)
#     datagen = ImageDataGenerator(rotation_range=90)
#     datagen.fit(blank)
    return (np.asarray(blank), one_hot_char(rand_str))

def one_hot_char(char):
    zeros = np.zeros(len(chars))
    zeros[chars.index(char)] = 1
    return zeros

def images_generator():
    yield image_sample()

感谢任何帮助!谢谢。

最佳答案

目前,生成器返回单个图像。生成器生成的输入应具有形状:[batch_size, num_timesteps, img_width, img_height, img_channels]

此虚拟数据的快速修复方法是将 np.asarray(blank) 更改为 np.asarray([[blank]])

关于python - ValueError : Error when checking input: expected time_distributed_46_input to have 5 dimensions, 但得到形状为 (200, 200, 3) 的数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47925346/

相关文章:

python - Matplotlib:如何将时间戳与 broken_barh 一起使用?

java - 无法在 Java 上加载图像 (BlueJ)

python - Django 通过 iOS 上传图片时图像旋转不正确(EXIF 问题)

c# - RichTextBox无法插入图片

python - 错误 : Failed to load the native TensorFlow runtime

tensorflow - 批量训练使用更新总和?或平均更新?

Tensorflow,用 tf.train.Saver 保存了什么?

python - 用python解析dns配置文件

python - 在 DataFrameGroupBy 对象的组内进行切片

python - 有没有一种优雅的方法可以通过迭代循环 N 次列表(如 itertools.cycle 但限制循环)?