python - "fc1 = tf.reshape(conv2, [-1, weights[' wd 1'].get_shape().as_list()[0]])"在做什么?

标签 python tensorflow2.0

def conv_net(x, weights, biases, dropout):
    # Layer 1 - 28*28*1 to 14*14*32
    conv1 = conv2d(x, weights['wc1'], biases['bc1'])
    conv1 = maxpool2d(conv1, k=2)

    # Layer 2 - 14*14*32 to 7*7*64
    conv2 = conv2d(conv1, weights['wc2'], biases['bc2'])
    conv2 = maxpool2d(conv2, k=2)

这个 get.shape().as_list()[0] 有什么作用?

fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])

最佳答案

为了更好的理解get.shape().as_list(),这里我提供一个简单的例子,希望对你有帮助。

import tensorflow as tf
c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
Shape = c.get_shape().as_list()
print(Shape)  

输出:

[2, 3]

这意味着c2 行和3 列。 当我们打印 Shape = c.get_shape().as_list()[0] 时,它以列表格式 返回 c0th 元素(通常它将元组中的形状返回为 (2,3))

Shape = c.get_shape().as_list()[0] 

输出:

2

当我们使用 tf.reshape 时,它会返回一个新的张量,该张量的值与旧张量相同,顺序相同,但新形状由 shape 指定。

当我们传递 [-1] 的形状 时,它变平为一维。

tf.reshape(c, [-1])

输出:

<tf.Tensor: shape=(6,), dtype=float32, numpy=array([1., 2., 3., 4., 5., 6.], dtype=float32)>

fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])

总而言之,我们在这里将 conv2(即 2D)(即一维) 权重展平并提取 0th weights['wd1'] 的元素。

关于python - "fc1 = tf.reshape(conv2, [-1, weights[' wd 1'].get_shape().as_list()[0]])"在做什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64739606/

相关文章:

Python mysql 连接器加载语句在错误的目录 Windows 中搜索

python - 如何使 html div 元素跟随垂直滚动(在 html/django 中)?

c++ - 在 SWIG 中将可变 char * 传递给 Python

tensorflow2.0 - Tensorflow object_detection API 微调 ckpt 分类

python - Tensorflow Lite,图像大小零误差

python - 由于 g++ 编译器错误,NS-3 构建失败

python - 重置类统计数据的想法

python - Tensorflow 2.0 意外 OOM

python - 如何在 Tensorflow 2.0 数据集中动态更改批量大小?

python - 如何在 Tensorflow (python) 中随机选取和屏蔽张量的一部分