machine-learning - 来自 for 循环的 Keras LSTM,使用具有自定义层数的函数式 API

标签 machine-learning keras lstm keras-layer

我正在尝试通过 keras 功能 API 构建一个网络,提供两个列表,其中包含 LSTM 层和 FC(密集)层的单元数量。我想分析 20 个连续的段(批处理),每个段包含 fs 个时间步长和 2 个值(每个时间步长 2 个特征)。这是我的代码:

Rec = [4,4,4]  
FC = [8,4,2,1]    
def keras_LSTM(Rec,FC,fs, n_witness, lr=0.04, optimizer='Adam'):
    model_LSTM = Input(batch_shape=(20,fs,n_witness))
    return_state_bool=True
    for i in range(shape(Rec)[0]):
        nRec = Rec[i]
        if i == shape(Rec)[0]-1:
            return_state_bool=False
        model_LSTM = LSTM(nRec, return_sequences=True,return_state=return_state_bool,
                     stateful=True, input_shape=(None,n_witness),            
                     name='LSTM'+str(i))(model_LSTM)
    for j in range(shape(FC)[0]):
        nFC = FC[j]
        model_LSTM = Dense(nFC)(model_LSTM)
        model_LSTM = LeakyReLU(alpha=0.01)(model_LSTM)
    nFC_final = 1
    model_LSTM = Dense(nFC_final)(model_LSTM)
    predictions = LeakyReLU(alpha=0.01)(model_LSTM)

    full_model_LSTM = Model(inputs=model_LSTM, outputs=predictions)
    model_LSTM.compile(optimizer=keras.optimizers.Adam(lr=lr, beta_1=0.9, beta_2=0.999,
                    epsilon=1e-8, decay=0.066667, amsgrad=False), loss='mean_squared_error')
    return full_model_LSTM

model_new = keras_LSTM(Rec, FC, fs=fs, n_witness=n_wit)
model_new.summary()

编译时出现以下错误:

ValueError:图形已断开连接:无法获取层“input_1”处的张量 Tensor("input_1:0", shape=(20, 2048, 2), dtype=float32) 的值。访问之前的以下层没有出现问题:[]

我其实不太明白,但怀疑它可能与输入有关?

最佳答案

我通过修改代码第 4 行解决了这个问题,如下所示:

x = model_LSTM = Input(batch_shape=(20,fs,n_witness))

以及第 21 行,如下所示:

full_model_LSTM = Model(inputs=x, outputs=predictions)

关于machine-learning - 来自 for 循环的 Keras LSTM,使用具有自定义层数的函数式 API,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54241764/

相关文章:

tensorflow - 检查点 keras 模型 : TypeError: can't pickle _thread. 锁定对象

keras - 如何在具有共享嵌入层和负采样的 keras 中实现 word2vec CBOW?

python - 属性错误: 'NoneType' object has no attribute 'update'

python - 实现多对多回归任务

keras - 时间序列数据的 BatchNormalization 层的 axis 参数设置什么?

python - 通过 python3 中连接的集群的多数投票进行标记

r - R 的 MLR 中的预测函数产生的结果与预测不一致

python - scipy.optimize.minimize 在神经网络中的使用

machine-learning - 使用 scikit-learn 进行文本分类 : how to get a new document's representation from a pickle model

python - 为什么我的 keras 模型有这么多参数?