python - 自定义池化/反池化层的 Tensorflow Reshape 错误

标签 python tensorflow machine-learning keras computer-vision

我正在尝试实现本文 (https://arxiv.org/pdf/1511.00561.pdf) 中描述的 SegNet 的较小规模版本,但我正在尝试对其进行调整以检测边缘

数据集: 我使用的是 BSDS500 边界数据集,我裁剪并旋转了图像,使它们的大小为 320x480x3 而不是 321x481x3

输入形状、200 张训练图像和 100 张验证图像:

x_train: (200, 320, 480, 3)
x_val: (100, 320, 480, 3)
y_train: (200, 153600)
y_val: (100, 153600)

框架: 我正在使用带有 tensorflow 后端的 keras

这些是我用于自定义池化和反池化层的函数:

def pool_argmax2D(x, pool_size=(2,2), strides=(2,2)):
    padding = 'SAME'
    pool_size = [1, pool_size[0], pool_size[1], 1]
    strides = [1, strides[0], strides[1], 1]
    ksize = [1, pool_size[0], pool_size[1], 1]
    output, argmax = tf.nn.max_pool_with_argmax(
        x,
        ksize = ksize,
        strides = strides,
        padding = padding
    )

    return [output, argmax]


def unpool2D(pool, argmax, ksize=(2,2)):
    with tf.variable_scope("unpool"):
        input_shape =  tf.shape(pool)
        output_shape = [input_shape[0], 
                        input_shape[1] * ksize[0], 
                        input_shape[2] * ksize[1], 
                        input_shape[3]]

        flat_input_size = tf.cumprod(input_shape)[-1]
        flat_output_shape = tf.cast([output_shape[0], 
                            output_shape[1] * output_shape[2] * output_shape[3]], tf.int64)

        pool_ = tf.reshape(pool, [flat_input_size])
        batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=tf.int64),
                                shape=[input_shape[0], 1, 1, 1])

        b = tf.ones_like(argmax) * batch_range
        b = tf.reshape(b, [flat_input_size, 1])

        ind_ = tf.reshape(argmax, [flat_input_size, 1]) % flat_output_shape[1]
        ind_ = tf.concat([b, ind_], 1)
        ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape)
        ret = tf.reshape(ret, output_shape)
        return ret

这是模型的代码:

batch_size = 4
kernel = 3
pool_size=(2,2)
img_shape = (320,480,3)


inputs = Input(shape=img_shape, name='main_input')

conv_1 = Conv2D(32, (kernel, kernel), padding="same")(inputs)
conv_1 = BatchNormalization()(conv_1)
conv_1 = Activation("relu")(conv_1)
conv_2 = Conv2D(32, (kernel, kernel), padding="same")(conv_1)
conv_2 = BatchNormalization()(conv_2)
conv_2 = Activation("relu")(conv_2)

pool_1, mask_1 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_2)

conv_3 = Conv2D(64, (kernel, kernel), padding="same")(pool_1)
conv_3 = BatchNormalization()(conv_3)
conv_3 = Activation("relu")(conv_3)
conv_4 = Conv2D(64, (kernel, kernel), padding="same")(conv_3)
conv_4 = BatchNormalization()(conv_4)
conv_4 = Activation("relu")(conv_4)

pool_2, mask_2 = Lambda(pool_argmax2D, arguments={'pool_size': pool_size, 'strides': pool_size})(conv_4)

conv_5 = Conv2D(64, (kernel, kernel), padding="same")(pool_2)
conv_5 = BatchNormalization()(conv_5)
conv_5 = Activation("relu")(conv_5)

unpool_1 = Lambda(unpool2D, output_shape = (160,240,64), arguments={'ksize':pool_size, 'argmax': mask_2})(conv_5)

conv_6 = Conv2D(64, (kernel, kernel), padding="same")(unpool_1)
conv_6 = BatchNormalization()(conv_6)
conv_6 = Activation("relu")(conv_6)
conv_7 = Conv2D(64, (kernel, kernel), padding="same")(conv_6)
conv_7 = BatchNormalization()(conv_7)
conv_7 = Activation("relu")(conv_7)

unpool_2 = Lambda(unpool2D, output_shape = (320,480,64), arguments={'ksize':pool_size, 'argmax': mask_1})(conv_7)

conv_8 = Conv2D(32, (kernel, kernel), padding="same")(unpool_2)
conv_8 = BatchNormalization()(conv_8)
conv_8 = Activation("relu")(conv_8)
conv_9 = Conv2D(32, (kernel, kernel), padding="same")(conv_8)
conv_9 = BatchNormalization()(conv_9)
conv_9 = Activation("relu")(conv_9)

conv_10 = Conv2D(1, (1, 1), padding="same")(conv_9)
conv_10 = BatchNormalization()(conv_10)

flatten_1 = Flatten()(conv_10)

outputs = Activation('softmax')(flatten_1)

model = Model(inputs=inputs, outputs=outputs)

模型在我运行时正确编译:

model.compile(optimizer='adam', loss='mean_absolute_error', metrics=['accuracy'])
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
main_input (InputLayer)      (None, 320, 480, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 320, 480, 32)      896       
_________________________________________________________________
batch_normalization_1 (Batch (None, 320, 480, 32)      128       
_________________________________________________________________
activation_1 (Activation)    (None, 320, 480, 32)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 320, 480, 32)      9248      
_________________________________________________________________
batch_normalization_2 (Batch (None, 320, 480, 32)      128       
_________________________________________________________________
activation_2 (Activation)    (None, 320, 480, 32)      0         
_________________________________________________________________
lambda_1 (Lambda)            [(None, 160, 240, 32), (N 0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 160, 240, 64)      18496     
_________________________________________________________________
batch_normalization_3 (Batch (None, 160, 240, 64)      256       
_________________________________________________________________
activation_3 (Activation)    (None, 160, 240, 64)      0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 160, 240, 64)      36928     
_________________________________________________________________
batch_normalization_4 (Batch (None, 160, 240, 64)      256       
_________________________________________________________________
activation_4 (Activation)    (None, 160, 240, 64)      0         
_________________________________________________________________
lambda_2 (Lambda)            [(None, 80, 120, 64), (No 0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 80, 120, 64)       36928     
_________________________________________________________________
batch_normalization_5 (Batch (None, 80, 120, 64)       256       
_________________________________________________________________
activation_5 (Activation)    (None, 80, 120, 64)       0         
_________________________________________________________________
lambda_3 (Lambda)            (None, 160, 240, 64)      0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 160, 240, 64)      36928     
_________________________________________________________________
batch_normalization_6 (Batch (None, 160, 240, 64)      256       
_________________________________________________________________
activation_6 (Activation)    (None, 160, 240, 64)      0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 160, 240, 64)      36928     
_________________________________________________________________
batch_normalization_7 (Batch (None, 160, 240, 64)      256       
_________________________________________________________________
activation_7 (Activation)    (None, 160, 240, 64)      0         
_________________________________________________________________
lambda_4 (Lambda)            (None, 320, 480, 64)      0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 320, 480, 32)      18464     
_________________________________________________________________
batch_normalization_8 (Batch (None, 320, 480, 32)      128       
_________________________________________________________________
activation_8 (Activation)    (None, 320, 480, 32)      0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 320, 480, 32)      9248      
_________________________________________________________________
batch_normalization_9 (Batch (None, 320, 480, 32)      128       
_________________________________________________________________
activation_9 (Activation)    (None, 320, 480, 32)      0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 320, 480, 1)       33        
_________________________________________________________________
batch_normalization_10 (Batc (None, 320, 480, 1)       4         
_________________________________________________________________
flatten_1 (Flatten)          (None, 153600)            0         
_________________________________________________________________
activation_10 (Activation)   (None, 153600)            0         
=================================================================
Total params: 205,893
Trainable params: 204,995
Non-trainable params: 898
_________________________________________________________________

但是当试图拟合模型时

history = model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=3, verbose=2, validation_data=(x_val,y_val))

我遇到这个错误:

InvalidArgumentError: Input to reshape is a tensor with 4915200 values, but the requested shape has 9830400
     [[{{node lambda_4/unpool/Reshape_3}} = Reshape[T=DT_INT64, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/device:GPU:0"](lambda_1/MaxPoolWithArgmax:1, lambda_4/unpool/Reshape_2/shape)]]
     [[{{node lambda_4/unpool/strided_slice_6/_515}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_1479_lambda_4/unpool/strided_slice_6", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

我查看了每一层之后的所有形状,它们是我所期望的。我还测试了样本张量上的池化/解池化函数,它们产生了预期的输出。我在这里做错了什么?

我一直在努力解决这个问题,非常感谢任何帮助!

最佳答案

发现问题,mask_1 有 32 个 channel ,而 unpool_2 试图用 64 个 channel reshape 输出。我只是重新安排了一些东西,使深度对齐。

关于python - 自定义池化/反池化层的 Tensorflow Reshape 错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53392793/

相关文章:

python - 给定 xy 触摸位置,如何用颜色填充 RGBA 图像的透明部分?

python - 从停用词列表中重新替换多个字符串模式

python - ValueError : Cannot set tensor: Got value of type NOTYPE but expected type FLOAT64 for input 0, 名称:serving_default_z_raw_min:0

tensorflow - 当运行 hexagon nnlib 中的独立 graph_app 时,Shell 被卡住

python - Scipy/Numpy FFT 频率分析

python - 如何获取在 Kivy 中使用 fileChooser 选择的文件的信息?

python - 如何在 Keras 中为 VGG16 微调预处理训练集?

python - 存储和使用经过训练的神经网络

java - 在 Java TensorFlow 1.15 中使用 Python 构建的 TensorFlow 2.1.0 模型 |图表中没有名为 [input] 的操作

machine-learning - XGBoost 和 Sklearn 中的对数损失相同吗?