java - 在 Java 中加载 ONNX 模型

标签 java caffe2

我有一个经过训练的 PyTorch 模型,我现在想使用 ONNX 将其导出到 Caffe2。这部分看起来相当简单并且有据可查。但是,我现在想将该模型“加载”到 Java 程序中,以便在我的程序(Flink 流应用程序)中执行预测。做这个的最好方式是什么?我无法在网站上找到任何描述如何执行此操作的文档。

最佳答案

目前有点棘手,但有办法。您将需要使用 JavaCPP:

我将使用 single_relu.onnx例如:

    //read ONNX
    byte[] bytes = Files.readAllBytes(Paths.get("single_relu.onnx"));
    ModelProto model = new ModelProto(); 
    ParseProtoFromBytes(model, new BytePointer(bytes), bytes.length); // parse ONNX -> protobuf model

    //preprocess model in any way you like (you can skip this step)
    check_model(model);
    InferShapes(model);
    StringVector passes = new StringVector("eliminate_nop_transpose", "eliminate_nop_pad", "fuse_consecutive_transposes", "fuse_transpose_into_gemm");
    Optimize(model, passes);
    check_model(model);
    ConvertVersion(model, 8);
    BytePointer serialized = model.SerializeAsString();
    System.out.println("model="+serialized.getString());

    //prepare nGraph backend
    Backend backend = Backend.create("CPU");
    Shape shape = new Shape(new SizeTVector(1,2 ));
    Tensor input =backend.create_tensor(f32(), shape);
    Tensor output =backend.create_tensor(f32(), shape);
    Function ng_function = import_onnx_model(serialized); // convert ONNX -> nGraph
    Executable exec = backend.compile(ng_function);
    exec.call(new NgraphTensorVector(output), new NgraphTensorVector(input));

    //collect result to array
    float[] r = new float[2];
    FloatPointer p = new FloatPointer(r);
    output.read(p, 0, r.length * 4);
    p.get(r);

    //print result
    System.out.println("[");
    for (int i = 0; i < shape.get(0); i++) {
        System.out.print(" [");
        for (int j = 0; j < shape.get(1); j++) {
            System.out.print(r[i * (int)shape.get(1) + j] + " ");
        }
        System.out.println("]");
    }
    System.out.println("]");

关于java - 在 Java 中加载 ONNX 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47464416/

相关文章:

machine-learning - 使用brew 添加新的辅助方法会引发错误

java - 如何使用参数 T 模拟和验证方法

java - 如何使用随机分布的符号填充数组?

java - 有没有更快的方法来输出 PDF 文件?

java - 使用 Mockito 跳过方法执行

python - 可以直接从 gpu 给 Caffe 或 Caffe2 输入数据吗?

python - 为 Caffe2 创建图像 LMDB

nlp - 有没有办法并行使用 fastText 的单词表示过程?

opencv - 通过 TensorFlow Lite、Caffe2 或 OpenCV 部署 cnn 模型哪个更快?

java - 从网络下载的 ListView 中行的图像应该较小