python - 如何获取TensorFlow层中的神经元数量?

标签 python tensorflow

假设我尝试将池化层的输出连接到密集层。为了做到这一点,我需要展平池张量。考虑下面的层:

def conv_layer(input, in_channels, out_channels, name="conv"):
    w = tf.get_variable("W", initializer=tf.truncated_normal([3, 3, in_channels, out_channels], stddev=0.1))
    b = tf.get_variable("B", initializer=tf.constant(0.1, shape=[out_channels]))
    conv = tf.nn.conv2d(input, w, strides=[1,1,1,1], padding="SAME")
    act = tf.nn.relu(conv + b)
    return act

def pool_layer(input, name="pool"):
    pool = tf.nn.max_pool(input, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME")
    return pool

def dense_layer(input, size_in, size_out, name="dense"):
    w = tf.get_variable("W", initializer=tf.truncated_normal([size_in, size_out], stddev=0.1))
    b = tf.get_variable("B", initializer=tf.constant(0.1, shape=[size_out]))
    act = tf.nn.relu(tf.matmul(input, w) + b)
    return act

我正在使用它们创建一个网络:

def cnn_model(x):
    x_image = tf.reshape(x, [-1, nseries, present_window, 1])
    conv1 = conv_layer(x_image, 1, 32, "conv1")
    pool1 = pool_layer(conv1, "pool1")
    conv2 = conv_layer(pool1, 32, 64, "conv2")
    pool2 = pool_layer(conv2, "pool2")
    nflat = 17*15*64 # hard-coded
    flat  = tf.reshape(pool2, [-1, nflat])
    yhat = dense_layer(flat, nflat, future_window, "dense1")
    return yhat

如您所见,我对变量 nflat 进行了硬编码。如何避免这种情况?

最佳答案

如果它是张量pool.get_shape()应该适用于Keras或Tensorflow。

这实际上会返回一个包含每个维度大小的元组,因此您需要从中进行选择,在您的情况下可能是第二个。

如果输入实际上是您的输入(没有任何其他层),那么您为什么要进行最大池化?您不是在寻找dropout吗?

如果批量大小是可变的,您确实会发现一个问题,因为无法告诉模型reshape的大小

关于python - 如何获取TensorFlow层中的神经元数量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47662661/

相关文章:

tensorflow - Keras/tensorflow - 限制内核数(intra_op_parallelism_threads 不起作用)

python - Tensorflow 支持哪个版本的 python?

python - 使用 Keras Tensorflow 2.0 获取梯度

javascript - 如何使用 Node.js (tfjs-node) 从 Tensorflow.js 中的检查点重新启动模型训练?

python - 如何向从 CSV 文件创建的每个命名元组添加 ID 属性?

python - 两张透明图像之间的差异

python - 递归地将子文件夹中的文件读取到列表中,并将每个子文件夹的文件合并到每个子文件夹的一个 csv 中

python - 属性错误: 'tuple' object has no attribute 'layer' when trying transfer learning with keras

python - 从 xpath 中删除信息?

python - 如何覆盖 sys.stdin 以复制输入流