我正在尝试将一组值传递给我的 Tensorflow Lite 模型。在我的应用程序中输入值后应用程序崩溃,然后将这些值存储在数组中并传递给模型。这是代码 fragment 。
inferButton.setOnClickListener(new View.OnClickListener(){
@Override
public void onClick(View view){
int getText1=Integer.parseInt(inputNumber1.getText().toString());
int getText2=Integer.parseInt(inputNumber2.getText().toString());
int getText3=Integer.parseInt(inputNumber3.getText().toString());
int getText4=Integer.parseInt(inputNumber4.getText().toString());
int getText5=Integer.parseInt(inputNumber5.getText().toString());
int getText6=Integer.parseInt(inputNumber6.getText().toString());
int getText7=Integer.parseInt(inputNumber7.getText().toString());
int getText8=Integer.parseInt(inputNumber8.getText().toString());
int getText9=Integer.parseInt(inputNumber9.getText().toString());
int getText10=Integer.parseInt(inputNumber10.getText().toString());
int getText11=Integer.parseInt(inputNumber11.getText().toString());
int getText12=Integer.parseInt(inputNumber12.getText().toString());
int getText13=Integer.parseInt(inputNumber13.getText().toString());
int getText14=Integer.parseInt(inputNumber14.getText().toString());
int attributes[]={getText1,getText2,getText3,getText4,getText5,getText5,getText6,getText7,getText8,getText9,getText10,getText10,getText11,getText12,getText13,getText14};
int prediction=doInference(attributes);
//float prediction=doInference(inputNumber.getText().toString());
outputNumber.setText(prediction);
}
});
public int doInference(int [] attributeArray){
int[] inputVal=new int[14];
for(int i=0;i<14;i++) {
inputVal[i] = attributeArray[i];
}
//Output shape is [1][1]
int[][] outputval=new int[1][1];
//Run inference passing the input shape and getting the output shape
tflite.run(inputVal, outputval);
//Inferred value is at [0][0]
int inferredValue=outputval[0][0];
return inferredValue;
}
我的错误是:-
E/AndroidRuntime: FATAL EXCEPTION: main
Process: com.example.appdoctor, PID: 13022
java.lang.NullPointerException: Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)' on a null object reference
at com.example.appdoctor.Prediction.doInference(Prediction.java:90)
at com.example.appdoctor.Prediction$1.onClick(Prediction.java:75)
at android.view.View.performClick(View.java:7125)
at android.view.View.performClickInternal(View.java:7102)
at android.view.View.access$3500(View.java:801)
at android.view.View$PerformClick.run(View.java:27336)
at android.os.Handler.handleCallback(Handler.java:883)
at android.os.Handler.dispatchMessage(Handler.java:100)
at android.os.Looper.loop(Looper.java:214)
at android.app.ActivityThread.main(ActivityThread.java:7356)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:492)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930)
我认为数组可能没有填充值,或者我在 doInference()
方法中犯了错误,因为我不了解我的神经网络结构。
这是 python 代码:-
import glob
import os
from keras.models import Sequential, load_model
import numpy as np
import pandas as pd
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
import matplotlib.pyplot as plt
import keras as k
import tensorflow as tf
from tensorflow import keras
from tensorflow import lite
df = pd.read_csv("kidney4.csv")
df = df.dropna(axis=0)
for column in df.columns:
if df[column].dtype == np.number:
continue
df[column] = LabelEncoder().fit_transform(df[column])
X = df.drop(["classification"], axis=1)
y = df["classification"]
x_scaler = MinMaxScaler()
x_scaler.fit(X)
column_names = X.columns
X[column_names] = x_scaler.transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size= 0.2, shuffle=True)
model=keras.Sequential([keras.layers.Dense(units=1, input_shape=[14])])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(X_train, y_train, epochs=500)
input_array = np.array([40,8,1,2,0,2,6,10,34,40,16,23,67,25])
input_array_for_prediction = np.expand_dims(input_array,axis=0)
print(model.predict(input_array_for_prediction))
这是我的model.summary()
这是我的整个 java 代码:-
import androidx.appcompat.app.AppCompatActivity;
import android.content.*;
import android.content.res.*;
import android.view.View;
import android.widget.*;
import org.tensorflow.lite.*;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import android.os.Bundle;
public class Prediction extends AppCompatActivity {
EditText inputNumber1,inputNumber2,inputNumber3,inputNumber4,inputNumber5,inputNumber6,inputNumber7,inputNumber8,inputNumber9,inputNumber10,inputNumber11,inputNumber12,inputNumber13,inputNumber14;
Button inferButton;
TextView outputNumber;
Interpreter tflite;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_prediction);
inputNumber1=(EditText)findViewById(R.id.editText);
inputNumber2=(EditText)findViewById(R.id.editText2);
inputNumber3=(EditText)findViewById(R.id.editText3);
inputNumber4=(EditText)findViewById(R.id.editText4);
inputNumber5=(EditText)findViewById(R.id.editText5);
inputNumber6=(EditText)findViewById(R.id.editText6);
inputNumber7=(EditText)findViewById(R.id.editText7);
inputNumber8=(EditText)findViewById(R.id.editText8);
inputNumber9=(EditText)findViewById(R.id.editText9);
inputNumber10=(EditText)findViewById(R.id.editText10);
inputNumber11=(EditText)findViewById(R.id.editText11);
inputNumber12=(EditText)findViewById(R.id.editText12);
inputNumber13=(EditText)findViewById(R.id.editText13);
inputNumber14=(EditText)findViewById(R.id.editText14);
outputNumber=(TextView)findViewById(R.id.outputNumber);
inferButton=(Button)findViewById(R.id.predictButton);
try{
tflite=new Interpreter(loadModelFile());
}catch(Exception ex){
ex.printStackTrace();
}
inferButton.setOnClickListener(new View.OnClickListener(){
@Override
public void onClick(View view){
int getText1=Integer.parseInt(inputNumber1.getText().toString());
int getText2=Integer.parseInt(inputNumber2.getText().toString());
int getText3=Integer.parseInt(inputNumber3.getText().toString());
int getText4=Integer.parseInt(inputNumber4.getText().toString());
int getText5=Integer.parseInt(inputNumber5.getText().toString());
int getText6=Integer.parseInt(inputNumber6.getText().toString());
int getText7=Integer.parseInt(inputNumber7.getText().toString());
int getText8=Integer.parseInt(inputNumber8.getText().toString());
int getText9=Integer.parseInt(inputNumber9.getText().toString());
int getText10=Integer.parseInt(inputNumber10.getText().toString());
int getText11=Integer.parseInt(inputNumber11.getText().toString());
int getText12=Integer.parseInt(inputNumber12.getText().toString());
int getText13=Integer.parseInt(inputNumber13.getText().toString());
int getText14=Integer.parseInt(inputNumber14.getText().toString());
int attributes[]={getText1,getText2,getText3,getText4,getText5,getText5,getText6,getText7,getText8,getText9,getText10,getText10,getText11,getText12,getText13,getText14};
int prediction=doInference(attributes);
//float prediction=doInference(inputNumber.getText().toString());
outputNumber.setText(prediction);
}
});
}
public int doInference(int [] attributeArray){
int[] inputVal=new int[14];
for(int i=0;i<14;i++) {
inputVal[i] = attributeArray[i];
}
int[][] outputval=new int[1][1];
//Run inference passing the input shape and getting the output shape
tflite.run(inputVal, outputval);
//Inferred value is at [0][0]
int inferredValue=outputval[0][0];
return inferredValue;
}
private MappedByteBuffer loadModelFile() throws IOException {
AssetFileDescriptor fileDescriptor=this.getAssets().openFd("kidney_test_model.tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}
感谢任何帮助。谢谢。
最佳答案
在检查我的 build.gradle
文件之前,我遇到了同样的问题。你需要输入
android {
...
aaptOptions {
noCompress "tflite"
}
}
在你的 build.gradle
文件中,然后同步它。
关于java - Tensorflow Lite Android 应用程序崩溃并出现 NullPointerException 'void org.tensorflow.lite.Interpreter.run(java.lang.Object, java.lang.Object)',我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60580404/