ONNX格式模型 学习笔记 (onnxRuntime部署)---用java调用yolov8模型来举例

慈云数据 2024-05-14 技术支持 76 0

ONNX(Open Neural Network Exchange)是一个开源项目,旨在建立一个开放的标准,使深度学习模型可以在不同的软件平台和工具之间轻松移动和重用。

ONNX模型可以用于各种应用场景,例如机器翻译、图像识别语音识别、自然语言处理等。

由于ONNX模型的互操作性,开发人员可以使用不同的框架来训练,模型可以更容易地在不同的框架之间转换,例如从PyTorch转换到TensorFlow,或从TensorFlow转换到MXNet等。然后将其部署到不同的环境中,例如云端、边缘设备或移动设备等。

ONNX还提供了一组工具和库,帮助开发人员更容易地创建、训练和部署深度学习模型。

ONNX模型是由多个节点(node)组成的图(graph),每个节点代表一个操作或一个张量(tensor)。ONNX模型还包含了一些元数据,例如模型的版本、输入和输出张量的名称等。

onnx官网

ONNX | Home

pytorch官方使用onnx模型格式举例

(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 2.2.0+cu121 documentation

TensorFlow官方使用onnx模型格式举例

https://github.com/onnx/tutorials/blob/master/tutorials/TensorflowToOnnx-1.ipynb

Netron可视化模型结构工具

Netron

你可通过该工具看到onnx具体的模型结构,点击每层都能看到其对应的内容信息

onnxRuntime  | 提供各种编程语言推导onnx格式模型的接口

ONNX Runtime | Home

比如我需要在java环境下调用一个onnx模型,我可以先导入onnxRuntime的依赖,对数据预处理后,调用onnx格式模型正向传播导出数据,然后将数据处理成我要的数据。 

onnxRuntime也提供了其他编程语言的接口,如C++、C#、JavaScript、python等等。

实际案例举例

python部分

python下利用ultralytics从网上下载并导出yolov8的onnx格式模型,用java调用onnxruntim接口,正向传播推导模型数据。

pip install ultralytics
from ultralytics import YOLO
# 加载模型
model = YOLO('yolov8n.pt')  # 加载官方模型
#加载自定义训练的模型
#model = YOLO('F:\\File\\AI\\Object\\yolov8_test\\runs\\detect\\train\\weights\\best.pt')  
# 导出模型
model.export(format='onnx')

java部分

前提安装java的opencv(Get Started - OpenCV),我这安装的是opencv480

maven依赖

        
            com.microsoft.onnxruntime
            onnxruntime
            1.12.0
        
        
        
            org.opencv
            opencv
            4.8.0
            system
            
            ${basedir}/lib/opencv-480.jar
        
        
            com.alibaba
            fastjson
            2.0.32
        

java完整代码

package com.sky;
//天宇 2023/12/21 20:23:13
import ai.onnxruntime.*;
import com.alibaba.fastjson.JSONObject;
import org.opencv.core.*;
import org.opencv.core.Point;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.nio.FloatBuffer;
import java.text.DecimalFormat;
import java.util.*;
import java.util.List;
/**
 * onnx学习笔记  GTianyu
 */
public class onnxLoadTest01 {
    public static OrtEnvironment env;
    public static OrtSession session;
    public static JSONObject names;
    public static long count;
    public static long channels;
    public static long netHeight;
    public static long netWidth;
    public static  float srcw;
    public static  float srch;
    public static float confThreshold = 0.25f;
    public static float nmsThreshold = 0.5f;
    static Mat src;
    public static void load(String path) {
         String weight = path;
        try{
            env = OrtEnvironment.getEnvironment();
            session = env.createSession(weight, new OrtSession.SessionOptions());
            OnnxModelMetadata metadata = session.getMetadata();
            Map infoMap = session.getInputInfo();
            TensorInfo nodeInfo = (TensorInfo)infoMap.get("images").getInfo();
            String nameClass = metadata.getCustomMetadata().get("names");
            System.out.println("getProducerName="+metadata.getProducerName());
            System.out.println("getGraphName="+metadata.getGraphName());
            System.out.println("getDescription="+metadata.getDescription());
            System.out.println("getDomain="+metadata.getDomain());
            System.out.println("getVersion="+metadata.getVersion());
            System.out.println("getCustomMetadata="+metadata.getCustomMetadata());
            System.out.println("getInputInfo="+infoMap);
            System.out.println("nodeInfo="+nodeInfo);
            System.out.println(nameClass);
            names = JSONObject.parseObject(nameClass.replace("\"","\"\""));
            count = nodeInfo.getShape()[0];//1 模型每次处理一张图片
            channels = nodeInfo.getShape()[1];//3 模型通道数
            netHeight = nodeInfo.getShape()[2];//640 模型高
            netWidth = nodeInfo.getShape()[3];//640 模型宽
            System.out.println(names.get(0));
            // 加载opencc需要的动态库
            System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(0);
        }
    }
    public static Map predict(String imgPath) throws Exception {
        src=Imgcodecs.imread(imgPath);
        return predictor();
    }
    public static Map predict(Mat mat) throws Exception {
        src=mat;
        return predictor();
    }
    public static OnnxTensor transferTensor(Mat dst){
        Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
        dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
        float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
        dst.get(0, 0, whc);
        float[] chw = whc2cwh(whc);
        OnnxTensor tensor = null;
        try {
            tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{count,channels,netWidth,netHeight});
        }
        catch (Exception e){
            e.printStackTrace();
            System.exit(0);
        }
        return tensor;
    }
    //宽 高 类型 to 类 宽 高
    public static float[] whc2cwh(float[] src) {
        float[] chw = new float[src.length];
        int j = 0;
        for (int ch = 0; ch  

效果:


参考文献:

使用 java-onnx 部署 yolovx 目标检测_java onnx-CSDN博客

微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon