zouyu
2025-09-26 3fbbfcc8f509c352c58dc8a126220b49b72ed5a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package com.xindao.ocr.smartjavaai.model.plate.translator;
 
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import cn.smartjavaai.common.utils.LetterBoxUtils;
import cn.smartjavaai.common.utils.NMSUtils;
 
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
 
/**
 * @author dwj
 */
public class Yolov8PlateDetectTranslator implements Translator<Image, DetectedObjects> {
 
    private int inputSize = 640;
    private float minConfThreshold = 0.3f;
    private float iouThreshold = 0.5f;
 
    private float confThreshold = 0;
 
    private int imageWidth;
    private int imageHeight;
 
    private int topK;
 
    private LetterBoxUtils.ResizeResult letterBoxResult;
 
    public Yolov8PlateDetectTranslator(Map<String, ?> arguments) {
        confThreshold =
                arguments.containsKey("confThreshold")
                        ? Integer.parseInt(arguments.get("confThreshold").toString())
                        : 0.3f;
 
        iouThreshold =
                arguments.containsKey("iouThreshold")
                        ? Integer.parseInt(arguments.get("iouThreshold").toString())
                        : 0.5f;
 
        topK = arguments.containsKey("topk")
                ? Integer.parseInt(arguments.get("topk").toString())
                : 100;
    }
 
    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        NDManager manager = ctx.getNDManager();
        NDArray array = input.toNDArray(manager, Image.Flag.COLOR);
        imageWidth = (int) array.getShape().get(1);
        imageHeight = (int) array.getShape().get(0);
        //Letter box resize 640x640 with padding (保持比例,补边缘)
        letterBoxResult = LetterBoxUtils.letterbox(manager, array, inputSize, inputSize, 114f, LetterBoxUtils.PaddingPosition.CENTER);
        array = letterBoxResult.image;
        // 转为 float32 且归一化到 0~1
        array = array.toType(DataType.FLOAT32, false).div(255f); // HWC
        // HWC -> CHW
        array = array.transpose(2, 0, 1); // CHW
        return new NDList(array.expandDims(0));
    }
 
    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        NDManager manager = ctx.getNDManager();
 
        NDArray preds = list.get(0); // shape: (1, 6, 8400)
        preds = preds.squeeze(0).transpose(1, 0); // shape: (8400, 6)
 
        // preds shape: (8400, 6)
        NDArray classScores = preds.get(":, 4:6"); // shape: (8400, 2)
 
        // 获取每行最大值(对应 Python 的 .amax(1))
        NDArray maxScores = classScores.max(new int[]{1}); // shape: (8400,)
 
        // 构造 mask:score > conf
        NDArray confMask = maxScores.gt(minConfThreshold); // shape: (8400,)
 
        // 应用 mask 筛选
        preds = preds.get(confMask); // shape: (N_filtered, 6)
 
        if (preds.isEmpty()) {
            return null;
        }
 
        // 提取 box (xywh),转换为 xyxy
        NDArray boxes = preds.get(":, 0:4"); // shape: (N, 4)
        boxes = xywh2xyxy(boxes); // 自定义函数:center xywh -> xyxy
 
        // 1. 得分和类别索引
        NDArray scoresAndClasses = preds.get(":, 4:6");  // shape (num, 2)
        NDArray scores = scoresAndClasses.max(new int[]{1}, true);  // keepDim = true
        NDArray index = scoresAndClasses.argMax(1).expandDims(1);  // 最大值索引,类别,shape (num, 1)
 
        // 4. 拼接
        NDArray result = NDArrays.concat(new NDList(boxes, scores, index), 1);  // 在列方向拼接
 
        // NMS 过滤掉重叠框
        int[] keepIndices = NMSUtils.nms(boxes, scores.squeeze(), iouThreshold); // scores.squeeze() ➝ (N,)
        NDArray kept = result.get(manager.create(keepIndices));
        // 如果超过 topK,则截断
        if (keepIndices.length > topK) {
            int[] topkIndices = new int[topK];
            System.arraycopy(keepIndices, 0, topkIndices, 0, topK);
            keepIndices = topkIndices;
        }
        //恢复原图坐标(除回比例,减掉 padding)
        NDArray restored = LetterBoxUtils.restoreBox(kept, letterBoxResult.r, letterBoxResult.left, letterBoxResult.top, 5,0);
 
        List<String> classNames = new ArrayList<>();
        List<Double> probabilities = new ArrayList<>();
        List<BoundingBox> boundingBoxes = new ArrayList<>();
 
        float[] flatData = restored.toFloatArray();
        long[] shape = restored.getShape().getShape(); // 比如 (N, 14)
        int rows = (int) shape[0];
        int cols = (int) shape[1];
 
        // 把一维数组重组为二维数组
        float[][] data = new float[rows][cols];
        for (int i = 0; i < rows; i++) {
            System.arraycopy(flatData, i * cols, data[i], 0, cols);
        }
 
        for (float[] row : data) {
            // row结构:(x1, y1, x2, y2, score, classIndex)
            float x1 = row[0];
            float y1 = row[1];
            float x2 = row[2];
            float y2 = row[3];
            float score = row[4];
            int classIndex = (int) row[5];
 
            double prob = score;
            String className = classIndex == 0 ? "single" : "double";
 
            // 转相对坐标,DJL的Rectangle用比例坐标(0~1)
            double rectX = x1 / imageWidth;
            double rectY = y1 / imageHeight;
            double rectW = (x2 - x1) / imageWidth;
            double rectH = (y2 - y1) / imageHeight;
 
            // 构建 Polygon 四个角点
//            List<Point> pointsSrc = new ArrayList<>();
//            pointsSrc.add(new Point(row[5], row[6]));
//            pointsSrc.add(new Point(row[7], row[8]));
//            pointsSrc.add(new Point(row[9], row[10]));
//            pointsSrc.add(new Point(row[11], row[12]));
 
            Rectangle rectangle = new Rectangle(rectX, rectY, rectW, rectH);
            classNames.add(className);
            probabilities.add(prob);
            boundingBoxes.add(rectangle);
        }
        DetectedObjects detectedObjects = new DetectedObjects(classNames, probabilities, boundingBoxes);
        return detectedObjects;
 
    }
 
    @Override
    public Batchifier getBatchifier() {
        return null;
    }
 
 
 
    public static NDArray xywh2xyxy(NDArray xywh) {
        NDArray x = xywh.get(":, 0");
        NDArray y = xywh.get(":, 1");
        NDArray w = xywh.get(":, 2").div(2);
        NDArray h = xywh.get(":, 3").div(2);
        NDArray x1 = x.sub(w);
        NDArray y1 = y.sub(h);
        NDArray x2 = x.add(w);
        NDArray y2 = y.add(h);
        return NDArrays.stack(new NDList(x1, y1, x2, y2), 1);
    }
 
}