package com.xindao.ocr.smartjavaai.model.table.translator; import ai.djl.Model; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.util.NDImageUtils; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.index.NDIndex; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.translate.Batchifier; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import ai.djl.util.Utils; import cn.smartjavaai.common.entity.Point; import com.xindao.ocr.smartjavaai.entity.OcrBox; import com.xindao.ocr.smartjavaai.entity.OcrItem; import com.xindao.ocr.smartjavaai.entity.TableStructureResult; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; import java.util.List; /** * 表格识别的前后处理 */ public class TableStructTranslator implements Translator { private final int maxLength = 488; private int height; private int width; private float scale = 1.0f; private float xScale; private float yScale; private List dict; private String beg_str = "sos"; private String end_str = "eos"; private List td_token = new ArrayList<>(); @Override public void prepare(TranslatorContext ctx) throws IOException { Model model = ctx.getModel(); try (InputStream is = model.getArtifact("table_structure_dict_ch.txt").openStream()) { dict = Utils.readLines(is, false); dict.add(0,beg_str); if(dict.contains("")) dict.remove(""); if(!dict.contains("")) dict.add(""); dict.add(end_str); } td_token.add(""); td_token.add(""); } @Override public NDList processInput(TranslatorContext ctx, Image input) { NDArray img = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR); height = input.getHeight(); width = input.getWidth(); img = ResizeTableImage(img, height, width, maxLength); img = PaddingTableImage(ctx, img, maxLength); img = img.transpose(2, 0, 1).div(255).flip(0); img = NDImageUtils.normalize( img, new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f}); img = img.expandDims(0); return new NDList(img); } @Override public TableStructureResult processOutput(TranslatorContext ctx, NDList list) { NDArray bbox_preds = list.get(0); NDArray structure_probs = list.get(1); NDArray structure_idx = structure_probs.argMax(2); structure_probs = structure_probs.max(new int[]{2}); List> structure_batch_list = new ArrayList<>(); List> bbox_batch_list = new ArrayList<>(); List> result_score_list = new ArrayList<>(); // get ignored tokens int beg_idx = dict.indexOf(beg_str); int end_idx = dict.indexOf(end_str); long batch_size = structure_idx.size(0); for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { List structure_list = new ArrayList<>(); List bbox_list = new ArrayList<>(); List score_list = new ArrayList<>(); long len = structure_idx.get(batch_idx).size(); for (int idx = 0; idx < len; idx++) { int char_idx = (int) structure_idx.get(batch_idx).get(idx).toLongArray()[0]; if (idx > 0 && char_idx == end_idx) { break; } // if (char_idx == beg_idx || char_idx == end_idx) { // continue; // } String text = dict.get(char_idx); if(td_token.indexOf(text)>-1){ NDArray bbox = bbox_preds.get(batch_idx, idx); // bbox.set(new NDIndex("0::2"), bbox.get(new NDIndex("0::2"))); // bbox.set(new NDIndex("1::2"), bbox.get(new NDIndex("1::2"))); bbox_list.add(bbox); } structure_list.add(text); score_list.add(structure_probs.get(batch_idx, idx)); } structure_batch_list.add(structure_list); // structure_str bbox_batch_list.add(bbox_list); result_score_list.add(score_list); } List structure_str_list =structure_batch_list.get(0); List bbox_list = bbox_batch_list.get(0); List score_list = result_score_list.get(0); structure_str_list.add(0,""); structure_str_list.add(1,""); structure_str_list.add(2,""); structure_str_list.add("
"); structure_str_list.add(""); structure_str_list.add(""); List ocrItemList = new ArrayList<>(); for (int i = 0; i < bbox_list.size(); i++) { NDArray box = bbox_list.get(i); float[] arr = new float[4]; arr[0] = box.get(new NDIndex("0::2")).min().toFloatArray()[0]; arr[1] = box.get(new NDIndex("1::2")).min().toFloatArray()[0]; arr[2] = box.get(new NDIndex("0::2")).max().toFloatArray()[0]; arr[3] = box.get(new NDIndex("1::2")).max().toFloatArray()[0]; Point topLeft = new Point(arr[0] * xScale * width, arr[1] * yScale * height); Point topRight = new Point(arr[2] * xScale * width, arr[1] * yScale * height); Point bottomRight = new Point(arr[2] * xScale * width, arr[3] * yScale * height); Point bottomLeft = new Point(arr[0] * xScale * width, arr[3] * yScale * height); OcrBox ocrBox = new OcrBox(topLeft, topRight, bottomRight, bottomLeft); //String tag = structure_str_list.get(i + 3); // 前面加了 所以偏移+3 float score = score_list.get(i).toFloatArray()[0]; // 获取每个结构token的得分 OcrItem item = new OcrItem(); item.setOcrBox(ocrBox); item.setScore(score); //item.setTableTag(tag); ocrItemList.add(item); } return new TableStructureResult(ocrItemList, structure_str_list); } @Override public Batchifier getBatchifier() { return null; } private NDArray ResizeTableImage(NDArray img, int height, int width, int maxLen) { int localMax = Math.max(height, width); float ratio = maxLen * 1.0f / localMax; int resize_h = (int) (height * ratio); int resize_w = (int) (width * ratio); scale = ratio; if(width > height){ xScale = 1f; yScale = (float)width /(float)height; } else{ xScale = (float)height /(float)width; yScale = 1f; } img = NDImageUtils.resize(img, resize_w, resize_h); return img; } private NDArray PaddingTableImage(TranslatorContext ctx, NDArray img, int maxLen) { NDArray paddingImg = ctx.getNDManager().zeros(new Shape(maxLen, maxLen, 3), DataType.UINT8); paddingImg.set( new NDIndex("0:" + img.getShape().get(0) + ",0:" + img.getShape().get(1) + ",:"), img); return paddingImg; } }