package com.xindao.ocr.smartjavaai.model.plate; import ai.djl.MalformedModelException; import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.modality.cv.output.DetectedObjects; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import cn.smartjavaai.common.entity.DetectionRectangle; import cn.smartjavaai.common.entity.R; import cn.smartjavaai.common.pool.PredictorFactory; import cn.smartjavaai.common.utils.Base64ImageUtils; import cn.smartjavaai.common.utils.FileUtils; import cn.smartjavaai.common.utils.ImageUtils; import cn.smartjavaai.common.utils.OpenCVUtils; import com.xindao.ocr.smartjavaai.config.PlateRecModelConfig; import com.xindao.ocr.smartjavaai.entity.PlateInfo; import com.xindao.ocr.smartjavaai.entity.PlateResult; import com.xindao.ocr.smartjavaai.enums.PlateType; import com.xindao.ocr.smartjavaai.exception.OcrException; import com.xindao.ocr.smartjavaai.model.plate.criteria.PlateRecCriterialFactory; import com.xindao.ocr.smartjavaai.utils.OcrUtils; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.pool2.impl.GenericObjectPool; import org.opencv.core.Core; import org.opencv.core.Mat; import org.opencv.core.Rect; import org.opencv.core.Size; import org.opencv.imgproc.Imgproc; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.nio.file.Paths; import java.util.ArrayList; import java.util.List; import java.util.Objects; /** * @author dwj */ @Slf4j public class CRNNPlateRecModel implements PlateRecModel{ private GenericObjectPool> recPredictorPool; private ZooModel recModel; private PlateRecModelConfig config; @Override public void loadModel(PlateRecModelConfig config) { if(StringUtils.isBlank(config.getModelPath())){ throw new OcrException("modelPath is null"); } this.config = config; //初始化 检测Criteria Criteria detCriteria = PlateRecCriterialFactory.createCriteria(config); try{ recModel = ModelZoo.loadModel(detCriteria); // 创建池子:每个线程独享 Predictor this.recPredictorPool = new GenericObjectPool<>(new PredictorFactory<>(recModel)); int predictorPoolSize = config.getPredictorPoolSize(); if(config.getPredictorPoolSize() <= 0){ predictorPoolSize = Runtime.getRuntime().availableProcessors(); // 默认等于CPU核心数 } recPredictorPool.setMaxTotal(predictorPoolSize); log.debug("当前设备: " + recModel.getNDManager().getDevice()); log.debug("当前引擎: " + Engine.getInstance().getEngineName()); log.debug("模型推理器线程池最大数量: " + predictorPoolSize); } catch (IOException | ModelNotFoundException | MalformedModelException e) { throw new OcrException("检测模型加载失败", e); } } @Override public R> recognize(String imagePath) { if(!FileUtils.isFileExists(imagePath)){ return R.fail(R.Status.FILE_NOT_FOUND); } Image img = null; try { img = ImageFactory.getInstance().fromFile(Paths.get(imagePath)); R> plateResult = recognize(img); return plateResult; } catch (IOException e) { throw new OcrException("无效的图片", e); } finally { ((Mat)img.getWrappedImage()).release(); } } @Override public R> recognizeBase64(String base64Image) { if(StringUtils.isBlank(base64Image)){ return R.fail(R.Status.INVALID_IMAGE); } byte[] imageData = Base64ImageUtils.base64ToImage(base64Image); return recognize(imageData); } @Override public R> recognize(BufferedImage image) { if(!ImageUtils.isImageValid(image)){ return R.fail(R.Status.INVALID_IMAGE); } Image img = ImageFactory.getInstance().fromImage(OpenCVUtils.image2Mat(image)); R> plateResult = recognize(img); ((Mat)img.getWrappedImage()).release(); return plateResult; } @Override public R> recognize(byte[] imageData) { if(Objects.isNull(imageData)){ return R.fail(R.Status.INVALID_IMAGE); } return recognize(new ByteArrayInputStream(imageData)); } @Override public R> recognize(Image image) { if(Objects.isNull(config.getPlateDetModel())){ return R.fail(R.Status.PARAM_ERROR.getCode(), "未指定车牌检测模型"); } DetectedObjects detectedObjects = config.getPlateDetModel().detect(image); if(Objects.isNull(detectedObjects) || detectedObjects.getNumberOfObjects() == 0){ return R.fail(R.Status.NO_OBJECT_DETECTED); } List plateInfoList = OcrUtils.convertToPlateInfo(detectedObjects, image); Predictor predictor = null; try { predictor = recPredictorPool.borrowObject(); for (PlateInfo plateInfo : plateInfoList){ DetectionRectangle detectionRectangle = plateInfo.getDetectionRectangle(); // Image subImage = image.getSubImage(detectionRectangle.getX(), detectionRectangle.getY(), detectionRectangle.getWidth(), detectionRectangle.getHeight()); //透视变换 Image subImage = OcrUtils.transformAndCrop((Mat)image.getWrappedImage(), plateInfo.getBox()); //双层车牌 if(plateInfo.getPlateType() == PlateType.DOUBLE){ Mat mergeImage = getSplitMerge((Mat)subImage.getWrappedImage()); subImage = ImageFactory.getInstance().fromImage(mergeImage); } PlateResult plateResult = predictor.predict(subImage); if(Objects.nonNull(plateResult)){ plateInfo.setPlateNumber(plateResult.getPlateNo()); plateInfo.setPlateColor(plateResult.getPlateColor()); } } return R.ok(plateInfoList); } catch (Exception e) { throw new OcrException("车牌识别错误", e); }finally { if (predictor != null) { try { recPredictorPool.returnObject(predictor); //归还 } catch (Exception e) { log.warn("归还Predictor失败", e); try { predictor.close(); // 归还失败才销毁 } catch (Exception ex) { log.error("关闭Predictor失败", ex); } } } } } /** * 双层车牌进行分割后识别 * @param img * @return */ private Mat getSplitMerge(Mat img) { int h = img.rows(); int w = img.cols(); // 上半部分:高度的前 5/12 Rect upperRect = new Rect(0, 0, w, (int)(5.0 / 12 * h)); Mat imgUpper = new Mat(img, upperRect); // 下半部分:高度从 1/3 开始 Rect lowerRect = new Rect(0, (int)(1.0 / 3 * h), w, h - (int)(1.0 / 3 * h)); Mat imgLower = new Mat(img, lowerRect); // 将上半部分 resize 到与下半部分相同大小 Mat resizedUpper = new Mat(); Size lowerSize = imgLower.size(); Imgproc.resize(imgUpper, resizedUpper, lowerSize); // 水平拼接(将上下拼成左右) List mergeList = new ArrayList<>(); mergeList.add(resizedUpper); mergeList.add(imgLower); Mat merged = new Mat(); Core.hconcat(mergeList, merged); return merged; } @Override public PlateResult recognizeCropped(Image image) { Predictor predictor = null; try { predictor = recPredictorPool.borrowObject(); return predictor.predict(image); } catch (Exception e) { throw new OcrException("车牌检测错误", e); }finally { if (predictor != null) { try { recPredictorPool.returnObject(predictor); //归还 } catch (Exception e) { log.warn("归还Predictor失败", e); try { predictor.close(); // 归还失败才销毁 } catch (Exception ex) { log.error("关闭Predictor失败", ex); } } } } } @Override public R> recognize(InputStream inputStream) { if(Objects.isNull(inputStream)){ return R.fail(R.Status.INVALID_IMAGE); } Image img = null; try { img = ImageFactory.getInstance().fromInputStream(inputStream); return recognize(img); } catch (IOException e) { throw new OcrException("无效图片输入流", e); } finally { if (img != null){ ((Mat)img.getWrappedImage()).release(); } } } @Override public R recognizeAndDraw(String imagePath, String outputPath) { if(!FileUtils.isFileExists(imagePath)){ return R.fail(R.Status.FILE_NOT_FOUND); } Image img = null; try { img = ImageFactory.getInstance().fromFile(Paths.get(imagePath)); R> plateResult = recognize(img); if(!plateResult.isSuccess()){ return R.fail(plateResult.getCode(), plateResult.getMessage()); } if(CollectionUtils.isEmpty(plateResult.getData())){ return R.fail(R.Status.NO_OBJECT_DETECTED); } BufferedImage bufferedImage = OpenCVUtils.mat2Image((Mat)img.getWrappedImage()); OcrUtils.drawPlateInfo(bufferedImage, plateResult.getData()); ImageIO.write(bufferedImage, "png", new File(outputPath)); return R.ok(); } catch (IOException e) { throw new OcrException(e); } finally { if (img != null){ ((Mat)img.getWrappedImage()).release(); } } } @Override public R recognizeAndDraw(BufferedImage sourceImage) { if(!ImageUtils.isImageValid(sourceImage)){ return R.fail(R.Status.INVALID_IMAGE); } try { R> plateResult = recognize(sourceImage); if(!plateResult.isSuccess()){ return R.fail(plateResult.getCode(), plateResult.getMessage()); } if(CollectionUtils.isEmpty(plateResult.getData())){ return R.fail(R.Status.NO_OBJECT_DETECTED); } OcrUtils.drawPlateInfo(sourceImage, plateResult.getData()); return R.ok(sourceImage); } catch (Exception e) { throw new OcrException("导出图片失败", e); } } @Override public GenericObjectPool> getPool() { return recPredictorPool; } @Override public void close() throws Exception { try { if (recPredictorPool != null) { recPredictorPool.close(); } } catch (Exception e) { log.warn("关闭 predictorPool 失败", e); } try { if (recModel != null) { recModel.close(); } } catch (Exception e) { log.warn("关闭 model 失败", e); } } }