package com.xindao.ocr.smartjavaai.model.plate.translator; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.output.BoundingBox; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.modality.cv.output.Landmark; import ai.djl.modality.cv.output.Point; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDArrays; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.translate.Batchifier; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import cn.smartjavaai.common.utils.LetterBoxUtils; import cn.smartjavaai.common.utils.NMSUtils; import java.util.ArrayList; import java.util.List; import java.util.Map; /** * @author dwj */ public class Yolov7PlateDetectTranslator implements Translator { private int inputSize = 640; private float minConfThreshold = 0.3f; private float iouThreshold = 0.5f; private float confThreshold = 0; private int imageWidth; private int imageHeight; private int topK; private LetterBoxUtils.ResizeResult letterBoxResult; public Yolov7PlateDetectTranslator(Map arguments) { confThreshold = arguments.containsKey("confThreshold") ? Integer.parseInt(arguments.get("confThreshold").toString()) : 0.3f; iouThreshold = arguments.containsKey("iouThreshold") ? Integer.parseInt(arguments.get("iouThreshold").toString()) : 0.5f; topK = arguments.containsKey("topk") ? Integer.parseInt(arguments.get("topk").toString()) : 100; } @Override public NDList processInput(TranslatorContext ctx, Image input) { NDManager manager = ctx.getNDManager(); NDArray array = input.toNDArray(manager, Image.Flag.COLOR); imageWidth = (int) array.getShape().get(1); imageHeight = (int) array.getShape().get(0); //Letter box resize 640x640 with padding (保持比例,补边缘) letterBoxResult = LetterBoxUtils.letterbox(manager, array, inputSize, inputSize, 114f, LetterBoxUtils.PaddingPosition.CENTER); array = letterBoxResult.image; // 转为 float32 且归一化到 0~1 array = array.toType(DataType.FLOAT32, false).div(255f); // HWC // HWC -> CHW array = array.transpose(2, 0, 1); // CHW return new NDList(array.expandDims(0)); } @Override public DetectedObjects processOutput(TranslatorContext ctx, NDList list) { NDManager manager = ctx.getNDManager(); int num_cls = 2; //[x_center, y_center, w, h, obj_conf, class1_conf, class2_conf,8个关键点] //目标置信度 obj_conf 5:13 关键点 [13:15]分类得分:单层车牌 / 双层车牌 NDArray dets = list.singletonOrThrow(); //置信度过滤 (1,25200, 15) NDArray dets0 = dets.get(0); NDArray conf = dets0.get(":, 4"); // shape [N] NDArray mask = conf.gt(minConfThreshold); //筛选出符合条件的框(17,15) NDArray detsFiltered = dets0.get(mask); // 筛掉低置信度 //把分类得分 [5:7] * 置信度 [4:5] 做联合概率 NDArray clsLogits = detsFiltered.get(":, 5:7"); // (N, 2) NDArray confFiltered = detsFiltered.get(":, 4").reshape(-1, 1); // (N, 1) clsLogits = clsLogits.mul(confFiltered); // (N, 2),变成 obj_conf * class_conf NDArray jointScore = clsLogits.max(new int[]{1}); // shape (N,) // 联合过滤 NDArray jointMask = jointScore.gt(confThreshold); detsFiltered = detsFiltered.get(jointMask); clsLogits = clsLogits.get(jointMask); //中心点框 [x,y,w,h] ➔ 左上右下 [x1,y1,x2,y2] NDArray xywh = detsFiltered.get(":, 0:4"); // (N, 4) NDArray halfWH = xywh.get(":, 2:4").div(2); // (N, 2) NDArray xy1 = xywh.get(":, 0:2").sub(halfWH); // (N, 2) NDArray xy2 = xywh.get(":, 0:2").add(halfWH); // (N, 2) NDArray boxes = NDArrays.concat(new NDList(xy1, xy2), 1); // (N, 4) // 分类得分最大值:score (N, 1),对应类别 index (N, 1) NDArray scores = clsLogits.max(new int[]{1}, true); // (N, 1) NDArray indices = clsLogits.argMax(1).reshape(-1, 1).toType(DataType.FLOAT32, false); // (N, 1) // 关键点坐标 [7,8,10,11,13,14,16,17] NDArray keyPoints = NDArrays.concat(new NDList( detsFiltered.get(":, 7:8"), detsFiltered.get(":, 8:9"), detsFiltered.get(":, 10:11"), detsFiltered.get(":, 11:12"), detsFiltered.get(":, 13:14"), detsFiltered.get(":, 14:15"), detsFiltered.get(":, 16:17"), detsFiltered.get(":, 17:18") ), 1); // 拼成 (N, 8) // 拼成最终结果:(x1, y1, x2, y2, score, 8关键点, index) NDArray output = NDArrays.concat(new NDList(boxes, scores, keyPoints, indices), 1); // (N, 14) // NMS 过滤掉重叠框 int[] keepIndices = NMSUtils.nms(boxes, scores.squeeze(), iouThreshold); // scores.squeeze() ➝ (N,) NDArray kept = output.get(manager.create(keepIndices)); // 如果超过 topK,则截断 if (keepIndices.length > topK) { int[] topkIndices = new int[topK]; System.arraycopy(keepIndices, 0, topkIndices, 0, topK); keepIndices = topkIndices; } //恢复原图坐标(除回比例,减掉 padding) NDArray restored = LetterBoxUtils.restoreBox(kept, letterBoxResult.r, letterBoxResult.left, letterBoxResult.top, 5,8); List classNames = new ArrayList<>(); List probabilities = new ArrayList<>(); List boundingBoxes = new ArrayList<>(); float[] flatData = restored.toFloatArray(); long[] shape = restored.getShape().getShape(); // 比如 (N, 14) int rows = (int) shape[0]; int cols = (int) shape[1]; // 把一维数组重组为二维数组 float[][] data = new float[rows][cols]; for (int i = 0; i < rows; i++) { System.arraycopy(flatData, i * cols, data[i], 0, cols); } for (float[] row : data) { // row结构:(x1, y1, x2, y2, score, kp1,..., kp8, classIndex) float x1 = row[0]; float y1 = row[1]; float x2 = row[2]; float y2 = row[3]; float score = row[4]; int classIndex = (int) row[13]; double prob = score; String className = classIndex == 0 ? "single" : "double"; // 转相对坐标,DJL的Rectangle用比例坐标(0~1) double rectX = x1 / imageWidth; double rectY = y1 / imageHeight; double rectW = (x2 - x1) / imageWidth; double rectH = (y2 - y1) / imageHeight; // 构建 Polygon 四个角点 List pointsSrc = new ArrayList<>(); pointsSrc.add(new Point(row[5], row[6])); pointsSrc.add(new Point(row[7], row[8])); pointsSrc.add(new Point(row[9], row[10])); pointsSrc.add(new Point(row[11], row[12])); Landmark box = new Landmark(rectX, rectY, rectW, rectH, pointsSrc); classNames.add(className); probabilities.add(prob); boundingBoxes.add(box); } DetectedObjects detectedObjects = new DetectedObjects(classNames, probabilities, boundingBoxes); return detectedObjects; } @Override public Batchifier getBatchifier() { return null; } }