我正在创建一个模型来对输入波形是否包含 I2C 线的 SDA 上升沿进行分类。
我的输入有 20000 个数据点和 100 个训练数据。
我最初找到了有关此处输入的答案 Keras 1D CNN: How to specify dimension correctly?
但是,我在激活函数中遇到错误:
ValueError: Error when checking target: expected activation_1 to have 3 dimensions, but got array with shape (100, 1)
我的模型是:
model.add(Conv1D(filters=n_filter,
kernel_size=input_filter_length,
strides=1,
activation='relu',
input_shape=(20000,1)))
model.add(BatchNormalization())
model.add(MaxPooling1D(pool_size=4, strides=None))
model.add(Dense(1))
model.add(Activation("sigmoid"))
adam = Adam(lr=learning_rate)
model.compile(optimizer= adam, loss='binary_crossentropy', metrics=['accuracy'])
model.fit(train_data, train_label,
nb_epoch=10,
batch_size=batch_size, shuffle=True)
score = np.asarray(model.evaluate(test_new_data, test_label, batch_size=batch_size))*100.0
我无法确定这里的问题。为什么激活函数需要 3D 张量。
最佳答案
问题在于,从keras 2.0
开始,应用于序列的Dense
层将将该层应用于每个时间步 - 因此给定一个序列将产生一个序列。因此,您的 Dense
实际上生成了一个 1 元素向量序列,这会导致您的问题(因为您的目标不是序列)。
有多种方法可以将序列简化为向量,然后对其应用Dense
:
全局池化
:您可以使用
GlobalPooling
层,例如GlobalAveragePooling1D
或GlobalMaxPooling1D
,例如:model.add(Conv1D(filters=n_filter, kernel_size=input_filter_length, strides=1, activation='relu', input_shape=(20000,1))) model.add(BatchNormalization()) model.add(GlobalMaxPooling1D(pool_size=4, strides=None)) model.add(Dense(1)) model.add(Activation("sigmoid"))
扁平化
:您可以使用
Flatten
层将整个序列折叠为单个向量:model.add(Conv1D(filters=n_filter, kernel_size=input_filter_length, strides=1, activation='relu', input_shape=(20000,1))) model.add(BatchNormalization()) model.add(MaxPooling1D(pool_size=4, strides=None)) model.add(Flatten()) model.add(Dense(1)) model.add(Activation("sigmoid"))
RNN
后处理:您还可以在序列顶部添加一个循环层,并使其仅返回最后一个输出:
model.add(Conv1D(filters=n_filter, kernel_size=input_filter_length, strides=1, activation='relu', input_shape=(20000,1))) model.add(BatchNormalization()) model.add(MaxPooling1D(pool_size=4, strides=None)) model.add(SimpleRNN(10, return_sequences=False)) model.add(Dense(1)) model.add(Activation("sigmoid"))
关于python - Keras 中一维 CNN 的激活函数错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44112236/