machine-learning - 在 tensorflow 中提取子图

标签 machine-learning tensorflow

我有预训练的网络,我正在尝试获取它的一部分(子图)tf 图以及变量和保护程序对象。

我就是这样做的:

subgraph = tf.graph_util.extract_sub_graph(default_graph, list of nodes to preserve)
tf.reset_default_graph()
tf.import_graph_def(subgraph)
然而,这会删除所有变量(当我调用reset_default_graph时)。即使我明确地将变量的操作节点(仅“变量”类型操作)添加到“要保留的节点列表”中。

如何在保留变量值的同时保留较大图的子图? 是否需要添加一些新节点到“保留列表”?

我仍然不清楚图节点和变量之间的关系,教程仅提到变量的创建会在图中创建一些操作(节点)。

最佳答案

我认为你所做的看起来是正确的。正如您所说,变量只是一个输出特定值的张量的操作(图中的节点)。您应该能够将变量节点添加到列表中以保留它们,就像您已经做的那样。您可以使用 print(sess.graph_def) 确保您提供的名称正确吗?

关于machine-learning - 在 tensorflow 中提取子图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41346052/

相关文章:

machine-learning - keras 卷积神经网络 - 输出形状

machine-learning - 如何添加回归头

python - "ValueError: Trying to share variable $var, but specified dtype float32 and found dtype float64_ref"尝试使用 get_variable 时

python - Keras 损失没有减少

c++ - 在tensorflow c++中是否有将base64字符串解码为张量的函数?

python - 尝试在 Amazon AWS 实例上安装 keras 和 tensorflow

machine-learning - 有偏差的初始数据集主动学习

algorithm - 为什么过度拟合给出了错误的假设函数

tensorflow - 如何在 tensorflow 2.0 中计算 hessian?

python - GradienTape 收敛速度比 Keras.model.fit 慢得多