package com.xindao.ocr.smartjavaai.model.common.recognize.translator; import ai.djl.Model; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.util.NDImageUtils; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.index.NDIndex; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.translate.Batchifier; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import ai.djl.util.Utils; import java.io.IOException; import java.io.InputStream; import java.util.Arrays; import java.util.List; import java.util.Map; /** * 文字识别前后处理 * */ public class PPOCRRecTranslator implements Translator { private List table; private final boolean use_space_char; private String batchifier; public PPOCRRecTranslator(Map arguments) { use_space_char = arguments.containsKey("use_space_char") ? Boolean.parseBoolean(arguments.get("use_space_char").toString()) : true; batchifier = arguments.containsKey("batchifier") ? arguments.get("batchifier").toString() : "padding"; } @Override public void prepare(TranslatorContext ctx) throws IOException { Model model = ctx.getModel(); try (InputStream is = model.getArtifact("dict.txt").openStream()) { table = Utils.readLines(is, true); table.add(0, "blank"); if(use_space_char){ table.add(" "); table.add(" "); } else{ table.add(""); table.add(""); } } } @Override public String processOutput(TranslatorContext ctx, NDList list) throws IOException { StringBuilder sb = new StringBuilder(); NDArray tokens = list.singletonOrThrow(); // long[] indices = tokens.get(0).argMax(1).toLongArray(); long[] indices = tokens.argMax(1).toLongArray(); boolean[] selection = new boolean[indices.length]; Arrays.fill(selection, true); for (int i = 1; i < indices.length; i++) { if (indices[i] == indices[i - 1]) { selection[i] = false; } } // 字符置信度 // float[] probs = new float[indices.length]; // for (int row = 0; row < indices.length; row++) { // NDArray value = tokens.get(0).get(new NDIndex(""+ row +":" + (row + 1) +"," + indices[row] +":" + ( indices[row] + 1))); // probs[row] = value.toFloatArray()[0]; // } int lastIdx = 0; for (int i = 0; i < indices.length; i++) { if (selection[i] == true && indices[i] > 0 && !(i > 0 && indices[i] == lastIdx)) { sb.append(table.get((int) indices[i])); } } return sb.toString(); } @Override public NDList processInput(TranslatorContext ctx, Image input) { NDArray img = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR); int imgC = 3; int imgH = 48; int imgW = 320; float max_wh_ratio = (float) imgW / (float) imgH; int h = input.getHeight(); int w = input.getWidth(); float wh_ratio = (float) w / (float) h; max_wh_ratio = Math.max(max_wh_ratio,wh_ratio); imgW = (int)(imgH * max_wh_ratio); int resized_w; if (Math.ceil(imgH * wh_ratio) > imgW) { resized_w = imgW; } else { resized_w = (int) (Math.ceil(imgH * wh_ratio)); } NDArray resized_image = NDImageUtils.resize(img, resized_w, imgH); resized_image = resized_image.transpose(2, 0, 1).toType(DataType.FLOAT32,false); resized_image.divi(255f).subi(0.5f).divi(0.5f); NDArray padding_im = ctx.getNDManager().zeros(new Shape(imgC, imgH, imgW), DataType.FLOAT32); padding_im.set(new NDIndex(":,:,0:" + resized_w), resized_image); padding_im = padding_im.flip(0); // padding_im = padding_im.expandDims(0); return new NDList(padding_im); } @Override public Batchifier getBatchifier() { return Batchifier.fromString(batchifier); } }