python - 如何在 Tensorflow 中以正确的方式使用自定义/非默认 tf.Graph?

标签 python session graph parallel-processing tensorflow

我是 Tensorflow 的新手,我正在阅读 https://www.amazon.com/TensorFlow-Machine-Learning-Cookbook-McClure/dp/1786462168 .我在 tf.Session 中注意到的一个参数是 graph。我喜欢完全控制流程,我想知道如何正确使用 tf.Graphtf.Session 以及如何向特定图形添加计算? 此外,人们向 Tensorflow 中的特定图形添加操作的规范语法是什么(如果有)?

类似于下面的内容:

t = np.linspace(0,2*np.pi)
fig, ax = plt.subplots()
ax.scatter(x=t, y=np.sin(t))

相比于:

plt.scatter(x=t, y=np.sin(t))

如何使用 tf.Graph() 获得同样的灵 active ?

G = tf.Graph()

t_query = np.linspace(0,2*np.pi,50)
pH_t = tf.placeholder(tf.float32, shape=t_query.shape)

def simple_sinewave(t, name=None):
    return tf.sin(t, name=name)

with tf.Session() as sess:
    r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
# array([  0.00000000e+00,   1.27877161e-01,   2.53654599e-01,
# ...
#         -1.27877384e-01,   1.74845553e-07], dtype=float32)

现在尝试指定一个 graph 参数:

with tf.Session(graph=G) as sess:
    r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-51-d73a1f0963e3> in <module>()
     26 #         -1.27877384e-01,   1.74845553e-07], dtype=float32)
     27 with tf.Session(graph=G) as sess:
---> 28     r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})

... RuntimeError: session 图为空。在调用 run() 之前向图中添加操作。

使用 David Parks 更新回答这个问题:

# Custom Function
def simple_sinewave(t, name=None):
    return tf.sin(t, name=name)

# Synth graph
G = tf.Graph()

# Build Graph
with G.as_default():
    t_query = np.linspace(0,2*np.pi,50)
    pH_t = tf.placeholder(tf.float32, shape=t_query.shape)

# Run session using Graph
with tf.Session(graph=G) as sess:
    r = sess.run(simple_sinewave(pH_t), feed_dict={pH_t:t_query})
r
# array([  0.00000000e+00,   1.27877161e-01,   2.53654599e-01,
#          3.75266999e-01,   4.90717560e-01,   5.98110557e-01,
# ...
#         -4.90717530e-01,  -3.75267059e-01,  -2.53654718e-01,
#         -1.27877384e-01,   1.74845553e-07], dtype=float32)

奖励:Tensorflow 中是否有用于命名占位符变量的特定命名法?像 pd.DataFrame 一样像 df_data

最佳答案

你通常这样做:

with tf.Graph().as_default():
  # build your model
    with tf.Session() as sess:
      sess.run(...)

我有时会使用多个图表来分别运行与训练集分开的测试集,这与您上面的示例类似,您可以这样做:

g = tf.Graph()
with g.as_default():
  # build your model
  with tf.Session() as sess:
    sess.run(...)

正如您还指出的那样,您可以避免使用 with 而只是执行 sess = tf.Session(graph=g),并且您必须关闭您的 session 自己。大多数用例将通过使用 python 的 with

来简化

当您有两个图时,无论何时使用该图,您都将每个 as_default() 设置为默认图。

例子:

g1 = tf.Graph()
g2 = tf.Graph()

with g1.as_default():
  # do stuff like normal, g1 is the graph that will be used
  with tf.Session() as session_on_g1:
    session_on_g1.run(...)

with g2.as_default():
  # do stuff like normal, g2 is the graph that will be used
  with tf.Session() as session_on_g2:
    session_on_g2.run(...)

关于python - 如何在 Tensorflow 中以正确的方式使用自定义/非默认 tf.Graph?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43574928/

相关文章:

python - Transcrypt:将客户端 JS 对象转换为字典?

python - Style() 和 ThemedStyle() 之间的冲突

python - TemplateDoesNotExist at/Django?

php - 如何在 Symfony 中注销被禁止的用户?

django - 如何在 Django 中根据域名或 TLD 设置 urlpatterns?

python - Tensorflow:在类中创建图形并在外部运行

algorithm - 最大二分匹配方法中的错误

algorithm - 在哪里可以找到图形输入资源/文件?

python - 检查值是否位于区间内

algorithm - 访问所有节点的最短路径