python - 人脸识别keras维数问题

标签 python python-3.x tensorflow machine-learning keras

我们正在尝试使用 keras 进行图像识别,但出现以下错误:ValueError: 检查输入时出错:预期 conv2d_93_input 有 4 个维度,但得到了形状为 (4999, 40) 的数组因此,我们使用 imread 读取图像数据,然后将其放入数组中,但出于某种原因,keras 需要第四维。

这就是我们读取文件的方式:

def generator(BatchSize):
text_file = open("/content/list_attr_celeba.txt", "r")
lines = text_file.readlines()
lines = lines[2:]
prew = 1
e = []
while True:
    for i in range(prew,prew+BatchSize):
        #print(i)
        lines[i] = lines[i].split()
        name = lines[i][0]
        lines[i] =  lines[i][1:]
        a = imread('/content/img_align_celeba/' + name)
        #b =  numpy.zeros(4,1)
        #print(a)
        e.append(numpy.array(a))            
        if i % BatchSize == 0 and i != 0:
            yield (numpy.array(lines[prew:i]),e)                
            e = []
            prew = i+1

这就是我们定义生成器和模型的方式

 gen = generator(5000)
 model = Sequential()
model.add(Conv2D(32, kernel_size=(5, 5), strides=(1, 1),
             activation='relu',
             input_shape=((170,140,3))))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(64, (5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(1000, activation='relu'))
model.add(Dense(40, activation='sigmoid'))
model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = 
['accuracy'])

这就是我们的风格

model.fit_generator(gen, epochs=2, verbose=1,  max_queue_size=10, 
workers=1, use_multiprocessing=False, shuffle=False, initial_epoch=0, 
steps_per_epoch = 4)

最佳答案

将数据 reshape 为 (4999, 40,1),例如添加大小为 1 的维度,conv2d 需要 (batch_size, x, y, 过滤器)

a = numpy.array(a)
e.append(a.reshape((a.shape + (1,)))

关于python - 人脸识别keras维数问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53905433/

相关文章:

python - 如何找到所有()一个 Pandas 数据框的正则表达式序列?

Tensorflow CTC Loss Sequence Length 参数

python-3.x - ubuntu、python、ssl.SSLError : [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl. c:852)

python - 如何防止 Python 函数返回 None

python - 在 Ubuntu 中安装 GDAL Python 绑定(bind)以用作独立模块

python - 将非理想列表格式导出到 Excel

Python:如何对 Pandas 中的人之间的付款进行分组和求和?

python - 为什么我收到错误消息 : list index out of range in Python

javascript - 使用 Python-BeautifulSoup 和 urllib 抓取奇怪的 html 设置

python - Tensorflow 图像分类 Python 总是说相同的答案