我想做什么
我正在尝试学习 TensorFlow 对象识别,并且像往常一样学习新事物,我在网上搜索教程。我不想涉及任何第三方云服务或 Web 开发框架,我只想学习使用原生 JavaScript、Python 和 TensorFlow 库来完成。
到目前为止我所拥有的
到目前为止,我已经关注了 a TensorFlow object detection tutorial (伴随着 5+ hour video )到我在 Tensorflow (python) 中训练模型并希望通过 TensorflowJS 将其转换为在浏览器中运行的地步。我也尝试过其他教程,但似乎没有找到一个解释如何在没有第三方云/工具和 React 的情况下执行此操作的教程。
我知道为了将这个模型与 tensorflow.js
一起使用我的目标是获得如下文件:
group1-shard1of2.bin
group1-shard2of2.bin
labels.json
model.json
我已经到了创建 TFRecord 文件并开始训练的地步:
py Tensorflow\models\research\object_detection\model_main_tf2.py --model_dir=Tensorflow\workspace\models\my_ssd_mobnet --pipeline_config_path=Tensorflow\workspace\models\my_ssd_mobnet\pipeline.config --num_train_steps=100
似乎在训练模型后,我只剩下:checkpoint
, ckpt-1.data-00000-of-00001
, ckpt-1.index
, pipeline.config
ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8
我确信从这一步得到我需要的文件并不难,但老实说,我浏览了很多文档、教程和谷歌,没有看到没有第三方云服务的例子。也许它在文档中,我遗漏了一些明显的东西。
项目目录结构如下所示:
我在哪里寻找答案
出于某种原因,令人沮丧的是,我发现的每一个教程(包括上面链接的那个)使用预训练的 Tensorflow 模型通过 TensorFlowJS 进行对象检测都需要使用 IBM Cloud 和 ReactJS。也许他们都是从他们找到的一些教程中复制的,现在所有的教程都包含这个,我不知道。我所知道的是我正在 build 一个 Electron.js假设计算是在用户设备上进行的,桌面应用程序和对象检测不应该需要网络连接。澄清一下:我正在创建一个用户训练模型的应用程序,所以这不仅仅是一次转换的问题。我希望能够使用 Python Tensorflow 进行训练并将模型转换为在没有任何云 API 的情况下在 JavaScript Tensorflow 上运行。
所以我停止寻找教程并尝试直接查看文档 https://github.com/tensorflow/tfjs .
当您到达 section about importing pre-trained models , 它说:
Importing pre-trained models
We support porting pre-trained models from:
所以我点击了 Tensorflow SavedModel 的链接,这将我们带到了一个名为 tfjs-converter 的项目。 .该 repo 说:
This repository has been archived in favor of tensorflow/tfjs.
This repo will remain around for some time to keep history but all future PRs should be sent to tensorflow/tfjs inside the tfjs-core folder.
All history and contributions have been preserved in the monorepo.
这对我来说听起来有点像循环引用,因为它会将我引导到刚刚告诉我去这里的页面。所以此时您想知道整个库是否已被弃用,它会起作用还是什么?无论如何,我在这个 repo 中环顾四周,进入:https://github.com/tensorflow/tfjs-converter/tree/master/tfjs-converter
它说:
A 2-step process to import your model:
- A python pip package to convert a TensorFlow SavedModel or TensorFlow Hub module to a web friendly format. If you already have a converted model, or are using an already hosted model (e.g. MobileNet), skip this step.
- JavaScript API, for loading and running inference.
并且基本上说要创建一个 venv 并执行以下操作:
pip install tensorflowjs
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_format=tfjs_graph_model \
--signature_name=serving_default \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
但是等等,检查点文件是否有“TensorFlow SavedModel”?这似乎不清楚,文档没有解释。所以我用谷歌搜索,找到文档,它说:You can save and load a model in the SavedModel format using the following APIs:
Low-level tf.saved_model API. This document describes how to use this API in detail. Save: tf.saved_model.save(model, path_to_dir)
链接的语法有点推断:
tf.saved_model.save(
obj, export_dir, signatures=None, options=None
)
举个例子:class Adder(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def add(self, x):
return x + x
model = Adder()
tf.saved_model.save(model, '/tmp/adder')
但到目前为止,这一点都不熟悉。到目前为止,我不明白如何将训练过程的结果(检查点)加载到变量 model
中。所以我可以将它传递给这个函数。这段话似乎很重要:
Variables must be tracked by assigning them to an attribute of a tracked object or to an attribute of obj directly. TensorFlow objects (e.g. layers from
tf.keras.layers
, optimizers fromtf.train
) track their variables automatically. This is the same tracking scheme thattf.train.Checkpoint
uses, and an exported Checkpoint object may be restored as a training checkpoint by pointingtf.train.Checkpoint.restore
to theSavedModel
's "variables/" subdirectory.
这可能是答案,但我不太清楚就“恢复”而言意味着什么,或者我从那里去哪里,如果这甚至是正确的步骤。所有这些对于学习 TF 的人来说都非常令人困惑,这就是为什么我寻找一个可以做到这一点的教程,但同样,如果没有第三方云服务/React,我似乎找不到一个。
请帮我把点连起来。
最佳答案
您可以将模型转换为 TensorFlowJS 格式,而无需任何云服务。我已经列出了以下步骤。
I'm sure it's not hard to get from this step to the files I need.
您看到的检查点位于
tf.train.Checkpoint
格式( relevant source code that creates these checkpoints in the object detection model code )。这与 SavedModel 和 Keras 格式不同。我们将通过以下步骤:
Checkpoint (current) --> SavedModel --> TensorFlowJS
转换自 tf.train.Checkpoint
至 SavedModel
请查看 the script models/research/object_detection/export_inference_graph.py
将检查点文件转换为 SavedModel。下面的代码取自该脚本的文档。请调整您的特定项目的路径。
--input_type
应保持为 image_tensor
.python export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path path/to/ssd_inception_v2.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
在输出目录中,您应该会看到一个 savedmodel 目录。我们将在下一步中使用它。转换
SavedModel
到 TensorFlowJS按照 https://github.com/tensorflow/tfjs/tree/master/tfjs-converter 处的说明进行操作,特别关注“TensorFlow SavedModel 示例”。示例转换代码复制如下。请修改您项目的输入和输出路径。
--signature_name
和 --saved_model_tags
可能需要改变,但希望不会。tensorflowjs_converter \
--input_format=tf_saved_model \
--output_format=tfjs_graph_model \
--signature_name=serving_default \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
使用 TensorFlowJS 模型I know in order to use this model with tensorflow.js my goal is to get files like:
- group1-shard1of2.bin
- group1-shard2of2.bin
- labels.json
- model.json
上面的步骤应该为你创建这些文件,虽然我不认为
labels.json
将被创建。我不确定该文件应该包含什么。 TensorFlowJS 将使用 model.json
构建推理图,它将从 .bin
加载权重文件。因为我们将 TensorFlow SavedModel 转换为 TensorFlowJS 模型,所以我们需要使用
tf.loadGraphModel()
加载 JS 模型。 .见 the tfjs converter page for more information .请注意,对于 TensorFlowJS,TensorFlow SavedModel 和 Keras SavedModel 之间存在差异。在这里,我们正在处理一个 TensorFlow SavedModel。
运行模型的 Javascript 代码可能超出了这个答案的范围,但我建议阅读 this TensorFlowJS tutorial .我在下面包含了一个具有代表性的 javascript 部分。
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
附加说明... Which sounds a bit like a circular reference to me,
TensorFlowJS 生态系统已在 tensorflow/tfjs 中得到整合GitHub 存储库。 tfjs-converter documentation现在住在那里。您可以向 https://github.com/tensorflow/tfjs 创建拉取请求修复 SavedModel 链接以指向 tensorflow/tfjs 存储库。
关于tensorflow - 如何在不涉及 IBM 云的情况下转换我用 Tensorflow (python) 训练的模型以用于 TensorflowJS(从我现在的步骤开始)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68350476/