我们在Tomcat7(java1.8)中部署tensorflow模型(seq2seq question answering),在调试时,我们只使用简单的java Application(public static void main() function)来测试模型推理结果。简单java应用中的推理结果与python原版相同。 但是当我们在Tomcat中启动整个包(WAR)时,得到的结果却大不相同,而推理代码/测试输入语句/模型文件都是一样的。
谁能给我们一些关于这个问题的提示?
- 简单的 java 应用程序(public static void main() 函数) 得到与 python tensorflow 版本推断结果相同的结果。我们将它们视为正确的。
- Tomcat 加载模型得到不同的结果。结果看起来是正常的句子,但考虑问题时回答的意思很差。
- 模型文件(protobuf)/java代码/测试输入语句,以上两种情况相同。
- Keepout 概率为 1.0f 以供推断。
模型加载函数:
@Override
public boolean reload(String modelURL) {
logger.info("tensorflow version:{}", TensorFlow.version());
try {
logger.info("start to download model path:{}", modelURL);
//TODO: download model
logger.info("start to load model path:{} tag:{}", MODEL_PATH, MODEL_TAG);
bundle = SavedModelBundle.load(MODEL_PATH, MODEL_TAG);
session = bundle.session();
logger.info("finish loading model!");
} catch(Exception e) {
logger.error("reload model exception:", e);
return false;
}
return true;
}
推理代码:
@Override
public String predict(String query, String candidateAnswer) {
if (StringUtils.isEmpty(query) || StringUtils.isEmpty(candidateAnswer)) {
logger.info(String.format("query:%s candidate:%s can't be empty or null!", query, candidateAnswer));
return null;
}
String queryPad = preprocess(query, SEQUENCE_MAX_LEN);
String candidatePad = preprocess(candidateAnswer, SEQUENCE_MAX_LEN);
try(Tensor queryTensor = Tensor.create(queryPad.getBytes());
Tensor queryLenTensor = Tensor.create(SEQUENCE_MAX_LEN);
Tensor candidateTensor = Tensor.create(candidatePad.getBytes());
Tensor candidateLenTensor = Tensor.create(SEQUENCE_MAX_LEN))
{
List<Tensor> result = session.runner()
.feed("source_tokens", queryTensor)
.feed("source_len", queryLenTensor)
.feed("source_candidate_tokens", candidateTensor)
.feed("source_candidate_len", candidateLenTensor)
.fetch("model/att_seq2seq/predicted_tokens_scalar")
.run();
Tensor predictedTensor = result.get(0);
String predictedTokens = new String(predictedTensor.bytesValue(), "UTF-8");
logger.info(String.format("biseq2seq model generate:\nquery:%s\ncandidate:%s\npredict_tokens:%s", query.trim(), candidateAnswer.trim(), predictedTokens));
return predictedTokens;
} catch (Exception e) {
logger.error("exception:", e);
}
return null;
}
最佳答案
是的,这是编码问题。当我们在简单的 java 应用程序 (public static void main()) 中启动模型时,它的默认编码是 UTF-8,同时调用 getBytes()。但是当我们在tomcat中启动模型时,它的编码方案是ISO-8859-1。
张量 queryTensor = Tensor.create(queryPad.getBytes("UTF-8"))
张量 candidateTensor = Tensor.create(candidatePad.getBytes("UTF-8"))
@Override
public String predict(String query, String candidateAnswer) {
if (StringUtils.isEmpty(query) || StringUtils.isEmpty(candidateAnswer)) {
logger.info(String.format("query:%s candidate:%s can't be empty or null!", query, candidateAnswer));
return null;
}
String queryPad = preprocess(query, SEQUENCE_MAX_LEN);
String candidatePad = preprocess(candidateAnswer, SEQUENCE_MAX_LEN);
try(Tensor queryTensor = Tensor.create(queryPad.getBytes("UTF-8"));
Tensor queryLenTensor = Tensor.create(SEQUENCE_MAX_LEN);
Tensor candidateTensor = Tensor.create(candidatePad.getBytes("UTF-8"));
Tensor candidateLenTensor = Tensor.create(SEQUENCE_MAX_LEN))
{
List<Tensor> result = session.runner()
.feed("source_tokens", queryTensor)
.feed("source_len", queryLenTensor)
.feed("source_candidate_tokens", candidateTensor)
.feed("source_candidate_len", candidateLenTensor)
.fetch("model/att_seq2seq/predicted_tokens_scalar")
.run();
Tensor predictedTensor = result.get(0);
String predictedTokens = new String(predictedTensor.bytesValue(), "UTF-8");
logger.info(String.format("biseq2seq model generate:\nquery:%s\ncandidate:%s\npredict_tokens:%s", query.trim(), candidateAnswer.trim(), predictedTokens));
return predictedTokens;
} catch (Exception e) {
logger.error("exception:", e);
}
return null;
}
关于java - Tomcat 中相同的 tensorflow 模型推理从简单的 Java 应用程序中得到不同的结果,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46445343/