python - 在 Tensorflow 中使用 3d 转置卷积进行上采样

标签 python tensorflow computer-vision deep-learning convolution

我在 Tensorflow 中定义 3D 转置卷积如下:

def weights(shape):
    return tf.Variable(tf.truncated_normal(shape, mean = 0.0, stddev=0.1))

def biases(shape):
    return tf.Variable(tf.constant(value = 0.1, shape = shape))

def trans_conv3d(x, W, output_shape, strides, padding):
    return tf.nn.conv3d_transpose(x, W, output_shape, strides, padding)

def transconv3d_layer(x, shape, out_shape, strides, padding):
   # shape: [depth, height, width, output_channels, in_channels].
   # output_shape: [batch, depth, height, width, output_channels]
    W = weights(shape)
    b = biases([shape[4]]) 
    return tf.nn.elu(trans_conv3d(x, W, out_shape, strides, padding) + b)

假设我有一个来自上一层的 4D 张量 x,形状为 [2, 1, 1, 1, 10],其中 batch = 2 code>、深度 = 1高度 = 1宽度 = 1in_channels = 10 如所规定here .

如何使用 transconv3d_layerx 进行上采样,在一系列层上,以获得最终形状,例如 [2, 100, 100, 100, 10] 或类似的东西?我不清楚如何通过转置层跟踪张量的形状。

最佳答案

使用方法如下:

input = tf.random_normal(shape=[2, 1, 1, 1, 10])
deconv1 = transconv3d_layer(input,
                            shape=[2, 3, 3, 10, 10],
                            out_shape=[2, 50, 50, 50, 10],
                            strides=[1, 1, 1, 1, 1],
                            padding='SAME')
deconv2 = transconv3d_layer(deconv1,
                            shape=[2, 3, 3, 10, 10],
                            out_shape=[2, 100, 100, 100, 10],
                            strides=[1, 1, 1, 1, 1],
                            padding='SAME')
# deconv3 ...

print(deconv1)  # Tensor("Elu:0", shape=(2, 50, 50, 50, 10), dtype=float32)
print(deconv2)  # Tensor("Elu_1:0", shape=(2, 100, 100, 100, 10), dtype=float32)

基本上,您应该将每个 out_shape 指定为您想要上采样的形状:(2, 50, 50, 50, 10), (2, 100, 100, 100, 10) , ...

为了清楚起见,以下是上面不同张量中的维度的含义:

input shape:  [batch, depth, height, width, in_channels]
filter shape: [depth, height, width, output_channels, in_channels]
output shape: [batch, depth, height, width, output_channels]

关于python - 在 Tensorflow 中使用 3d 转置卷积进行上采样,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48327962/

相关文章:

python - python mako 模板是否支持循环上下文中的 connitue/break?

go - 如何使用 Golang 将数据转换为序列化 tf.Example(tensorflow tfrecords)

opencv - 比较F矩阵和E矩阵

c++ - 使用 Vlfeat 的 C API 进行快速换档

machine-learning - 计算机视觉和机器学习中特征描述符的解释

Python:ElementTree,获取一个元素的命名空间字符串

python - 减少n维numpy数组维数的有效方法

android - 保存/传输模型 - Android 上的 TensorFlow Lite 迁移学习

python - 使用 anaconda navigator 在 Windows 上导入 TensorFlow 时出错

python - 在 python 中从 opencv 写入 Gstreamer 管道