package com.xindao.ocr.smartjavaai.model.table; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import cn.smartjavaai.common.entity.DetectionRectangle; import cn.smartjavaai.common.entity.R; import cn.smartjavaai.common.utils.FileUtils; import cn.smartjavaai.common.utils.ImageUtils; import cn.smartjavaai.common.utils.OpenCVUtils; import com.xindao.ocr.smartjavaai.config.OcrRecOptions; import com.xindao.ocr.smartjavaai.entity.OcrBox; import com.xindao.ocr.smartjavaai.entity.OcrInfo; import com.xindao.ocr.smartjavaai.entity.OcrItem; import com.xindao.ocr.smartjavaai.entity.TableStructureResult; import com.xindao.ocr.smartjavaai.exception.OcrException; import com.xindao.ocr.smartjavaai.model.common.detect.OcrCommonDetModel; import com.xindao.ocr.smartjavaai.model.common.direction.OcrDirectionModel; import com.xindao.ocr.smartjavaai.model.common.recognize.OcrCommonRecModel; import com.xindao.ocr.smartjavaai.utils.ConvertHtml2Excel; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.poi.hssf.usermodel.HSSFWorkbook; import org.opencv.core.Mat; import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.io.OutputStream; import java.nio.file.Paths; import java.util.*; import java.util.List; import java.util.concurrent.ConcurrentHashMap; /** * 表格内容识别器 * @author dwj */ @Slf4j public class TableRecognizer { private OcrCommonDetModel textDetector; private TableStructureModel tableStructureModel; private OcrCommonRecModel textRecModel; private OcrDirectionModel directionModel; private TableRecognizer(Builder builder) { this.tableStructureModel = builder.tableStructureModel; this.textRecModel = builder.textRecModel; this.directionModel = builder.directionModel; this.textDetector = builder.textDetector; textRecModel.setTextDetModel(textDetector); textRecModel.setDirectionModel(directionModel); } public static Builder builder() { return new Builder(); } // 链式设置文本识别模型 public TableRecognizer withTextRecModel(OcrCommonRecModel textRecModel) { this.textRecModel = textRecModel; return this; } // 链式设置表格结构模型 public TableRecognizer withStructureModel(TableStructureModel tableStructureModel) { this.tableStructureModel = tableStructureModel; return this; } /** * 表格识别 * @param image * @return */ public R recognize(Image image) { //表格结构识别 R result = tableStructureModel.detect(image); if(!result.isSuccess()){ return R.fail(result.getCode(), result.getMessage()); } //文本检测+文字识别 boolean enableDirectionCorrect = directionModel == null ? false : true; OcrRecOptions options = new OcrRecOptions(enableDirectionCorrect, false); OcrInfo ocrInfo = textRecModel.recognize(image, options); List tableContentList = buildTable(result.getData(), ocrInfo); String html = convertHtml(result.getData().getTableTagList(), tableContentList); result.getData().setHtml(html); return result; } /** * 表格识别 * @param image * @return */ public R recognize(BufferedImage image) { if(!ImageUtils.isImageValid(image)){ return R.fail(R.Status.INVALID_IMAGE); } Image img = null; try { img = ImageFactory.getInstance().fromImage(OpenCVUtils.image2Mat(image)); return recognize(img); } catch (Exception e) { throw new OcrException(e); } finally { if(Objects.nonNull(img)){ ((Mat)img.getWrappedImage()).release(); } } } /** * 表格识别 * @param imagePath * @return */ public R recognize(String imagePath) { if(!FileUtils.isFileExists(imagePath)){ return R.fail(R.Status.FILE_NOT_FOUND); } Image img = null; try { img = ImageFactory.getInstance().fromFile(Paths.get(imagePath)); return recognize(img); } catch (IOException e) { throw new OcrException("无效的图片", e); } finally { if (Objects.nonNull(img)){ ((Mat)img.getWrappedImage()).release(); } } } /** * 表格识别 * @param imageData * @return */ public R recognize(byte[] imageData) { if(Objects.isNull(imageData)){ return R.fail(R.Status.INVALID_IMAGE); } try { BufferedImage image = ImageIO.read(new ByteArrayInputStream(imageData)); return recognize(image); } catch (IOException e) { throw new OcrException("错误的图像", e); } } /** * 绘制表格 * @param tableStructureResult * @param image * @param savePath */ public void drawTable(TableStructureResult tableStructureResult, BufferedImage image, String savePath){ if(Objects.isNull(tableStructureResult) || CollectionUtils.isEmpty(tableStructureResult.getTableTagList())){ throw new OcrException("表格结构为空"); } for (int i = 0; i < tableStructureResult.getOcrItemList().size(); i++){ OcrItem item = tableStructureResult.getOcrItemList().get(i); DetectionRectangle detectionRectangle = item.getOcrBox().toDetectionRectangle(); ImageUtils.drawImageRectWithText(image, detectionRectangle, i + "", Color.RED); } ImageUtils.saveImage(image, savePath); } /** * 绘制表格 * @param tableStructureResult * @param image * @return */ public BufferedImage drawTable(TableStructureResult tableStructureResult, BufferedImage image){ if(Objects.isNull(tableStructureResult) || CollectionUtils.isEmpty(tableStructureResult.getTableTagList())){ throw new OcrException("表格结构为空"); } for (int i = 0; i < tableStructureResult.getOcrItemList().size(); i++){ OcrItem item = tableStructureResult.getOcrItemList().get(i); DetectionRectangle detectionRectangle = item.getOcrBox().toDetectionRectangle(); ImageUtils.drawImageRectWithText(image, detectionRectangle, i + "", Color.RED); } return image; } /** * 删除 HTML 中第一个 段落 * @param html 原始 HTML * @return 去掉 ", styleStart); if (styleEnd == -1) { return html; // 没闭合标签,不处理 } styleEnd += "".length(); // 去掉 style 块 return html.substring(0, styleStart) + html.substring(styleEnd); } /** * 导出 Excel * @param html * @param out */ public void exportExcel(String html, OutputStream out){ String content = removeStyleBlock(html); content = content.replace("", ""); content = content.replace("", ""); try (HSSFWorkbook workbook = ConvertHtml2Excel.table2Excel(content)){ workbook.write(out); out.flush(); } catch (Exception e) { throw new OcrException("导出excel失败,请检查表结构是否识别正确"); } } /** * 导出 Excel * @param html * @param savePath */ public void exportExcel(String html, String savePath){ String content = removeStyleBlock(html); content = content.replace("", ""); content = content.replace("", ""); try (HSSFWorkbook workbook = ConvertHtml2Excel.table2Excel(content)){ workbook.write(new File(savePath)); } catch (Exception e) { throw new OcrException("导出excel失败,请检查表结构是否识别正确"); } } /** * 构建表格 * @param tableStructureResult * @param ocrInfo * @return */ public List buildTable(TableStructureResult tableStructureResult, OcrInfo ocrInfo) { // 获取 Cell 与 文本检测框 的对应关系(1:N)。 Map> matched = new ConcurrentHashMap<>(); List ocrItems = ocrInfo.getOcrItemList(); for (int i = 0; i < ocrItems.size(); i++) { OcrBox ocrBox = ocrItems.get(i).getOcrBox(); int[] box_1 = { (int)ocrBox.getTopLeft().getX(), (int)ocrBox.getTopLeft().getY(), (int)ocrBox.getBottomRight().getX(), (int)ocrBox.getBottomRight().getY() }; // 获取两两cell之间的L1距离和 1- IOU List> distances = new ArrayList<>(); for (OcrItem cell : tableStructureResult.getOcrItemList()) { OcrBox cellBox = cell.getOcrBox(); int[] box_2 = { (int)cellBox.getTopLeft().getX(), (int)cellBox.getTopLeft().getY(), (int)cellBox.getBottomRight().getX(), (int)cellBox.getBottomRight().getY() }; float distance = distance(box_1, box_2); float iou = 1 - computeIou(box_1, box_2); distances.add(Pair.of(distance, iou)); } // 根据距离和IOU挑选最"近"的cell Pair nearest = sorted(distances); // 获取最小距离对应的下标id,也等价于cell的下标id (distances列表是根据遍历cells生成的) int id = 0; for (int idx = 0; idx < distances.size(); idx++) { Pair current = distances.get(idx); if (current.getLeft().floatValue() == nearest.getLeft().floatValue() && current.getRight().floatValue() == nearest.getRight().floatValue()) { id = idx; break; } } if (!matched.containsKey(id)) { List textIds = new ArrayList<>(); textIds.add(i); // cell id, text id list (dt_boxes index list) matched.put(id, textIds); } else { matched.get(id).add(i); } } List cell_contents = new ArrayList<>(); List probs = new ArrayList<>(); for (int i = 0; i < tableStructureResult.getOcrItemList().size(); i++) { List textIds = matched.get(i); List contents = new ArrayList<>(); String content = ""; if (textIds != null) { for (Integer id : textIds) { contents.add(ocrItems.get(id).getText()); } content = StringUtils.join(contents, " "); } cell_contents.add(content); probs.add(-1.0); } return cell_contents; } /** * 计算欧式距离 * Calculate L1 distance * * @param box_1 * @param box_2 * @return */ private int distance(int[] box_1, int[] box_2) { int x1 = box_1[0]; int y1 = box_1[1]; int x2 = box_1[2]; int y2 = box_1[3]; int x3 = box_2[0]; int y3 = box_2[1]; int x4 = box_2[2]; int y4 = box_2[3]; int dis = Math.abs(x3 - x1) + Math.abs(y3 - y1) + Math.abs(x4 - x2) + Math.abs(y4 - y2); int dis_2 = Math.abs(x3 - x1) + Math.abs(y3 - y1); int dis_3 = Math.abs(x4 - x2) + Math.abs(y4 - y2); return dis + Math.min(dis_2, dis_3); } /** * 计算交并比 * computing IoU * * @param rec1: (y0, x0, y1, x1), which reflects (top, left, bottom, right) * @param rec2: (y0, x0, y1, x1) * @return scala value of IoU */ private float computeIou(int[] rec1, int[] rec2) { // computing area of each rectangles int S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]); int S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]); // computing the sum_area int sum_area = S_rec1 + S_rec2; // find the each edge of intersect rectangle int left_line = Math.max(rec1[1], rec2[1]); int right_line = Math.min(rec1[3], rec2[3]); int top_line = Math.max(rec1[0], rec2[0]); int bottom_line = Math.min(rec1[2], rec2[2]); // judge if there is an intersect if (left_line >= right_line || top_line >= bottom_line) { return 0.0f; } else { float intersect = (right_line - left_line) * (bottom_line - top_line); return (intersect / (sum_area - intersect)) * 1.0f; } } /** * 距离排序 * Distance sorted * * @param distances * @return */ private Pair sorted(List> distances) { Comparator> comparator = new Comparator>() { @Override public int compare(Pair a1, Pair a2) { // 首先根据IoU排序 if (a1.getRight().floatValue() > a2.getRight().floatValue()) { return 1; } else if (a1.getRight().floatValue() == a2.getRight().floatValue()) { // 然后根据L1距离排序 if (a1.getLeft().floatValue() > a2.getLeft().floatValue()) { return 1; } return -1; } return -1; } }; // 距离排序 List> newDistances = new ArrayList<>(); CollectionUtils.addAll(newDistances, new Object[distances.size()]); Collections.copy(newDistances, distances); Collections.sort(newDistances, comparator); return newDistances.get(0); } /** * 生成表格html * Generate table html * * @param pred_structures * @param cell_contents * @return */ public String convertHtml(List pred_structures, List cell_contents) { StringBuffer html = new StringBuffer(); // 添加统一的样式(可选放到中) html.append("\n"); int td_index = 0; for (String tag : pred_structures) { if (tag.contains("")) { String content = cell_contents.get(td_index); html.append(""); html.append(content); html.append(""); td_index++; continue; } html.append(tag); } return html.toString(); } public static class Builder { private TableStructureModel tableStructureModel; private OcrCommonRecModel textRecModel; private OcrDirectionModel directionModel; private OcrCommonDetModel textDetector; public Builder withStructureModel(TableStructureModel model) { this.tableStructureModel = model; return this; } public Builder withTextRecModel(OcrCommonRecModel model) { this.textRecModel = model; return this; } public Builder withDirectionModel(OcrDirectionModel model) { this.directionModel = model; return this; } public Builder withTextDetModel(OcrCommonDetModel model) { this.textDetector = model; return this; } public TableRecognizer build() { if (this.tableStructureModel == null) { throw new IllegalStateException("tableStructureModel 未设置"); } if (this.textDetector == null) { throw new IllegalStateException("textDetector 未设置"); } if (this.textRecModel == null) { throw new IllegalStateException("textRecModel 未设置"); } return new TableRecognizer(this); } } }