python - Tensorflow 将图加载到特定范围

标签 python tensorflow

我正在使用 Network 类中的以下方法将预训练网络加载到 Tensorflow 中(因此调用 self.xyz)。首先,调用 define_network(),然后初始化其他变量和优化器,然后调用 load_model()

然而,尽管使用了tf.variable_scope(self.name),图中的变量还是被加载到变量的通用空间中。这是有问题的,因为我有这个类的两个实例,每个实例都加载相同的网络,并且我想将其分离到不同的范围中。

如何将变量加载到特定范围?

附注请随时纠正我的代码中的任何错误!

  def load_model(self):
    with tf.variable_scope(self.name) as scope:
      self.saver.restore(self.sess, self.model_path)
      print("Loaded model from {}".format(self.model_path))

  def define_model(self):   

    with tf.variable_scope(self.name) as scope:
      self.saver = tf.train.import_meta_graph(self.model_path + '.meta')
      print("Loaded model from {}".format(self.model_path + '.meta'))
      graph = tf.get_default_graph()

      self.inputs = []
      inp_names = ['i_hand1:0', 'i_hand2:0', 'i_flop1:0', 'i_flop2:0', 'i_flop3:0',
                   'i_turn:0', 'i_river:0', 'i_other:0', 'i_allowed_mod:0', 'keras_learning_phase:0']
      for inp in inp_names:
        self.inputs.append(tf.get_default_graph().get_tensor_by_name(inp))

      self.outputs = tf.get_default_graph().get_tensor_by_name("Tanh:0") 
      self.add_output_conversions()

      all_vars = tf.trainable_variables()
      for var in all_vars:
        self.var[var.name] = var

最佳答案

我认为你的问题可以通过添加一个参数来解决

self.saver = tf.train.import_meta_graph(self.model_path + '.meta', 'import_scope'=self.name)

这是reference

关于python - Tensorflow 将图加载到特定范围,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42017833/

相关文章:

python - 使用 Tkinter 事件更改标签的外观

python - Scikit Learn 中不确定的随机森林文档

python - 我可以将我的 TensorFlow Python 环境从 Windows 转移到 Ubuntu 吗?

machine-learning - 如何在没有绑定(bind)权值的情况下实现卷积连接?

c++ - 在 TensorFlow C++ 中传递 RunOptions

python - 如何调试使用基本身份验证处理程序的 urllib2 请求

python - peewee.OperationalError : too many SQL variables on upsert of only 150 rows * 8 columns

python - 使用石斑鱼列出唯一值

c# - 在 C# 中获取 Keras 训练模型预测

tensorflow - `tf.nn.batch_normalization` 和 `tf.nn.fused_batch_norm` 之间的区别?