python - 如何(有效地)在 TensorFlow 中应用 channel 级全连接层

标签 python tensorflow deep-learning autoencoder

我又来找你了,我在为我可以开始工作的事情挠头,但真的很慢。我希望你能帮助我优化它。

我正在尝试在 TensorFlow 中实现一个卷积自动编码器,在编码器和解码器之间有很大的潜在空间。通常,人们会使用全连接层将编码器连接到解码器,但由于这个潜在空间具有高维性,因此这样做会产生太多的特征,使其在计算上不可行。

我在 this paper 中找到了解决此问题的好方法.他们称之为“ channel 式全连接层”。它基本上是每个 channel 的全连接层。

我正在努力实现并让它工作,但图表的生成需要很长时间。到目前为止,这是我的代码:

def _network(self, dataset, isTraining):
        encoded = self._encoder(dataset, isTraining)
        with tf.variable_scope("fully_connected_channel_wise"):
            shape = encoded.get_shape().as_list()
            print(shape)
            channel_wise = tf.TensorArray(dtype=tf.float32, size=(shape[-1]))
            for i in range(shape[-1]):  # last index in shape should be the output channels of the last conv
                channel_wise = channel_wise.write(i, self._linearLayer(encoded[:,:,i], shape[1], shape[1]*4, 
                                  name='Channel-wise' + str(i), isTraining=isTraining))
            channel_wise = channel_wise.concat()
            reshape = tf.reshape(channel_wise, [shape[0], shape[1]*4, shape[-1]])
        reconstructed = self._decoder(reshape, isTraining)
        return reconstructed

那么,关于为什么要花这么长时间,有什么想法吗?这实际上是一个范围(2048),但所有线性层都非常小(4x16)。我是不是以错误的方式处理这个问题?

谢谢!

最佳答案

您可以在 Tensorflow 中查看他们对该论文的实现情况。 这是他们的“ channel 级全连接层”的实现。

def channel_wise_fc_layer(self, input, name): # bottom: (7x7x512)
    _, width, height, n_feat_map = input.get_shape().as_list()
    input_reshape = tf.reshape( input, [-1, width*height, n_feat_map] )
    input_transpose = tf.transpose( input_reshape, [2,0,1] )

    with tf.variable_scope(name):
        W = tf.get_variable(
                "W",
                shape=[n_feat_map,width*height, width*height], # (512,49,49)
                initializer=tf.random_normal_initializer(0., 0.005))
        output = tf.batch_matmul(input_transpose, W)

    output_transpose = tf.transpose(output, [1,2,0])
    output_reshape = tf.reshape( output_transpose, [-1, height, width, n_feat_map] )

    return output_reshape

https://github.com/jazzsaxmafia/Inpainting/blob/8c7735ec85393e0a1d40f05c11fa1686f9bd530f/src/model.py#L60

主要思想是使用 tf.batch_matmul 函数。

但是,tf.batch_matmul 在最新版本的 Tensorflow 中被删除了,您可以使用 tf.matmul 来替换它。

关于python - 如何(有效地)在 TensorFlow 中应用 channel 级全连接层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47556001/

相关文章:

tensorflow - 仅在 TF 中解包

python - TensorFlow 的参数无效错误(形状不兼容)

python - Pandas 无法打开 csv 文件 FileNotFoundError

python - 滚动平均图例标签

python - 存储倒排索引

TensorFlow 服务 : "No assets to save/writes" when exporting models

python - numpy 的 argsort 可以给相等的元素相同的等级吗?

python - 凯拉斯输入/输出

python - Tensorflow:保存/恢复变量加载不正确

tensorflow - 为什么 DeepLabV3+ 生成的所有图像都变成黑色?