ONNX(Open Neural Network Exchange)是一个开源项目,旨在建立一个开放的标准,使深度学习模型可以在不同的软件平台和工具之间轻松移动和重用。
ONNX模型可以用于各种应用场景,例如机器翻译、图像识别、语音识别、自然语言处理等。
由于ONNX模型的互操作性,开发人员可以使用不同的框架来训练,模型可以更容易地在不同的框架之间转换,例如从PyTorch转换到TensorFlow,或从TensorFlow转换到MXNet等。然后将其部署到不同的环境中,例如云端、边缘设备或移动设备等。
ONNX还提供了一组工具和库,帮助开发人员更容易地创建、训练和部署深度学习模型。
ONNX模型是由多个节点(node)组成的图(graph),每个节点代表一个操作或一个张量(tensor)。ONNX模型还包含了一些元数据,例如模型的版本、输入和输出张量的名称等。
onnx官网
ONNX | Home
pytorch官方使用onnx模型格式举例
(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 2.2.0+cu121 documentation
TensorFlow官方使用onnx模型格式举例
https://github.com/onnx/tutorials/blob/master/tutorials/TensorflowToOnnx-1.ipynb
Netron可视化模型结构工具
Netron
你可通过该工具看到onnx具体的模型结构,点击每层都能看到其对应的内容信息
onnxRuntime | 提供各种编程语言推导onnx格式模型的接口
ONNX Runtime | Home
比如我需要在java环境下调用一个onnx模型,我可以先导入onnxRuntime的依赖,对数据预处理后,调用onnx格式模型正向传播导出数据,然后将数据处理成我要的数据。
onnxRuntime也提供了其他编程语言的接口,如C++、C#、JavaScript、python等等。
实际案例举例
python部分
python下利用ultralytics从网上下载并导出yolov8的onnx格式模型,用java调用onnxruntim接口,正向传播推导模型数据。
pip install ultralytics
from ultralytics import YOLO # 加载模型 model = YOLO('yolov8n.pt') # 加载官方模型 #加载自定义训练的模型 #model = YOLO('F:\\File\\AI\\Object\\yolov8_test\\runs\\detect\\train\\weights\\best.pt') # 导出模型 model.export(format='onnx')
java部分
前提安装java的opencv(Get Started - OpenCV),我这安装的是opencv480
maven依赖
com.microsoft.onnxruntime onnxruntime 1.12.0 org.opencv opencv 4.8.0 system ${basedir}/lib/opencv-480.jar com.alibaba fastjson 2.0.32
java完整代码
package com.sky; //天宇 2023/12/21 20:23:13 import ai.onnxruntime.*; import com.alibaba.fastjson.JSONObject; import org.opencv.core.*; import org.opencv.core.Point; import org.opencv.highgui.HighGui; import org.opencv.imgcodecs.Imgcodecs; import org.opencv.imgproc.Imgproc; import java.nio.FloatBuffer; import java.text.DecimalFormat; import java.util.*; import java.util.List; /** * onnx学习笔记 GTianyu */ public class onnxLoadTest01 { public static OrtEnvironment env; public static OrtSession session; public static JSONObject names; public static long count; public static long channels; public static long netHeight; public static long netWidth; public static float srcw; public static float srch; public static float confThreshold = 0.25f; public static float nmsThreshold = 0.5f; static Mat src; public static void load(String path) { String weight = path; try{ env = OrtEnvironment.getEnvironment(); session = env.createSession(weight, new OrtSession.SessionOptions()); OnnxModelMetadata metadata = session.getMetadata(); Map infoMap = session.getInputInfo(); TensorInfo nodeInfo = (TensorInfo)infoMap.get("images").getInfo(); String nameClass = metadata.getCustomMetadata().get("names"); System.out.println("getProducerName="+metadata.getProducerName()); System.out.println("getGraphName="+metadata.getGraphName()); System.out.println("getDescription="+metadata.getDescription()); System.out.println("getDomain="+metadata.getDomain()); System.out.println("getVersion="+metadata.getVersion()); System.out.println("getCustomMetadata="+metadata.getCustomMetadata()); System.out.println("getInputInfo="+infoMap); System.out.println("nodeInfo="+nodeInfo); System.out.println(nameClass); names = JSONObject.parseObject(nameClass.replace("\"","\"\"")); count = nodeInfo.getShape()[0];//1 模型每次处理一张图片 channels = nodeInfo.getShape()[1];//3 模型通道数 netHeight = nodeInfo.getShape()[2];//640 模型高 netWidth = nodeInfo.getShape()[3];//640 模型宽 System.out.println(names.get(0)); // 加载opencc需要的动态库 System.loadLibrary(Core.NATIVE_LIBRARY_NAME); } catch (Exception e){ e.printStackTrace(); System.exit(0); } } public static Map predict(String imgPath) throws Exception { src=Imgcodecs.imread(imgPath); return predictor(); } public static Map predict(Mat mat) throws Exception { src=mat; return predictor(); } public static OnnxTensor transferTensor(Mat dst){ Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB); dst.convertTo(dst, CvType.CV_32FC1, 1. / 255); float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ]; dst.get(0, 0, whc); float[] chw = whc2cwh(whc); OnnxTensor tensor = null; try { tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{count,channels,netWidth,netHeight}); } catch (Exception e){ e.printStackTrace(); System.exit(0); } return tensor; } //宽 高 类型 to 类 宽 高 public static float[] whc2cwh(float[] src) { float[] chw = new float[src.length]; int j = 0; for (int ch = 0; ch效果:
参考文献:
使用 java-onnx 部署 yolovx 目标检测_java onnx-CSDN博客