tensorflow - 连接 ConvLSTM2D 模型和表格模型的更好方法

标签 tensorflow machine-learning keras lstm

我建立了一个模型,该模型将时间序列的 3 个图像以及 5 个数字信息作为输入,并生成时间序列的下三个图像。
我通过以下方式完成了这项工作:

  • 构建用于处理图像的 ConvLSTM2D 模型(与 Keras 文档 here 中列出的示例非常相似)。输入尺寸=(3x128x128x3)
  • 为具有几个 Dense 层的表格数据构建一个简单的模型。输入大小=(1,5)
  • 连接这两个模型
  • 有一个 Conv3D 模型可以生成接下来的 3 个图像

  • LSTM 模型产生大小为 393216 (3x128x128x8) 的输出。现在我必须将表格模型的输出设置为 49,152,以便在下一层输入大小为 442368 (3x128x128x9)。因此,表格模型的 Dense 层的这种不必要的膨胀使得原本高效的 LSTM 模型表现得非常糟糕。
    有没有更好的方法来连接两个模型?有没有办法在表格模型的 Dense 层中只输出 10?
    该模型:
    x_input = Input(shape=(None, 128, 128, 3))
    x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
    x = BatchNormalization()(x)
    x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
    x = BatchNormalization()(x)
    x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
    x = BatchNormalization()(x)
    x = Flatten()(x)
    # x = MaxPooling3D()(x)
    
    x_tab_input = Input(shape=(5))
    x_tab = Dense(100, activation="relu")(x_tab_input)
    x_tab = Dense(49152, activation="relu")(x_tab)
    x_tab = Flatten()(x_tab)
    
    concat = Concatenate()([x, x_tab])
    
    output = Reshape((3,128,128,9))(concat)
    output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)
    model = Model([x_input, x_tab_input], output)
    model.compile(loss='mae', optimizer='rmsprop')
    
    型号概要:
    Model: "functional_3"
    ______________________________________________________________________________________________________________________________________________________
    Layer (type)                                     Output Shape                     Param #           Connected to                                      
    ======================================================================================================================================================
    input_4 (InputLayer)                             [(None, None, 128, 128, 3)]      0                                                                   
    ______________________________________________________________________________________________________________________________________________________
    conv_lst_m2d_9 (ConvLSTM2D)                      (None, None, 128, 128, 32)       40448             input_4[0][0]                                     
    ______________________________________________________________________________________________________________________________________________________
    batch_normalization_9 (BatchNormalization)       (None, None, 128, 128, 32)       128               conv_lst_m2d_9[0][0]                              
    ______________________________________________________________________________________________________________________________________________________
    conv_lst_m2d_10 (ConvLSTM2D)                     (None, None, 128, 128, 16)       27712             batch_normalization_9[0][0]                       
    ______________________________________________________________________________________________________________________________________________________
    batch_normalization_10 (BatchNormalization)      (None, None, 128, 128, 16)       64                conv_lst_m2d_10[0][0]                             
    ______________________________________________________________________________________________________________________________________________________
    input_5 (InputLayer)                             [(None, 5)]                      0                                                                   
    ______________________________________________________________________________________________________________________________________________________
    conv_lst_m2d_11 (ConvLSTM2D)                     (None, None, 128, 128, 8)        6944              batch_normalization_10[0][0]                      
    ______________________________________________________________________________________________________________________________________________________
    dense (Dense)                                    (None, 100)                      600               input_5[0][0]                                     
    ______________________________________________________________________________________________________________________________________________________
    batch_normalization_11 (BatchNormalization)      (None, None, 128, 128, 8)        32                conv_lst_m2d_11[0][0]                             
    ______________________________________________________________________________________________________________________________________________________
    dense_1 (Dense)                                  (None, 49152)                    4964352           dense[0][0]                                       
    ______________________________________________________________________________________________________________________________________________________
    flatten_3 (Flatten)                              (None, None)                     0                 batch_normalization_11[0][0]                      
    ______________________________________________________________________________________________________________________________________________________
    flatten_4 (Flatten)                              (None, 49152)                    0                 dense_1[0][0]                                     
    ______________________________________________________________________________________________________________________________________________________
    concatenate (Concatenate)                        (None, None)                     0                 flatten_3[0][0]                                   
                                                                                                        flatten_4[0][0]                                   
    ______________________________________________________________________________________________________________________________________________________
    reshape_2 (Reshape)                              (None, 3, 128, 128, 9)           0                 concatenate[0][0]                                 
    ______________________________________________________________________________________________________________________________________________________
    conv3d_2 (Conv3D)                                (None, 3, 128, 128, 3)           732               reshape_2[0][0]                                   
    ======================================================================================================================================================
    Total params: 5,041,012
    Trainable params: 5,040,900
    Non-trainable params: 112
    ______________________________________________________________________________________________________________________________________________________
    

    最佳答案

    我同意你说的巨大Dense层(具有数百万个参数)可能会阻碍模型的性能。而不是用 Dense 来膨胀表格数据层,您宁愿选择以下两种方法之一。

    选项 1:平铺x_tab张量,使其与您想要的形状相匹配。这可以通过以下步骤来实现:
    首先,没有必要把ConvLSTM2D弄平。的编码张量:

    x_input = Input(shape=(3, 128, 128, 3))
    x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
    x = BatchNormalization()(x)
    x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
    x = BatchNormalization()(x)
    x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
    x = BatchNormalization()(x)  # Shape=(None, None, 128, 128, 8) 
    # Commented: x = Flatten()(x)
    
    其次,您可以使用一个或多个 Dense 处理表格数据。层。例如:
    dim = 10
    x_tab_input = Input(shape=(5))
    x_tab = Dense(100, activation="relu")(x_tab_input)
    x_tab = Dense(dim, activation="relu")(x_tab)
    # x_tab = Flatten()(x_tab)  # Note: Flattening a 2D tensor leaves the tensor unchanged
    
    第三,我们包装tensorflow操作tf.tileLambda层,有效地创建张量的副本 x_tab以便它匹配所需的形状:
    def repeat_tabular(x_tab):
        h = x_tab[:, None, None, None, :]  # Shape=(bs, 1, 1, 1, dim)
        h = tf.tile(h, [1, 3, 128, 128, 1])  # Shape=(bs, 3, 128, 128, dim)
        return h
    x_tab = Lambda(repeat_tabular)(x_tab)
    
    最后,我们连接 x和瓷砖 x_tab沿最后一个轴的张量(您也可以考虑沿第一个轴连接,对应于 channel 的维度)
    concat = Concatenate(axis=-1)([x, x_tab])  # Shape=(3,128,128,8+dim)
    output = concat
    output = Conv3D(filters=3, kernel_size=(3, 3, 3), activation='relu', padding="same")(output)
    # ...
    
    请注意,这个解决方案可能有点幼稚,因为模型没有将图像的输入序列编码为低维表示,限制了网络的感受野,并可能导致性能下降。

    选项 2:类似于自动编码器和 U-Net ,可能需要将您的图像序列编码为低维表示,以丢弃不需要的变化(例如噪声),同时保留有意义的信号(例如推断序列的下 3 个图像所需)。这可以通过以下方式实现:
    首先,将输入的图像序列编码为低维二维张量。例如,类似于以下内容:
    x_input = Input(shape=(None, 128, 128, 3))
    x = ConvLSTM2D(32, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x_input)
    x = BatchNormalization()(x)
    x = ConvLSTM2D(16, 3, strides = 1, padding='same', dilation_rate = 2,return_sequences=True)(x)
    x = BatchNormalization()(x)
    x = ConvLSTM2D(8, 3, strides = 1, padding='same', dilation_rate = 2, return_sequences=False)(x)
    x = BatchNormalization()(x)
    x = Flatten()(x)
    x = Dense(64, activation='relu')(x)
    
    注意最后一个ConvLSTM2D不返回序列。您可能想要探索不同的编码器以达到这一点(例如,您也可以在这里使用池化层)。
    其次,使用 Dense 处理您的表格数据层。例如:
    dim = 10
    x_tab_input = Input(shape=(5))
    x_tab = Dense(100, activation="relu")(x_tab_input)
    x_tab = Dense(dim, activation="relu")(x_tab)
    
    第三,连接前两个流中的数据:
    concat = Concatenate(axis=-1)([x, x_tab])
    
    四、使用Dense + Reshape层将连接的向量投影到一系列低分辨率图像中:
    h = Dense(3 * 32 * 32 * 3)(concat)
    output = Reshape((3, 32, 32, 3))(h)
    
    output的形状允许将图像上采样为 (128, 128, 3) 的形状,但它是任意的(例如,您可能还想在这里进行实验)。
    最后,申请一个或几个Conv3DTranspose层以获得所需的输出(例如 3 张形状为 (128, 128, 3) 的图像)。
    output = tf.keras.layers.Conv3DTranspose(filters=50, kernel_size=(3, 3, 3),
                                             strides=(1, 2, 2), padding='same',
                                             activation='relu')(output)
    output = tf.keras.layers.Conv3DTranspose(filters=3, kernel_size=(3, 3, 3),
                                             strides=(1, 2, 2), padding='same',
                                             activation='relu')(output)  # Shape=(None, 3, 128, 128, 3)
    
    讨论了转置卷积层背后的基本原理 here .本质上,Conv3DTranspose层与正常卷积相反 - 它允许将低分辨率图像上采样为高分辨率图像。

    关于tensorflow - 连接 ConvLSTM2D 模型和表格模型的更好方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65963752/

    相关文章:

    python - 方法对象在神经网络代码中不可下标

    python - Mac 上的 TensorFlow 安装错误

    python - TensorFlow 神经网络输出线性函数

    python - 断言错误 : Tried to export a function which references untracked resource

    python - 如何将这些数据加载到 LSTM 中?

    python - 神经网络预测足球结果

    python - 为什么 tensorflow 将 one_hot 值加倍?

    math - 贝叶斯曲线拟合模型

    machine-learning - 证明多元线性回归模型效率的最佳 RMSE(均方根误差)值范围是多少?

    python - 名称错误 : name 'history' is not defined