这个问题是关于低级 Tensorflow 1.x API 的。给定一个 Tensor
到 Session.run()
,我不清楚 Tensorflow 如何遍历计算图。
假设我有一些这样的代码:
a = tf.constant(1.0)
b = tf.subtract(a, 1.0)
c = tf.add(b, 2.0)
d = tf.multiply(c,3)
sess = tf.Session()
sess.run(d)
减法、加法和乘法运算并不都存储在张量d
中,对吧?我知道 Tensor 对象有 graph 和 op 字段;这些字段是如何递归访问以获得计算d
所需的所有操作的?
编辑:添加输出
print(tf.get_default_graph().as_graph_def())
node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 1.0
}
}
}
}
node {
name: "Sub/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 1.0
}
}
}
}
node {
name: "Sub"
op: "Sub"
input: "Const"
input: "Sub/y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Add/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 2.0
}
}
}
}
node {
name: "Add"
op: "Add"
input: "Sub"
input: "Add/y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Mul/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 3.0
}
}
}
}
node {
name: "Mul"
op: "Mul"
input: "Add"
input: "Mul/y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 38
}
最佳答案
这就是 Tensorflow 静态计算图的全部要点。当您构建图时,Tensorflow 会在后台隐式构建静态图。然后,当您执行图中的节点时,Tensorflow 知道导致该节点的确切操作集。这有几个好处:
- 节省计算,因为只有指向您想要的节点的子图才会被执行。
- 整个计算被分成小的可微部分。
- 模型的每个部分都可以在不同的设备上执行,从而实现巨大的加速。
使用此命令,查看每个节点的输入:
print(tf.get_default_graph().as_graph_def())
例如,如果您在小图上执行此操作,您将看到以下内容,从节点 d = tf.multiply(c,3)
开始:
name: "Mul"
op: "Mul"
input: "Add"
然后c = tf.add(b, 2.0)
:
name: "Add"
op: "Add"
input: "Sub"
然后b = tf.subtract(a, 1.0)
:
name: "Sub"
op: "Sub"
input: "Const"
最后a = tf.constant(1.0)
:
name: "Const"
op: "Const"
关于tensorflow - Tensorflow 1.x 如何遍历给定 Tensor 的计算图?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57469488/