python - Tensorflow 图节点是交换的

标签 python python-3.x tensorflow tensorboard object-detection-api

我用微调预训练模型训练了一个模型 ssd_mobilenet_v2_coco_2018 .在这里,我使用了完全相同的pipeline.config 文件进行训练,该文件在ssd_mobilenet_v2_coco_2018 中可用。预训练文件夹。
我只删除了 batch_norm_trainable: true标记并更改了类的数量 (4)。
用我的自定义数据集训练模型后,我发现了 concatconcat_1节点相互交换。
预训练模型有| concat | 1x1917x1x4 |训练后它变成| concat | 1x1917x5 |我附上了两个张量板图形可视化图像。第一张图片是预训练图 ssd_mobilenet_v2_coco_2018 .
enter image description here
enter image description here

节点交换可以在图像的最右上角看到。与预训练图一样,Postprocess layer联系 concat_1Squeeeze联系 concat .但经过训练后,图形显示完全相反。赞 Prosprocess layer联系 concatSqueeeze联系 concat_1 .
此外,我还在预训练的模型图中发现 Preprocessor接受输入 ToFloat而在训练后,图表显示 Cast 作为 Preprocessor 的输入.
我已将输入作为 tfrecords 提供给模型.

最佳答案

最有可能的区别不在于图中,而在于节点的名称,即节点 concatconcat_1左侧是与 resp 相同的节点。 concat_1concat在右边。

问题是,当您没有为节点提供明确的名称时,tensorflow 需要提出一个名称,并且它的命名约定相当没有创意。第一次需要命名节点时,它会使用它的类型。当它再次遇到这种情况时,只需添加 _ + 名称的增加数字。

拿这个例子:

import tensorflow as tf

x = tf.placeholder(tf.float32, (1,), name='x')
y = tf.placeholder(tf.float32, (1,), name='y')
z = tf.placeholder(tf.float32, (1,), name='z')

xy = tf.concat([x, y], axis=0)  # named 'concat'
xz = tf.concat([x, z], axis=0)  # named 'concat_1'

该图如下所示:

enter image description here

现在,如果我们构造相同的图,但这次创建 xz之前 xy ,我们得到如下图:

enter image description here

所以图表并没有真正改变——只有名字改变了。这可能就是您的情况:创建了相同的操作,但顺序不同。

无状态节点的名称更改的事实,例如 concat不重要,因为例如在加载保存的模型时不会错误路由权重。尽管如此,如果命名稳定性对您很重要,您可以为您的操作指定明确的名称或将它们放在不同的范围内:
xy = tf.concat([x, y], axis=0, name='xy')
xz = tf.concat([x, z], axis=0, name='xz')

enter image description here

如果变量切换名称,则问题更大。这就是为什么tf.get_variable的原因之一-- 这会强制变量具有名称并在发生名称冲突时引发错误 -- 是前 TF2 时代处理变量的首选方式。

关于python - Tensorflow 图节点是交换的,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61043978/

相关文章:

python - 如何解析位姿估计 tflite 模型的热图输出?

python - 仅取消堆叠或旋转某些列

python - 在python调试中计算单词中的字母

python - 在 python 中重写 __setattr__ 和 __getattribute__ 时,无法更新可变实例属性

python-3.x - 通过网格搜索调整模型

python - 如何在 TensorBoard 中单独运行我的 TensorFlow 代码?

python - Pyinstaller 将 opencv 从 Windows 10 分发到 Windows <10,缺少 ucrt dlls api-ms-win-crt

python - 从 REST API 获取数据并将其存储在 HDFS/HBase 中

python - 使用 super() 访问第二个基类的方法

python - 值错误: invalid literal for int() with base 10 on Alexnet