java - tensorflow float32的转换与java对象不兼容

标签 java android python-3.x tensorflow

错误: 原因:java.lang.IllegalArgumentException:无法在 FLOAT32 类型的 TensorFlowLite 张量和 java.lang.String 类型的 Java 对象(与 TensorFlowLite 类型 STRING 兼容)之间进行转换。

我已经从我的数据集构建了一个神经网络,并且有 2 层,然后我将模型保存为 h5,然后使用 tf.keras 模型和转换将其转换为 tflite,但是当我将其部署在应用程序中时,它给了我上面的内容错误

我尝试过输入多种类型的数组和数组列表

错误: 原因:java.lang.IllegalArgumentException:无法在 FLOAT32 类型的 TensorFlowLite 张量和 java.lang.String 类型的 Java 对象(与 TensorFlowLite 类型 STRING 兼容)之间进行转换。

model.add(layers.Dense(500, input_dim=3, activation='relu'))
model.add(layers.Dense(1, activation= "relu"))
model.summary() #Print model Summary
model.compile(loss='mean_squared_error',optimizer='adam')
model.fit(X_train,Y_train,epochs=1000,validation_split=0.3)

我如何转换:-

from tensorflow.contrib import lite
converter = lite.TFLiteConverter.from_keras_model_file( 'Model.h5')
tfmodel = converter.convert()
open ("model.tflite" , "wb") .write(tfmodel)

Android 的实现

ArrayList<String> list = new ArrayList<>();
list.add("-0.5698444");
list.add("-0.57369368");
list.add("-1.31490297");






try (Interpreter interpreter = new Interpreter(mappedByteBuffer)) {
    interpreter.run(list, "output");
}


private MappedByteBuffer loadModelFile() throws IOException {
    AssetFileDescriptor fileDescriptor = getAssets().openFd("model.tflite");
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

最佳答案

发现错误了。在行中,

try (Interpreter interpreter = new Interpreter(mappedByteBuffer)) {
    interpreter.run(list, "output");
 }

interpreter.run() 的第二个参数必须是 float[] 而不是 “output”。模型运行时,float[] 将填充类概率。

interpreter.run() 方法提供输入和输出的正确方法:

Interpreter interpreter = new Interpreter(mappedByteBuffer)

float[][] inputs = new float[1][num_features]
// populate the inputs float array above
float[][] outputs = new float[1][num_classes]

interpreter.run( inputs , outputs )

float[] classProb = outputs[0]

关于java - tensorflow float32的转换与java对象不兼容,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56323246/

相关文章:

java - 需要帮助逆向工程算法

android - Ionic2 RC4,在生产模式下找不到插件

python - 为什么我的 'button.bind' 不会调用 'makeChoice' ?

python-3.x - Python 无法访问我的 USB 设备 - 权限不足

python - 当另一个线程处于事件状态时,为什么不在全局对象上调用__del__?

java - 如何使用BufferedReader一次又一次地读取同一个txt文件?

java - 具有不同布局参数的 RecyclerView.Adapter 项目

java - 如何替换Java中第一次出现的字符串

android - 如果 url 字符串包含 "~"符号,则改造调用错误的 url 会导致 404 错误

javascript - 我的加速度计和canvas 2d 做错了什么?安卓