package com.xindao.ocr.smartjavaai.model.common.recognize; import ai.djl.MalformedModelException; import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import cn.hutool.core.img.ImgUtil; import cn.smartjavaai.common.pool.PredictorFactory; import cn.smartjavaai.common.utils.FileUtils; import cn.smartjavaai.common.utils.ImageUtils; import cn.smartjavaai.common.utils.OpenCVUtils; import com.xindao.ocr.smartjavaai.config.OcrRecModelConfig; import com.xindao.ocr.smartjavaai.config.OcrRecOptions; import com.xindao.ocr.smartjavaai.entity.*; import com.xindao.ocr.smartjavaai.exception.OcrException; import com.xindao.ocr.smartjavaai.model.common.detect.OcrCommonDetModel; import com.xindao.ocr.smartjavaai.model.common.direction.OcrDirectionModel; import com.xindao.ocr.smartjavaai.model.common.recognize.criteria.OcrCommonRecCriterialFactory; import com.xindao.ocr.smartjavaai.utils.OcrUtils; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.pool2.impl.GenericObjectPool; import org.opencv.core.Mat; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; /** * PPOCRV5 识别模型 * * @author dwj */ @Slf4j public class OcrCommonRecModelImpl implements OcrCommonRecModel { private GenericObjectPool> recPredictorPool; private OcrRecModelConfig config; private ZooModel recognitionModel; private OcrDirectionModel directionModel; private OcrCommonDetModel textDetModel; @Override public void loadModel(OcrRecModelConfig config) { if (StringUtils.isBlank(config.getRecModelPath())) { throw new OcrException("recModelPath is null"); } this.config = config; this.directionModel = config.getDirectionModel(); this.textDetModel = config.getTextDetModel(); //初始化 识别Criteria Criteria recCriteria = OcrCommonRecCriterialFactory.createCriteria(config); try { recognitionModel = ModelZoo.loadModel(recCriteria); this.recPredictorPool = new GenericObjectPool<>(new PredictorFactory<>(recognitionModel)); int predictorPoolSize = config.getPredictorPoolSize(); if (config.getPredictorPoolSize() <= 0) { predictorPoolSize = Runtime.getRuntime().availableProcessors(); // 默认等于CPU核心数 } recPredictorPool.setMaxTotal(predictorPoolSize); log.debug("当前设备: " + recognitionModel.getNDManager().getDevice()); log.debug("当前引擎: " + Engine.getInstance().getEngineName()); log.debug("模型推理器线程池最大数量: " + predictorPoolSize); } catch (IOException | ModelNotFoundException | MalformedModelException e) { throw new OcrException("识别模型加载失败", e); } } @Override public OcrInfo recognize(String imagePath, OcrRecOptions options) { if (StringUtils.isBlank(config.getRecModelPath())) { throw new OcrException("recModelPath为空,无法识别"); } if (!FileUtils.isFileExists(imagePath)) { throw new OcrException("图像文件不存在"); } Image img = null; try { img = ImageFactory.getInstance().fromFile(Paths.get(imagePath)); return recognize(img, options); } catch (IOException e) { throw new OcrException("无效的图片", e); } finally { if (img != null) { ((Mat) img.getWrappedImage()).release(); } } } /** * @param image * @param options * @return */ @Override public OcrInfo recognize(Image image, OcrRecOptions options) { List result = batchRecognizeDJLImage(Collections.singletonList(image), options); if (CollectionUtils.isEmpty(result)) { throw new OcrException("OCR识别结果为空"); } return result.get(0); } /** * 批量矫正文本框 * * @param boxList * @param srcMat * @param manager * @return */ private List batchAlign(List boxList, Mat srcMat, NDManager manager) { List imageList = new ArrayList<>(boxList.size()); for (int i = 0; i < boxList.size(); i++) { //透视变换 + 裁剪 Image subImg = OcrUtils.transformAndCrop(srcMat, boxList.get(i)); //ImageUtils.saveImage(subImg, i + "crop.png", "build/output"); //高宽比 > 1.5 if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) { //旋转图片90度 subImg = OcrUtils.rotateImg(manager, subImg); //ImageUtils.saveImage(subImg, i + "rotate.png", "build/output"); } imageList.add(subImg); } return imageList; } /** * 批量矫正文本框 * * @param itemList * @param srcMat * @param manager * @return */ private List batchAlignWithDirection(List itemList, Mat srcMat, NDManager manager) { List imageList = new ArrayList<>(itemList.size()); for (OcrItem ocrItem : itemList) { //放射变换+裁剪 Image subImage = OcrUtils.transformAndCrop(srcMat, ocrItem.getOcrBox()); //ImageUtils.saveImage(subImage, UUID.randomUUID().toString() + "_aaa.png", "build/output"); //纠正文本框 subImage = OcrUtils.rotateImg(subImage, ocrItem.getAngle()); imageList.add(subImage); } return imageList; } // private RotatedBox recognize(OcrBox box,Mat srcMat,Predictor recPredictor,NDManager manager){ // try { // //透视变换 + 裁剪 // Image subImg = OcrUtils.transformAndCrop(srcMat, box); // //ImageUtils.saveImage(subImg, i + "crop.png", "build/output"); // //高宽比 > 1.5 // if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) { // //旋转图片90度 // subImg = OcrUtils.rotateImg(manager, subImg); // //ImageUtils.saveImage(subImg, i + "rotate.png", "build/output"); // } // String name = recPredictor.predict(subImg); // ((Mat)subImg.getWrappedImage()).release(); // NDArray pointsArray = manager.create(box.toFloatArray()); // return new RotatedBox(pointsArray, name); // } catch (Exception e) { // throw new OcrException("OCR检测错误", e); // } // } /** * 后处理:排序,分行 * * @param rotatedBoxes */ private OcrInfo postProcessOcrResult(List rotatedBoxes, OcrRecOptions ocrRecOptions) { //不分行 if (!ocrRecOptions.isEnableLineSplit()) { return OcrUtils.convertRotatedBoxesToOcrItems(rotatedBoxes); } //Y坐标升序排序 List initList = new ArrayList<>(); for (RotatedBox result : rotatedBoxes) { initList.add(result); } Collections.sort(initList); //多行文本框的集合 List> lines = new ArrayList<>(); List line = new ArrayList<>(); RotatedBoxCompX firstBox = new RotatedBoxCompX(initList.get(0).getBox(), initList.get(0).getText()); line.add(firstBox); lines.add((ArrayList) line); //分行判断 for (int i = 1; i < initList.size(); i++) { RotatedBoxCompX tmpBox = new RotatedBoxCompX(initList.get(i).getBox(), initList.get(i).getText()); float y1 = firstBox.getBox().toFloatArray()[1]; float y2 = tmpBox.getBox().toFloatArray()[1]; float dis = Math.abs(y2 - y1); if (dis < 20) { // 认为是同 1 行 - Considered to be in the same line line.add(tmpBox); } else { // 换行 - Line break firstBox = tmpBox; Collections.sort(line); line = new ArrayList<>(); line.add(firstBox); lines.add((ArrayList) line); } } return OcrUtils.convertToOcrInfo(lines); } @Override public void recognizeAndDraw(String imagePath, String outputPath, int fontSize, OcrRecOptions options) { if (!FileUtils.isFileExists(imagePath)) { throw new OcrException("图像文件不存在"); } try { Image img = ImageFactory.getInstance().fromFile(Paths.get(imagePath)); OcrInfo ocrInfo = recognize(img, options); if (Objects.isNull(ocrInfo) || Objects.isNull(ocrInfo.getLineList()) || ocrInfo.getLineList().isEmpty()) { throw new OcrException("未检测到文字"); } Mat wrappedImage = (Mat) img.getWrappedImage(); BufferedImage bufferedImage = OpenCVUtils.mat2Image(wrappedImage); OcrUtils.drawRectWithText(bufferedImage, ocrInfo, fontSize); ImageUtils.saveImage(bufferedImage, outputPath); wrappedImage.release(); } catch (IOException e) { throw new OcrException(e); } } @Override public OcrInfo recognize(BufferedImage image, OcrRecOptions options) { if (!ImageUtils.isImageValid(image)) { throw new OcrException("图像无效"); } Image img = ImageFactory.getInstance().fromImage(OpenCVUtils.image2Mat(image)); OcrInfo ocrInfo = recognize(img, options); ((Mat) img.getWrappedImage()).release(); return ocrInfo; } @Override public OcrInfo recognize(byte[] imageData, OcrRecOptions options) { if (Objects.isNull(imageData)) { throw new OcrException("图像无效"); } try { BufferedImage image = ImageIO.read(new ByteArrayInputStream(imageData)); return recognize(image, options); } catch (IOException e) { throw new OcrException("错误的图像", e); } } @Override public BufferedImage recognizeAndDraw(BufferedImage sourceImage, int fontSize, OcrRecOptions options) { if (!ImageUtils.isImageValid(sourceImage)) { throw new OcrException("图像无效"); } Image img = ImageFactory.getInstance().fromImage(OpenCVUtils.image2Mat(sourceImage)); OcrInfo ocrInfo = recognize(img, options); if (Objects.isNull(ocrInfo) || Objects.isNull(ocrInfo.getLineList()) || ocrInfo.getLineList().isEmpty()) { throw new OcrException("未检测到文字"); } OcrUtils.drawRectWithText(sourceImage, ocrInfo, fontSize); return sourceImage; } @Override public String recognizeAndDrawToBase64(byte[] imageData, int fontSize, OcrRecOptions options) { if (Objects.isNull(imageData)) { throw new OcrException("图像无效"); } OcrInfo ocrInfo = recognize(imageData, options); if (Objects.isNull(ocrInfo) || Objects.isNull(ocrInfo.getLineList()) || ocrInfo.getLineList().isEmpty()) { throw new OcrException("未检测到文字"); } try { BufferedImage sourceImage = ImageIO.read(new ByteArrayInputStream(imageData)); OcrUtils.drawRectWithText(sourceImage, ocrInfo, fontSize); return ImgUtil.toBase64(sourceImage, "png"); } catch (IOException e) { throw new OcrException("导出图片失败", e); } } @Override public OcrInfo recognizeAndDraw(byte[] imageData, int fontSize, OcrRecOptions options) { if (Objects.isNull(imageData)) { throw new OcrException("图像无效"); } OcrInfo ocrInfo = recognize(imageData, options); if (Objects.isNull(ocrInfo) || Objects.isNull(ocrInfo.getLineList()) || ocrInfo.getLineList().isEmpty()) { throw new OcrException("未检测到文字"); } try { BufferedImage sourceImage = ImageIO.read(new ByteArrayInputStream(imageData)); OcrUtils.drawRectWithText(sourceImage, ocrInfo, fontSize); ocrInfo.setBase64Img(ImgUtil.toBase64(sourceImage, "png")); return ocrInfo; } catch (IOException e) { throw new OcrException("导出图片失败", e); } } @Override public List batchRecognize(List imageList, OcrRecOptions options) { List djlImageList = new ArrayList<>(imageList.size()); try { for (BufferedImage bufferedImage : imageList) { djlImageList.add(ImageFactory.getInstance().fromImage(OpenCVUtils.image2Mat(bufferedImage))); } return batchRecognizeDJLImage(djlImageList, options); } catch (Exception e) { throw new OcrException(e); } finally { djlImageList.forEach(image -> ((Mat) image.getWrappedImage()).release()); } } @Override public List batchRecognizeDJLImage(List imageList, OcrRecOptions options) { if (Objects.isNull(textDetModel)) { throw new OcrException("textDetModel is null"); } OcrRecOptions ocrRecOptions = options; if (Objects.isNull(options)) { ocrRecOptions = new OcrRecOptions(); } if (CollectionUtils.isEmpty(imageList)) { throw new OcrException("imageList is empty"); } //检测文本 List> boxeList = textDetModel.batchDetectDJLImage(imageList); if (CollectionUtils.isEmpty(boxeList) || boxeList.size() != imageList.size()) { throw new OcrException("未检测到文本"); } Predictor predictor = null; List ocrInfoList = new ArrayList(); try (NDManager manager = NDManager.newBaseManager()) { predictor = recPredictorPool.borrowObject(); List allImageAlignList = new ArrayList(); //检测方向 if (ocrRecOptions.isEnableDirectionCorrect()) { if (Objects.isNull(directionModel)) { throw new OcrException("请配置方向模型"); } List matList = imageList.stream() .map(image -> (Mat) image.getWrappedImage()) .collect(Collectors.toList()); List> ocrItemList = directionModel.batchDetect(boxeList, matList); if (CollectionUtils.isEmpty(ocrItemList) || ocrItemList.size() != imageList.size()) { throw new OcrException("方向检测失败"); } allImageAlignList = new ArrayList(); for (int i = 0; i < ocrItemList.size(); i++) { Mat srcMat = (Mat) imageList.get(i).getWrappedImage(); List imageAlignList = batchAlignWithDirection(ocrItemList.get(i), srcMat, manager); // for(int j = 0; j < imageAlignList.size(); j++){ // ImageUtils.saveImage(imageAlignList.get(j),"dir-"+i+"-"+j+".png","/Users/xxx/Downloads/testing33"); // } allImageAlignList.addAll(imageAlignList); } } else { for (int i = 0; i < boxeList.size(); i++) { Mat srcMat = (Mat) imageList.get(i).getWrappedImage(); List imageAlignList = batchAlign(boxeList.get(i), srcMat, manager); // for(int j = 0; j < imageAlignList.size(); j++){ // ImageUtils.saveImage(imageAlignList.get(j),i+"-"+j+".png","/Users/xxx/Downloads/testing33"); // } allImageAlignList.addAll(imageAlignList); } } List textList = batchRecognize(allImageAlignList); int textIndex = 0; for (int i = 0; i < boxeList.size(); i++) { List rotatedBoxes = new ArrayList<>(); for (int j = 0; j < boxeList.get(i).size(); j++) { if (textIndex >= textList.size()) { throw new OcrException("识别失败: 第" + i + "张图片, 第" + j + "个文本块,未识别到文本"); } OcrBox box = boxeList.get(i).get(j); NDArray pointsArray = manager.create(box.toFloatArray()); rotatedBoxes.add(new RotatedBox(pointsArray, textList.get(textIndex))); textIndex++; } OcrInfo ocrInfo = postProcessOcrResult(rotatedBoxes, ocrRecOptions); ocrInfoList.add(ocrInfo); } return ocrInfoList; } catch (Exception e) { throw new OcrException("OCR检测错误", e); } finally { if (predictor != null) { try { recPredictorPool.returnObject(predictor); //归还 } catch (Exception e) { log.warn("归还Predictor失败", e); try { predictor.close(); // 归还失败才销毁 } catch (Exception ex) { log.error("关闭Predictor失败", ex); } } } } } private List batchRecognize(List imageAlignList) { Predictor predictor = null; try { predictor = recPredictorPool.borrowObject(); List textList = predictor.batchPredict(imageAlignList); imageAlignList.forEach(subImg -> ((Mat) subImg.getWrappedImage()).release()); return textList; } catch (Exception e) { throw new OcrException("OCR检测错误", e); } finally { if (predictor != null) { try { recPredictorPool.returnObject(predictor); //归还 } catch (Exception e) { log.warn("归还Predictor失败", e); try { predictor.close(); // 归还失败才销毁 } catch (Exception ex) { log.error("关闭Predictor失败", ex); } } } } } @Override public void setTextDetModel(OcrCommonDetModel detModel) { this.textDetModel = detModel; } @Override public OcrCommonDetModel getTextDetModel() { return textDetModel; } @Override public void setDirectionModel(OcrDirectionModel directionModel) { this.directionModel = directionModel; } @Override public OcrDirectionModel getDirectionModel() { return directionModel; } public GenericObjectPool> getRecPredictorPool() { return recPredictorPool; } @Override public void close() throws Exception { try { if (recPredictorPool != null) { recPredictorPool.close(); } } catch (Exception e) { log.warn("关闭 predictorPool 失败", e); } try { if (recognitionModel != null) { recognitionModel.close(); } } catch (Exception e) { log.warn("关闭 model 失败", e); } } }