tensorflow - Tensorflow 1.x 如何遍历给定 Tensor 的计算图?

标签 tensorflow

这个问题是关于低级 Tensorflow 1.x API 的。给定一个 TensorSession.run(),我不清楚 Tensorflow 如何遍历计算图。

假设我有一些这样的代码:

a = tf.constant(1.0)
b = tf.subtract(a, 1.0)
c = tf.add(b, 2.0)
d = tf.multiply(c,3)

sess = tf.Session()
sess.run(d)

减法、加法和乘法运算并不都存储在张量d中,对吧?我知道 Tensor 对象有 graph 和 op 字段;这些字段是如何递归访问以获得计算d所需的所有操作的?

编辑:添加输出

print(tf.get_default_graph().as_graph_def())
node {
  name: "Const"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 1.0
      }
    }
  }
}
node {
  name: "Sub/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 1.0
      }
    }
  }
}
node {
  name: "Sub"
  op: "Sub"
  input: "Const"
  input: "Sub/y"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Add/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 2.0
      }
    }
  }
}
node {
  name: "Add"
  op: "Add"
  input: "Sub"
  input: "Add/y"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "Mul/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 3.0
      }
    }
  }
}
node {
  name: "Mul"
  op: "Mul"
  input: "Add"
  input: "Mul/y"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
versions {
  producer: 38
}

最佳答案

这就是 Tensorflow 静态计算图的全部要点。当您构建图时,Tensorflow 会在后台隐式构建静态图。然后,当您执行图中的节点时,Tensorflow 知道导致该节点的确切操作集。这有几个好处:

  1. 节省计算,因为只有指向您想要的节点的子图才会被执行。
  2. 整个计算被分成小的可微部分。
  3. 模型的每个部分都可以在不同的设备上执行,从而实现巨大的加速。

使用此命令,查看每个节点的输入:

print(tf.get_default_graph().as_graph_def())

例如,如果您在小图上执行此操作,您将看到以下内容,从节点 d = tf.multiply(c,3) 开始:

name: "Mul"
op: "Mul"
input: "Add"

然后c = tf.add(b, 2.0):

name: "Add"
op: "Add"
input: "Sub"

然后b = tf.subtract(a, 1.0):

name: "Sub"
op: "Sub"
input: "Const"

最后a = tf.constant(1.0):

name: "Const"
op: "Const"

关于tensorflow - Tensorflow 1.x 如何遍历给定 Tensor 的计算图?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57469488/

相关文章:

tensorflow - 如何显式广播张量以匹配 tensorflow 中的另一个形状?

tensorflow - Tensorflow RNN 是否完全实现了 Elman 网络?

python - 训练因 ResourceExausted 错误而中断

python - Tensorflow 2 没有在 gpu 上运行

machine-learning - 如何在 TensorFlow 1.0 中使用 ValidationMonitor 作为估算器?

python - 为什么单独执行 softmax 和 crossentropy 与使用 softmax_cross_entropy_with_logits 一起执行它们会产生不同的结果?

python - 值错误 : Data cardinality is ambiguous

python - Tensorflow - 使用自定义比较器对张量进行排序

tensorflow - Keras/TF 2019 限制 GPU 内存使用?

Tensorflow 在 GPT 2 Git 版本中没有属性 "sort"?