我一直在尝试使用 TFLite 来提高 Android 上的检测速度,但奇怪的是我的 .tflite 模型现在几乎只检测到 1 个类别。
我已经对重新训练 mobilenet 后得到的 .pb 模型进行了测试,结果很好,但是由于某种原因,当我将其转换为 .tflite 时,检测就差了...
对于再培训,我使用了 Tensorflow for poets 2 中的 retrain.py 文件
我正在使用以下命令重新训练、优化推理并将模型转换为 tflite:
python retrain.py \
--image_dir ~/tf_files/tw/ \
--tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/feature_vector/1 \
--output_graph ~/new_training_dir/retrainedGraph.pb \
-–saved_model_dir ~/new_training_dir/model/ \
--how_many_training_steps 500
sudo toco \
--input_file=retrainedGraph.pb \
--output_file=optimized_retrainedGraph.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TENSORFLOW_GRAPHDEF \
--input_shape=1,224,224,3 \
--input_array=Placeholder \
--output_array=final_result \
sudo toco \
--input_file=optimized_retrainedGraph.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=retrainedGraph.tflite \
--inference_type=FLOAT \
--inference_input_type=FLOAT \
--input_arrays=Placeholder \
--output_array=final_result \
--input_shapes=1,224,224,3
我在这里做错了什么吗?准确性的损失从何而来?
最佳答案
我在尝试将 .pb 模型转换为 .lite 时遇到了同样的问题。
事实上,我的准确率会从 95 降到 30!
事实证明,我犯的错误不是在 .pb 到 .lite 的转换过程中,也不是在执行此操作的命令中。但它实际上是在加载图像并对其进行预处理之前将其传递到 lite 模型并使用推断
interpreter.invoke()
命令。
您看到的以下代码是我所说的预处理的意思:
test_image=cv2.imread(file_name)
test_image=cv2.resize(test_image,(299,299),cv2.INTER_AREA)
test_image = np.expand_dims((test_image)/255, axis=0).astype(np.float32)
interpreter.set_tensor(input_tensor_index, test_image)
interpreter.invoke()
digit = np.argmax(output()[0])
#print(digit)
prediction=result[digit]
正如您所看到的,一旦使用“imread()”读取图像,就有两个关键命令/预处理完成:
i) 图像的大小应调整为训练期间使用的输入图像/张量的“input_height”和“input_width”值。在我的情况下 (inception-v3),“input_height”和“input_width”都是 299。 (阅读此值的模型文档或在用于训练或重新训练模型的文件中查找此变量)
ii) 上面代码中的下一个命令是:
test_image = np.expand_dims((test_image)/255, axis=0).astype(np.float32)
我从“公式”/模型代码中得到了这个:
test_image = np.expand_dims((test_image-input_mean)/input_std, axis=0).astype(np.float32)
阅读文档发现对于我的架构 input_mean = 0 和 input_std = 255。
当我对我的代码进行上述更改时,我获得了预期的准确度 (90%)。
希望这可以帮助。
关于TensorFlow 精简版 : High loss in accuracy after converting model to tflite,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50938992/