python - TensorFlow 中 Max Pooling 2D Layer 的输出张量是什么?

标签 python tensorflow max-pooling

我试图了解有关 tensorflow 的一些基础知识 我在阅读最大池化 2D 层的文档时卡住了:https://www.tensorflow.org/tutorials/layers#pooling_layer_1

这是指定 max_pooling2d 的方式:

pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

其中 conv1 具有形状为 [batch_size, image_width, image_height, channels] 的张量,在本例中具体为 [batch_size, 28, 28, 32 ]

所以我们的输入是一个张量,其形状为:[batch_size, 28, 28, 32]

我对最大池化 2D 层的理解是,它将应用大小为 pool_size(在本例中为 2x2)的过滤器,并以 stride(也为 2x2)移动滑动窗口).这意味着图像的 widthheight 都将减半,即我们最终将得到每个 channel 14x14 像素(总共 32 个 channel ),这意味着我们的输出是形状为张量:[batch_size, 14, 14, 32]

但是,根据上面的链接,输出张量的形状是[batch_size, 14, 14, 1]:

Our output tensor produced by max_pooling2d() (pool1) has a shape of 
[batch_size, 14, 14, 1]: the 2x2 filter reduces width and height by 50%.

我在这里错过了什么?

32 是如何转换为 1 的?

他们稍后在这里应用相同的逻辑:https://www.tensorflow.org/tutorials/layers#convolutional_layer_2_and_pooling_layer_2

但是这次是正确的,即 [batch_size, 14, 14, 64] 变成了 [batch_size, 7, 7, 64] ( channel 数是一样的).

最佳答案

是的,使用 strides=2x2 的 2x2 max pool 会将数据减少一半,输出深度不会改变。这是我从你给定的测试代码,输出形状是 (14, 14, 32),也许有什么问题?

#!/usr/bin/env python

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('./MNIST_data/', one_hot=True)

conv1 = tf.placeholder(tf.float32, [None,28,28,32])
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2,2], strides=2)
print pool1.get_shape()

输出是:

Extracting ./MNIST_data/train-images-idx3-ubyte.gz
Extracting ./MNIST_data/train-labels-idx1-ubyte.gz
Extracting ./MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data/t10k-labels-idx1-ubyte.gz
(?, 14, 14, 32)

关于python - TensorFlow 中 Max Pooling 2D Layer 的输出张量是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43453712/

相关文章:

machine-learning - 在某些情况下,在 CNN 中省略池化层是否有意义?

python - 创建副本时如何避免改变原始全局变量

python-3.x - Wasserstein GAN 批评者训练歧义

python - Tensorflow 中的动态时间扭曲实现

machine-learning - keras vgg 16 形状错误

tensorflow - 自定义池化层 - minmax pooling - Keras - Tensorflow

python - 为什么我的 Python 类声称我有 2 个参数而不是 1 个?

python - 基于国家/地区的 super 用户访问

python - 使用 Plotly 和 Choropleth 绘制多边形

tensorflow - 如何阅读 tensorflow 非极大值抑制方法源码?