python - 图像生成器缺少 unet keras 的位置参数

标签 python machine-learning keras generator conv-neural-network

当我尝试训练模型时,我不断收到以下代码的以下错误:TypeError: fit_generator() missing 1 required positional argument: 'generator'。对于我的生活,我无法弄清楚是什么导致了这个错误。 x_train 是形状为 (400, 256, 256, 3) 的 rgb 图像,对于 y_train 我有 10 个输出类使其形状为 (400, 256, 256, 10)。这里出了什么问题?

如有必要,可以通过以下链接下载数据: https://www49.zippyshare.com/v/5pR3GPv3/file.html

import skimage
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropout
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as K
from sklearn.metrics import jaccard_similarity_score
from shapely.geometry import MultiPolygon, Polygon
import shapely.wkt
import shapely.affinity
from collections import defaultdict
from keras.preprocessing.image import ImageDataGenerator
from keras.utils.np_utils import to_categorical
from keras import utils as np_utils
import os
from keras.preprocessing.image import ImageDataGenerator
gen = ImageDataGenerator()
#Importing image and labels
labels = skimage.io.imread("ede_subset_293_wegen.tif")
images = skimage.io.imread("ede_subset_293_20180502_planetscope.tif")[...,:-1]


#scaling image
img_scaled = images / images.max()

#Make non-roads 0
labels[labels == 15] = 0

#Resizing image and mask and labels
img_scaled_resized = img_scaled[:6400, :6400 ]
print(img_scaled_resized.shape)
labels_resized = labels[:6400, :6400]
print(labels_resized.shape)

#splitting images
split_img = [
    np.split(array, 25, axis=0) 
    for array in np.split(img_scaled_resized, 25, axis=1)
]

split_img[-1][-1].shape

#splitting labels
split_labels = [
    np.split(array, 25, axis=0) 
    for array in np.split(labels_resized, 25, axis=1)
]

#Convert to np.array
split_labels = np.array(split_labels)
split_img = np.array(split_img)

train_images = np.reshape(split_img, (625, 256, 256, 3))
train_labels = np.reshape(split_labels, (625, 256, 256, 10))

train_labels = np_utils.to_categorical(train_labels, 10)

#Create train test and val
x_train = train_images[:400,:,:,:]
x_val = train_images[400:500,:,:,:]
x_test = train_images[500:625,:,:,:]
y_train = train_labels[:400,:,:]
y_val = train_labels[400:500,:,:]
y_test = train_labels[500:625,:,:]

# Create image generator (credit to Ioannis Nasios)
data_gen_args = dict(rotation_range=5,
                     width_shift_range=0.1,
                     height_shift_range=0.1,
                     validation_split=0.2)
image_datagen = ImageDataGenerator(**data_gen_args)

seed = 1
batch_size = 100

def XYaugmentGenerator(X1, y, seed, batch_size):
    genX1 = gen.flow(X1, y, batch_size=batch_size, seed=seed)
    genX2 = gen.flow(y, X1, batch_size=batch_size, seed=seed)
    while True:
        X1i = genX1.next()
        X2i = genX2.next()

        yield X1i[0], X2i[0]


# Train model
Model.fit_generator(XYaugmentGenerator(x_train, y_train, seed, batch_size), steps_per_epoch=np.ceil(float(len(x_train)) / float(batch_size)),
                validation_data = XYaugmentGenerator(x_val, y_val,seed, batch_size), 
                validation_steps = np.ceil(float(len(x_val)) / float(batch_size))
, shuffle=True, epochs=20)

最佳答案

您的代码中有一些错误,但考虑到您的错误:

TypeError: fit_generator() missing 1 required positional argument: 'generator'

这是因为fit_generator调用了XYaugmentGenerator,但内部没有调用增广生成器。

gen.flow(...

不会工作,因为 gen 没有声明。您应该将 image_datagen 重命名为 gen 为:

gen = ImageDataGenerator(**data_gen_args)

或者,将 gen 替换为 image_datagen

genX1 = image_datagen.flow(X1, y, batch_size=batch_size, seed=seed)
genX2 = image_datagen.flow(y, X1, batch_size=batch_size, seed=seed)

关于python - 图像生成器缺少 unet keras 的位置参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53410205/

相关文章:

machine-learning - 朴素贝叶斯分类器中的 RMSE

python - 给定矩阵 M x N 的 k 近邻

python - 从 Google App Engine python 发送 iOS 推送通知

python - Keras:模型过度拟合?

react-native - 未知激活 : swish

deep-learning - Keras语义分割加权损失像素图

neural-network - 添加 dropout 后,我​​的神经网络比以前更容易过拟合。这是怎么回事?

python - AWS Boto3 获取未经身份验证的文件列表

Python - 如何断言未使用特定参数调用模拟对象?

r - 如何解释该分类 TreeMap 中的预测?