python - 无法在android studio中的解释器上运行tflite模型

标签 python android-studio tensorflow keras tensorflow-lite

我正在尝试在智能手机上的应用程序上运行 TensorFlow-lite 模型。首先,我使用 LSTM 使用数值数据训练模型,并使用 TensorFlow.Keras 构建模型层。我使用 TensorFlow V2.x 并将训练好的模型保存在服务器上。之后,模型由应用程序下载到智能手机的内存中,并使用“MappedByteBuffer”加载到解释器中。直到这里一切都工作正常。

问题在于解释器无法读取和运行模型。 我还在 build.gradle 上添加了所需的依赖项。

python中tflite模型的转换代码:

from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, LSTM
from tensorflow.keras import regularizers
#Create the network
model = Sequential()
model.add(LSTM(...... name = 'First_layer'))
model.add(Dropout(rate=Drop_out))
model.add(LSTM(...... name = 'Second_layer'))
model.add(Dropout(rate=Drop_out))

# compile model
model.compile(loss=keras.losses.mae, 
optimizer=keras.optimizers.Adam(learning_rate=learning_rate), metrics=["mae"])

# fit model
model.fit(.......)
#save the model
tf.saved_model.save(model,'saved_model')
print("Model  type", model1.dtype)# Model type is float32 and size around 2MB

#Convert saved model into TFlite
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model')
tflite_model = converter.convert()

with open("Model.tflite, "wb") as f:
    f.write(tflite_model)
f.close()

我还尝试了使用 Keras 的其他转换方式

# converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
# tflite_model = converter.convert()

完成此步骤后,“Model.tflite”将被转换并下载到智能手机的内存中。

Android Studio代码:

  try {
        private Interpreter tflite = new Interpreter(loadModelFile());
        Log.d("Load_model", "Created a Tensorflow Lite of AutoAuth.");

    } catch (IOException e) {
        Log.e("Load_model", "IOException loading the tflite file");

    }

private MappedByteBuffer loadModelFile() throws IOException {
    String model_path = model_directory + model_name + ".tflite";
    Log.d(TAG, model_path);
    File file = new File(model_path);
    if(file!=null){
    FileInputStream inputStream = new FileInputStream(file);
    FileChannel fileChannel = inputStream.getChannel();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, file.length());
    }else{
        return null;
    }
}

“loadModelFile()”函数工作正常,因为我使用 MNIST 数据集进行图像分类的另一个 tflite 模型对其进行了检查。问题只是解释器。

这也是build.gradle的内容:

android {
aaptOptions {
    noCompress "tflite"
}
 }
  android {
     defaultConfig {
        ndk {
            abiFilters 'armeabi-v7a', 'arm64-v8a'
        }
      }
    }

dependencies {
     implementation 'com.jakewharton:butterknife:8.8.1'
     implementation 'org.tensorflow:tensorflow-lite:0.1.2-nightly'
     annotationProcessor 'com.jakewharton:butterknife-compiler:8.8.1'
     implementation fileTree(dir: 'libs', include: ['*.jar'])
     //noinspection GradleCompatible
     implementation 'com.android.support:appcompat-v7:28.0.0'
    implementation 'com.android.support.constraint:constraint-layout:2.0.4'
    testImplementation 'junit:junit:4.12'
    androidTestImplementation 'com.android.support.test:runner:1.0.2'
    androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
    }

每当我运行 Android studio 时,都会出现以下错误之一: 1- enter image description here

2-

enter image description here

我浏览了许多资源和线程,并阅读了有关保存训练模型、TFlite 转换和解释器的信息。 我五天前就试图解决这个问题,但没有希望。谁能给出解决方案吗?

最佳答案

引用最新的 TfLite android 应用程序示例之一可能会有所帮助:Model Personalization App 。该演示应用使用迁移学习模型而不是 LSTM,但整体工作流程应该类似。

正如 Farmaker 在评论中提到的,尝试在 gradle 依赖项中使用 SNAPSHOT:

implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT'

要正确加载模型,您可以尝试:

protected MappedByteBuffer loadMappedFile(String filePath) throws IOException {
    AssetFileDescriptor fileDescriptor = assetManager.openFd(this.directoryName + "/" + filePath);

    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(MapMode.READ_ONLY, startOffset, declaredLength);
  }

此代码片段也可以在我上面发布的 GitHub 示例链接中找到。

关于python - 无法在android studio中的解释器上运行tflite模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69796868/

相关文章:

python - 如何解决这些 tensorflow 警告?

python - Tensorflow无法训练模型

python - 使用 Keras ImageDataGenerator 时出现内存错误

Python 断言——改进了失败的自省(introspection)?

android-studio - 升级到 android studio 3.5 并 gradle 到 5.4.1 后出现 groovy.lang.MissingPropertyException

python - 深度学习中哪些算法可以验证列到矩阵的关系

android - Gradle 总是从最后一个 flavor 中的 buildType 中获取值

android - Pending Intent 请求代码可以为负数吗?

python - 函数变为 "unbound method"。为什么?

python - 属性类和属性装饰器有什么区别