我正在使用 CK+ 数据集进行面部表情识别,并通过 datagen.flow_from_directory
传递面部图像和标签,以提取面部特征并映射到标签。
标签作为分类值传递,范围从 0 到 7。同样的值似乎作为 one-hot 编码形式传递。我的问题是我可以将标签值作为独热编码值进行广播。
我收到以下错误:
ValueError:无法将输入数组从形状 (32,8) 广播到形状 (32)
代码如下:
import scipy
import os, shutil
from tensorflow.keras.preprocessing.image import ImageDataGenerator
img_width, img_height = 224, 224
datagen = ImageDataGenerator(rescale=1./255)
batch_size = 32
def extract_features(directory, sample_count):
features = np.zeros(shape=(sample_count, 7, 7, 512)) # Must be equal to the output of the convolutional base
labels = np.zeros(shape=(sample_count))
print(sample_count, 7, 7, 512)
# Preprocess data - flow_from_directory allows us to extract
#... features and labels directly from a directory
generator = datagen.flow_from_directory(directory,
target_size=(img_width,img_height),
batch_size = batch_size,
class_mode='categorical')
i = 0
for inputs_batch, labels_batch in generator:
features_batch = conv_base.predict(inputs_batch)
features[i * batch_size: (i + 1) * batch_size] = features_batch
labels[i * batch_size: (i + 1) * batch_size] = labels_batch
i += 1
if i * batch_size >= sample_count:
break
return features, labels
我得到以下形状:
Found 209 images belonging to 8 classes.
Input batch shape: (32, 224, 224, 3)
Features batch shape: (32, 7, 7, 512)
Features shape: (209, 7, 7, 512)
Labels batch shape: (32, 8)
所以我很困惑为什么features_batch
可以广播,但labels_batch
不能。
我尝试了几件事,其中包括:
1)展平标签数组 - 这没有意义,但只是为了查看并获取跨行和列的完整元素计数32*8=259(如预期)。
2)我尝试仅使用 labels[i]=labels_batch
和 labels=labels_batch
,它只返回最后几个标签
(17,209-(6*32)=17 剩下的)。
3) 我还尝试从 this question 插入另一个解决方案。 通过这样做:
for c in range(0,7):
labels[i * batch_size: (i + 1) * batch_size, [c]] = labels_batch
但出现以下错误:
ValueError: Error when checking input: expected input_3 to have 4 dimensions, but got array with shape (32, 8)
我觉得我缺少的东西很简单,但我似乎无法弄清楚。有人有什么想法吗?
谢谢!
最佳答案
你的标签应该是形状labels = np.zeros(shape=(sample_count, num_classes))
而不是labels = np.zeros(shape=(sample_count))
生成器的标签分配应该是
labels[i * batch_size: (i + 1) * batch_size,:] = labels_batch
关于python - 从生成器广播标签数据时出现问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59401615/