android - 解释 TF lite 在对象检测 API 上的输出

标签 android python tensorflow object-detection-api

我正在使用对象检测 API 来训练我的自定义数据以解决 2 类问题。 我正在使用 SSD Mobilenet v2。我正在将模型转换为 TF lite,并尝试在 python 解释器上执行它。 分数和类(class)的值(value)让我有些困惑,我无法对此做出有效的解释。我得到以下分数值。

[[ 0.9998122 0.2795332 0.7827836 1.8154384 -1.1171713 0.152002 -0.90076405 1.6943774 -1.1098632 0.6275915]]

我得到以下类值:

[[ 0.1.742706 0.5762139 -0.23641224 -2.1639721 -0.6644413 -0.60925585 0.5485272 -0.9775026 1.4633082]]

如何获得大于 1 或小于 0 的分数,例如-1.10986321.6943774。 此外,理想情况下,类应该是整数 12,因为这是一个 2 类对象检测问题

我正在使用以下代码



    import numpy as np
    import tensorflow as tf
    import cv2

    # Load TFLite model and allocate tensors.
    interpreter = tf.contrib.lite.Interpreter(model_path="C://Users//Admin//Downloads//tflitenew//detect.tflite")
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    print(input_details)
    print(output_details)
    input_shape = input_details[0]['shape']
    print(input_shape)
    # change the following line to feed into your own data.
    #input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)

    input_data = cv2.imread("C:/Users/Admin/Pictures/fire2.jpg")
    #input_data = cv2.imread("C:/Users/Admin/Pictures/images4.jpg")
    #input_data = cv2.imread("C:\\Users\\Admin\\Downloads\\FlareModels\\lessimages\\video5_image_178.jpg")
    input_data = cv2.resize(input_data, (300, 300)) 

    input_data = np.expand_dims(input_data, axis=0)
    input_data = (2.0 / 255.0) * input_data - 1.0
    input_data=input_data.astype(np.float32)
    interpreter.reset_all_variables()
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    output_data_scores = []
    output_data_scores = interpreter.get_tensor(output_details[2]['index'])
    print(output_data_scores)

    output_data_class = []
    output_data_class = interpreter.get_tensor(output_details[1]['index'])
    print(output_data_class)

最佳答案

看起来问题是由错误的输入图像 channel 顺序引起的。 Opencv imread 读取“BGR”格式的图像。您可以尝试添加

input_data = cv2.cvtColor(input_data,  cv2.COLOR_BGR2RGB)

获取“RGB”格式的图像,然后查看结果是否合理。

引用:ref

关于android - 解释 TF lite 在对象检测 API 上的输出,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55492231/

相关文章:

Android——如何在使用具有自定义背景的 ListView 时删除那一点额外的填充?

Python - 正则表达式查找所有重复模式,后跟可变长度的字符

c++ - 哪个库用于视频和音频录制?

tensorflow - 使用 Tensorflow 的 Connectionist 时间分类 (CTC) 实现

tensorflow - "rewind" tensorflow 训练步骤

android - 如何通过 Intent 打开或展开状态栏?

android - 我可以通过 android 中的 findViewById() 方法直接访问包含布局的 child 吗?

android - FCM : Displaying emoji or UTF-8 text on iOS that is sent from Android device

python - 如何将 IP 地址转换为可用作字典键的 32 位地址

optimization - 在 TensorFlow 中停止梯度优化器