package com.xindao.ocr.smartjavaai.model.common.recognize.translator;
|
|
import ai.djl.Model;
|
import ai.djl.modality.cv.Image;
|
import ai.djl.modality.cv.util.NDImageUtils;
|
import ai.djl.ndarray.NDArray;
|
import ai.djl.ndarray.NDList;
|
import ai.djl.ndarray.index.NDIndex;
|
import ai.djl.ndarray.types.DataType;
|
import ai.djl.ndarray.types.Shape;
|
import ai.djl.translate.Batchifier;
|
import ai.djl.translate.Translator;
|
import ai.djl.translate.TranslatorContext;
|
import ai.djl.util.Utils;
|
|
import java.io.IOException;
|
import java.io.InputStream;
|
import java.util.Arrays;
|
import java.util.List;
|
import java.util.Map;
|
|
/**
|
* 文字识别前后处理
|
*
|
*/
|
public class PPOCRRecTranslator implements Translator<Image, String> {
|
private List<String> table;
|
private final boolean use_space_char;
|
|
private String batchifier;
|
|
public PPOCRRecTranslator(Map<String, ?> arguments) {
|
use_space_char =
|
arguments.containsKey("use_space_char")
|
? Boolean.parseBoolean(arguments.get("use_space_char").toString())
|
: true;
|
batchifier = arguments.containsKey("batchifier")
|
? arguments.get("batchifier").toString()
|
: "padding";
|
}
|
|
@Override
|
public void prepare(TranslatorContext ctx) throws IOException {
|
Model model = ctx.getModel();
|
try (InputStream is = model.getArtifact("dict.txt").openStream()) {
|
table = Utils.readLines(is, true);
|
table.add(0, "blank");
|
if(use_space_char){
|
table.add(" ");
|
table.add(" ");
|
}
|
else{
|
table.add("");
|
table.add("");
|
}
|
|
}
|
}
|
|
@Override
|
public String processOutput(TranslatorContext ctx, NDList list) throws IOException {
|
StringBuilder sb = new StringBuilder();
|
NDArray tokens = list.singletonOrThrow();
|
|
// long[] indices = tokens.get(0).argMax(1).toLongArray();
|
long[] indices = tokens.argMax(1).toLongArray();
|
boolean[] selection = new boolean[indices.length];
|
Arrays.fill(selection, true);
|
for (int i = 1; i < indices.length; i++) {
|
if (indices[i] == indices[i - 1]) {
|
selection[i] = false;
|
}
|
}
|
|
// 字符置信度
|
// float[] probs = new float[indices.length];
|
// for (int row = 0; row < indices.length; row++) {
|
// NDArray value = tokens.get(0).get(new NDIndex(""+ row +":" + (row + 1) +"," + indices[row] +":" + ( indices[row] + 1)));
|
// probs[row] = value.toFloatArray()[0];
|
// }
|
|
int lastIdx = 0;
|
for (int i = 0; i < indices.length; i++) {
|
if (selection[i] == true && indices[i] > 0 && !(i > 0 && indices[i] == lastIdx)) {
|
sb.append(table.get((int) indices[i]));
|
}
|
}
|
return sb.toString();
|
}
|
|
@Override
|
public NDList processInput(TranslatorContext ctx, Image input) {
|
NDArray img = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
|
int imgC = 3;
|
int imgH = 48;
|
int imgW = 320;
|
|
float max_wh_ratio = (float) imgW / (float) imgH;
|
|
int h = input.getHeight();
|
int w = input.getWidth();
|
float wh_ratio = (float) w / (float) h;
|
|
max_wh_ratio = Math.max(max_wh_ratio,wh_ratio);
|
imgW = (int)(imgH * max_wh_ratio);
|
|
int resized_w;
|
if (Math.ceil(imgH * wh_ratio) > imgW) {
|
resized_w = imgW;
|
} else {
|
resized_w = (int) (Math.ceil(imgH * wh_ratio));
|
}
|
NDArray resized_image = NDImageUtils.resize(img, resized_w, imgH);
|
resized_image = resized_image.transpose(2, 0, 1).toType(DataType.FLOAT32,false);
|
resized_image.divi(255f).subi(0.5f).divi(0.5f);
|
NDArray padding_im = ctx.getNDManager().zeros(new Shape(imgC, imgH, imgW), DataType.FLOAT32);
|
padding_im.set(new NDIndex(":,:,0:" + resized_w), resized_image);
|
|
padding_im = padding_im.flip(0);
|
// padding_im = padding_im.expandDims(0);
|
return new NDList(padding_im);
|
}
|
|
@Override
|
public Batchifier getBatchifier() {
|
return Batchifier.fromString(batchifier);
|
}
|
|
}
|