python - Tensorflow仅初始化特定范围

标签 python tensorflow

他在那里,

我有一个关于控制初始化哪个变量范围的问题,或者至少是在运行期间使用哪个变量范围的问题。

以这段简单的代码为例

import numpy as np
import tensorflow as tf

with tf.variable_scope('0') as scope:
    place_holder_batch_x = tf.Variable(np.random.rand(11,6), dtype=tf.float64)
    place_holder_batch_y = tf.Variable(np.random.rand(8,5), dtype=tf.float64)
    rnn_cell = tf.nn.rnn_cell.BasicRNNCell(3)
    z = place_holder_batch_x*2

with tf.variable_scope('1') as scope:
    place_holder_batch_x = tf.Variable(np.random.rand(10,5), dtype=tf.float64)
    place_holder_batch_y = tf.Variable(np.random.rand(9,6), dtype=tf.float64)
    rnn_cell = tf.nn.rnn_cell.BasicRNNCell(4)
    z = place_holder_batch_x*2

init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)
print(sess.run(z).shape)

如果我按原样运行它,我将获得变量 z 的形状,如变量范围“1”中定义的那样。 但是我如何指定在 session 期间使用哪个变量范围呢?我在 stackoverflow 或文档中找不到任何答案...

当然,我可以将两个 z 重命名为 z1 和 z2...但我想保持两个作用域看起来很相似并使用相同名称的情况...

最佳答案

试试这个:

import numpy as np
import tensorflow as tf


g1 = tf.Graph()
with g1.as_default() as g:
    with tf.variable_scope('0') as scope:
        place_holder_batch_x = tf.Variable(np.random.rand(11,6), dtype=tf.float64)
        place_holder_batch_y = tf.Variable(np.random.rand(8,5), dtype=tf.float64)
        rnn_cell = tf.nn.rnn_cell.BasicRNNCell(3)
        z = place_holder_batch_x*2
g2 = tf.Graph()
with g2.as_default() as g:
    with tf.variable_scope('1') as scope:
        place_holder_batch_x = tf.Variable(np.random.rand(10,5), dtype=tf.float64)
        place_holder_batch_y = tf.Variable(np.random.rand(9,6), dtype=tf.float64)
        rnn_cell = tf.nn.rnn_cell.BasicRNNCell(4)
        z = place_holder_batch_x*2

tf.reset_graph_default()

graph_to_be_used = g1

with tf.session(graph = graph_to_be_used) as sess:
    init = tf.global_variables_initializer()

    sess.run(init)
    print(sess.run(z).shape)

关于python - Tensorflow仅初始化特定范围,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46648536/

相关文章:

neural-network - 我们应该为 adam 优化器做学习率衰减吗

python - Odoo 12 我无法更改计算字段的值

python - 字体菜单 PyQt5 文本编辑器

python - cv2.estimateRigidTransform 最小点数?

tensorflow - 如何将 logits 转换为 tensorflow 中二元分类的概率?

python - 我应该只使用 "exactly same"输入形状进行迁移学习吗?

python - 如何在训练操作的每一步之间修剪密集层的权重

python - Nose 多进程问题

python - Cloud ML 引擎和 Scikit-Learn : 'LatentDirichletAllocation' object has no attribute 'predict'

tensorflow - 如何使用FFT和神经网络对声音进行分类?我应该使用 CNN 还是 RNN?