java - Tomcat 中相同的 tensorflow 模型推理从简单的 Java 应用程序中得到不同的结果

标签 java python tomcat tensorflow

我们在Tomcat7(java1.8)中部署tensorflow模型(seq2seq question answering),在调试时,我们只使用简单的java Application(public static void main() function)来测试模型推理结果。简单java应用中的推理结果与python原版相同。 但是当我们在Tomcat中启动整个包(WAR)时,得到的结果却大不相同,而推理代码/测试输入语句/模型文件都是一样的。

谁能给我们一些关于这个问题的提示?

  1. 简单的 java 应用程序(public static void main() 函数) 得到与 python tensorflow 版本推断结果相同的结果。我们将它们视为正确的。
  2. Tomcat 加载模型得到不同的结果。结果看起来是正常的句子,但考虑问题时回答的意思很差。
  3. 模型文件(protobuf)/java代码/测试输入语句,以上两种情况相同。
  4. Keepout 概率为 1.0f 以供推断。

模型加载函数:

@Override
public boolean reload(String modelURL) {
    logger.info("tensorflow version:{}", TensorFlow.version());
    try {
        logger.info("start to download model path:{}", modelURL);
        //TODO: download model
        logger.info("start to load model path:{} tag:{}", MODEL_PATH, MODEL_TAG);
        bundle = SavedModelBundle.load(MODEL_PATH, MODEL_TAG);
        session = bundle.session();
        logger.info("finish loading model!");

    } catch(Exception e) {
        logger.error("reload model exception:", e);
        return false;
    }

    return true;
}

推理代码:

    @Override
public String predict(String query, String candidateAnswer) {
    if (StringUtils.isEmpty(query) || StringUtils.isEmpty(candidateAnswer)) {
        logger.info(String.format("query:%s candidate:%s can't be empty or null!", query, candidateAnswer));
        return null;
    }
    String queryPad = preprocess(query, SEQUENCE_MAX_LEN);
    String candidatePad = preprocess(candidateAnswer, SEQUENCE_MAX_LEN);

    try(Tensor queryTensor = Tensor.create(queryPad.getBytes());
        Tensor queryLenTensor = Tensor.create(SEQUENCE_MAX_LEN);
        Tensor candidateTensor = Tensor.create(candidatePad.getBytes());
        Tensor candidateLenTensor = Tensor.create(SEQUENCE_MAX_LEN))
    {
        List<Tensor> result = session.runner()
                .feed("source_tokens", queryTensor)
                .feed("source_len", queryLenTensor)
                .feed("source_candidate_tokens", candidateTensor)
                .feed("source_candidate_len", candidateLenTensor)
                .fetch("model/att_seq2seq/predicted_tokens_scalar")
                .run();

        Tensor predictedTensor = result.get(0);
        String predictedTokens = new String(predictedTensor.bytesValue(), "UTF-8");
        logger.info(String.format("biseq2seq model generate:\nquery:%s\ncandidate:%s\npredict_tokens:%s", query.trim(), candidateAnswer.trim(), predictedTokens));
        return predictedTokens;
    } catch (Exception e) {
        logger.error("exception:", e);
    }

    return null;
}

最佳答案

是的,这是编码问题。当我们在简单的 java 应用程序 (public static void main()) 中启动模型时,它的默认编码是 UTF-8,同时调用 getBytes()。但是当我们在tomcat中启动模型时,它的编码方案是ISO-8859-1。

张量 queryTensor = Tensor.create(queryPad.getBytes("UTF-8"))

张量 candidateTensor = Tensor.create(candidatePad.getBytes("UTF-8"))

    @Override
public String predict(String query, String candidateAnswer) {
    if (StringUtils.isEmpty(query) || StringUtils.isEmpty(candidateAnswer)) {
        logger.info(String.format("query:%s candidate:%s can't be empty or null!", query, candidateAnswer));
        return null;
    }
    String queryPad = preprocess(query, SEQUENCE_MAX_LEN);
    String candidatePad = preprocess(candidateAnswer, SEQUENCE_MAX_LEN);

    try(Tensor queryTensor = Tensor.create(queryPad.getBytes("UTF-8"));
        Tensor queryLenTensor = Tensor.create(SEQUENCE_MAX_LEN);
        Tensor candidateTensor = Tensor.create(candidatePad.getBytes("UTF-8"));
        Tensor candidateLenTensor = Tensor.create(SEQUENCE_MAX_LEN))
    {
        List<Tensor> result = session.runner()
                .feed("source_tokens", queryTensor)
                .feed("source_len", queryLenTensor)
                .feed("source_candidate_tokens", candidateTensor)
                .feed("source_candidate_len", candidateLenTensor)
                .fetch("model/att_seq2seq/predicted_tokens_scalar")
                .run();

        Tensor predictedTensor = result.get(0);
        String predictedTokens = new String(predictedTensor.bytesValue(), "UTF-8");
        logger.info(String.format("biseq2seq model generate:\nquery:%s\ncandidate:%s\npredict_tokens:%s", query.trim(), candidateAnswer.trim(), predictedTokens));
        return predictedTokens;
    } catch (Exception e) {
        logger.error("exception:", e);
    }

    return null;
}

关于java - Tomcat 中相同的 tensorflow 模型推理从简单的 Java 应用程序中得到不同的结果,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46445343/

相关文章:

java - Tomcat 和 logback.xml

java - 因缺少系统属性而抛出的正确异常

java - 如何在 Java 中将这个 "Tue Nov 13 14:35:04 +0000 2012"String 格式转换为日期?

java - Maven 构建测试失败但 JUnit 运行器测试通过?

python - 在多处理中启动嵌套进程

python - 何时缓存 DataFrame?

java - 模块 javafx.controls 未找到异常

python - 如何对所有列使用 groupby agg 和重命名函数

java - apache tomcat 上的 jsp servlet 中的 session 不匹配

java - Windows 上 Tomcat 上的 Log4j2 产生警告 "unable to instantiate org.fusesource.jansi.WindowsAnsiOutputStream"