package com.xindao.ocr.smartjavaai.model.plate; import ai.djl.MalformedModelException; import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import cn.smartjavaai.common.entity.R; import cn.smartjavaai.common.pool.PredictorFactory; import cn.smartjavaai.common.utils.Base64ImageUtils; import cn.smartjavaai.common.utils.FileUtils; import cn.smartjavaai.common.utils.ImageUtils; import cn.smartjavaai.common.utils.OpenCVUtils; import com.xindao.ocr.smartjavaai.config.PlateDetModelConfig; import com.xindao.ocr.smartjavaai.entity.PlateInfo; import com.xindao.ocr.smartjavaai.exception.OcrException; import com.xindao.ocr.smartjavaai.model.plate.criteria.PlateDetCriterialFactory; import com.xindao.ocr.smartjavaai.utils.OcrUtils; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.pool2.impl.GenericObjectPool; import org.opencv.core.Mat; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.List; import java.util.Objects; /** * Yolov5 车牌检测模型 * @author dwj * @date 2025/7/23 */ @Slf4j public class Yolov5PlateDetModel implements PlateDetModel{ private GenericObjectPool> detPredictorPool; private ZooModel detectionModel; private PlateDetModelConfig config; @Override public void loadModel(PlateDetModelConfig config) { if(StringUtils.isBlank(config.getModelPath())){ throw new OcrException("modelPath is null"); } this.config = config; //初始化 检测Criteria Criteria detCriteria = PlateDetCriterialFactory.createCriteria(config); try{ detectionModel = ModelZoo.loadModel(detCriteria); // 创建池子:每个线程独享 Predictor this.detPredictorPool = new GenericObjectPool<>(new PredictorFactory<>(detectionModel)); int predictorPoolSize = config.getPredictorPoolSize(); if(config.getPredictorPoolSize() <= 0){ predictorPoolSize = Runtime.getRuntime().availableProcessors(); // 默认等于CPU核心数 } detPredictorPool.setMaxTotal(predictorPoolSize); log.debug("当前设备: " + detectionModel.getNDManager().getDevice()); log.debug("当前引擎: " + Engine.getInstance().getEngineName()); log.debug("模型推理器线程池最大数量: " + predictorPoolSize); } catch (IOException | ModelNotFoundException | MalformedModelException e) { throw new OcrException("检测模型加载失败", e); } } @Override public R> detect(String imagePath) { if(!FileUtils.isFileExists(imagePath)){ return R.fail(R.Status.FILE_NOT_FOUND); } Image img = null; try { img = ImageFactory.getInstance().fromFile(Paths.get(imagePath)); } catch (IOException e) { throw new OcrException("无效的图片", e); } DetectedObjects detectedObjects = detect(img); if (Objects.isNull(detectedObjects) || detectedObjects.getNumberOfObjects() == 0){ return R.fail(R.Status.NO_OBJECT_DETECTED); } List plateInfoList = OcrUtils.convertToPlateInfo(detectedObjects, img); ((Mat)img.getWrappedImage()).release(); return R.ok(plateInfoList); } @Override public R> detectBase64(String base64Image) { if(StringUtils.isBlank(base64Image)){ return R.fail(R.Status.INVALID_IMAGE); } byte[] imageData = Base64ImageUtils.base64ToImage(base64Image); return detect(imageData); } @Override public R> detect(BufferedImage image) { if(!ImageUtils.isImageValid(image)){ return R.fail(R.Status.INVALID_IMAGE); } Image img = ImageFactory.getInstance().fromImage(OpenCVUtils.image2Mat(image)); DetectedObjects detectedObjects = detect(img); if (Objects.isNull(detectedObjects) || detectedObjects.getNumberOfObjects() == 0){ return R.fail(R.Status.NO_OBJECT_DETECTED); } List plateInfoList = OcrUtils.convertToPlateInfo(detectedObjects, img); ((Mat)img.getWrappedImage()).release(); return R.ok(plateInfoList); } @Override public R> detect(byte[] imageData) { if(Objects.isNull(imageData)){ return R.fail(R.Status.INVALID_IMAGE); } return detect(new ByteArrayInputStream(imageData)); } @Override public DetectedObjects detect(Image image) { Predictor predictor = null; try { predictor = detPredictorPool.borrowObject(); return predictor.predict(image); } catch (Exception e) { throw new OcrException("车牌检测错误", e); }finally { if (predictor != null) { try { detPredictorPool.returnObject(predictor); //归还 } catch (Exception e) { log.warn("归还Predictor失败", e); try { predictor.close(); // 归还失败才销毁 } catch (Exception ex) { log.error("关闭Predictor失败", ex); } } } } } @Override public R> detect(InputStream inputStream) { if(Objects.isNull(inputStream)){ return R.fail(R.Status.INVALID_IMAGE); } try { Image img = ImageFactory.getInstance().fromInputStream(inputStream); DetectedObjects detection = detect(img); List plateInfoList = OcrUtils.convertToPlateInfo(detection, img); ((Mat)img.getWrappedImage()).release(); return R.ok(plateInfoList); } catch (IOException e) { throw new OcrException("无效图片输入流", e); } } @Override public R detectAndDraw(String imagePath, String outputPath) { if(!FileUtils.isFileExists(imagePath)){ return R.fail(R.Status.FILE_NOT_FOUND); } try { Image img = ImageFactory.getInstance().fromFile(Paths.get(imagePath)); DetectedObjects detectedObjects = detect(img); if(Objects.isNull(detectedObjects) || detectedObjects.getNumberOfObjects() == 0){ return R.fail(R.Status.NO_FACE_DETECTED); } img.drawBoundingBoxes(detectedObjects); Path output = Paths.get(outputPath); log.debug("Saving to {}", output.toAbsolutePath().toString()); img.save(Files.newOutputStream(output), "png"); return R.ok(); } catch (IOException e) { throw new OcrException(e); } } @Override public R detectAndDraw(BufferedImage sourceImage) { if(!ImageUtils.isImageValid(sourceImage)){ return R.fail(R.Status.INVALID_IMAGE); } Image img = ImageFactory.getInstance().fromImage(OpenCVUtils.image2Mat(sourceImage)); DetectedObjects detectedObjects = detect(img); if(Objects.isNull(detectedObjects) || detectedObjects.getNumberOfObjects() == 0){ return R.fail(R.Status.NO_FACE_DETECTED); } img.drawBoundingBoxes(detectedObjects); try { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); // 调用 save 方法将 Image 写入字节流 img.save(outputStream, "png"); // 将字节流转换为 BufferedImage byte[] imageBytes = outputStream.toByteArray(); return R.ok(ImageIO.read(new ByteArrayInputStream(imageBytes))); } catch (IOException e) { throw new OcrException("导出图片失败", e); } } @Override public GenericObjectPool> getPool() { return detPredictorPool; } @Override public void close() throws Exception { try { if (detPredictorPool != null) { detPredictorPool.close(); } } catch (Exception e) { log.warn("关闭 predictorPool 失败", e); } try { if (detectionModel != null) { detectionModel.close(); } } catch (Exception e) { log.warn("关闭 model 失败", e); } } }