package com.xindao.ocr.smartjavaai.factory;
|
|
import cn.smartjavaai.common.config.Config;
|
import com.xindao.ocr.smartjavaai.config.DirectionModelConfig;
|
import com.xindao.ocr.smartjavaai.config.OcrDetModelConfig;
|
import com.xindao.ocr.smartjavaai.config.OcrRecModelConfig;
|
import com.xindao.ocr.smartjavaai.enums.CommonDetModelEnum;
|
import com.xindao.ocr.smartjavaai.enums.CommonRecModelEnum;
|
import com.xindao.ocr.smartjavaai.enums.DirectionModelEnum;
|
import com.xindao.ocr.smartjavaai.exception.OcrException;
|
import com.xindao.ocr.smartjavaai.model.common.detect.OcrCommonDetModel;
|
import com.xindao.ocr.smartjavaai.model.common.detect.OcrCommonDetModelImpl;
|
import com.xindao.ocr.smartjavaai.model.common.direction.OcrDirectionModel;
|
import com.xindao.ocr.smartjavaai.model.common.direction.PPOCRMobileV2ClsModel;
|
import com.xindao.ocr.smartjavaai.model.common.recognize.OcrCommonRecModel;
|
import com.xindao.ocr.smartjavaai.model.common.recognize.OcrCommonRecModelImpl;
|
import lombok.extern.slf4j.Slf4j;
|
|
import java.util.Map;
|
import java.util.Objects;
|
import java.util.concurrent.ConcurrentHashMap;
|
|
/**
|
* OCR模型工厂
|
* @author dwj
|
*/
|
@Slf4j
|
public class OcrModelFactory {
|
|
// 使用 volatile 和双重检查锁定来确保线程安全的单例模式
|
private static volatile OcrModelFactory instance;
|
|
private static final ConcurrentHashMap<CommonDetModelEnum, OcrCommonDetModel> commonDetModelMap = new ConcurrentHashMap<>();
|
|
|
private static final ConcurrentHashMap<CommonRecModelEnum, OcrCommonRecModel> commonRecModelMap = new ConcurrentHashMap<>();
|
|
private static final ConcurrentHashMap<DirectionModelEnum, OcrDirectionModel> directionModelMap = new ConcurrentHashMap<>();
|
|
/**
|
* 检测模型注册表
|
*/
|
private static final Map<CommonDetModelEnum, Class<? extends OcrCommonDetModel>> commonDetRegistry =
|
new ConcurrentHashMap<>();
|
|
/**
|
* 识别模型注册表
|
*/
|
private static final Map<CommonRecModelEnum, Class<? extends OcrCommonRecModel>> commonRecRegistry =
|
new ConcurrentHashMap<>();
|
|
/**
|
* 方向分类模型注册表
|
*/
|
private static final Map<DirectionModelEnum, Class<? extends OcrDirectionModel>> directionRegistry =
|
new ConcurrentHashMap<>();
|
|
|
public static OcrModelFactory getInstance() {
|
if (instance == null) {
|
synchronized (OcrModelFactory.class) {
|
if (instance == null) {
|
instance = new OcrModelFactory();
|
}
|
}
|
}
|
return instance;
|
}
|
|
|
|
/**
|
* 注册通用检测模型
|
* @param detModelEnum
|
* @param clazz
|
*/
|
private static void registerCommonDetModel(CommonDetModelEnum detModelEnum, Class<? extends OcrCommonDetModel> clazz) {
|
commonDetRegistry.put(detModelEnum, clazz);
|
}
|
|
/**
|
* 注册通用识别模型
|
* @param recModelEnum
|
* @param clazz
|
*/
|
private static void registerCommonRecModel(CommonRecModelEnum recModelEnum, Class<? extends OcrCommonRecModel> clazz) {
|
commonRecRegistry.put(recModelEnum, clazz);
|
}
|
|
/**
|
* 注册通用方向分类模型
|
* @param directionModelEnum
|
* @param clazz
|
*/
|
private static void registerDirectionModel(DirectionModelEnum directionModelEnum, Class<? extends OcrDirectionModel> clazz) {
|
directionRegistry.put(directionModelEnum, clazz);
|
}
|
|
|
/**
|
* 获取检测模型(通过配置)
|
* @param config
|
* @return
|
*/
|
public OcrCommonDetModel getDetModel(OcrDetModelConfig config) {
|
if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){
|
throw new OcrException("未配置OCR模型");
|
}
|
return commonDetModelMap.computeIfAbsent(config.getModelEnum(), k -> {
|
return createCommonDetModel(config);
|
});
|
}
|
|
/**
|
* 获取识别模型(通过配置)
|
* @param config
|
* @return
|
*/
|
public OcrCommonRecModel getRecModel(OcrRecModelConfig config) {
|
if(Objects.isNull(config) || Objects.isNull(config.getRecModelEnum())){
|
throw new OcrException("未配置OCR模型");
|
}
|
return commonRecModelMap.computeIfAbsent(config.getRecModelEnum(), k -> {
|
return createCommonRecModel(config);
|
});
|
}
|
|
/**
|
* 获取模型(通过配置)
|
* @param config
|
* @return
|
*/
|
public OcrDirectionModel getDirectionModel(DirectionModelConfig config) {
|
if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){
|
throw new OcrException("未配置OCR模型");
|
}
|
return directionModelMap.computeIfAbsent(config.getModelEnum(), k -> {
|
return createDirectionModel(config);
|
});
|
}
|
|
|
|
/**
|
* 创建OCR通用检测模型
|
* @param config
|
* @return
|
*/
|
private OcrCommonDetModel createCommonDetModel(OcrDetModelConfig config) {
|
Class<?> clazz = commonDetRegistry.get(config.getModelEnum());
|
if(clazz == null){
|
throw new OcrException("Unsupported model");
|
}
|
OcrCommonDetModel model = null;
|
try {
|
model = (OcrCommonDetModel) clazz.newInstance();
|
} catch (InstantiationException | IllegalAccessException e) {
|
throw new OcrException(e);
|
}
|
model.loadModel(config);
|
return model;
|
}
|
|
|
/**
|
* 创建OCR通用识别模型
|
* @param config
|
* @return
|
*/
|
private OcrCommonRecModel createCommonRecModel(OcrRecModelConfig config) {
|
Class<?> clazz = commonRecRegistry.get(config.getRecModelEnum());
|
if(clazz == null){
|
throw new OcrException("Unsupported model");
|
}
|
OcrCommonRecModel model = null;
|
try {
|
model = (OcrCommonRecModel) clazz.newInstance();
|
} catch (InstantiationException | IllegalAccessException e) {
|
throw new OcrException(e);
|
}
|
model.loadModel(config);
|
return model;
|
}
|
|
/**
|
* 创建OCR方向分类模型
|
* @param config
|
* @return
|
*/
|
private OcrDirectionModel createDirectionModel(DirectionModelConfig config) {
|
Class<?> clazz = directionRegistry.get(config.getModelEnum());
|
if(clazz == null){
|
throw new OcrException("Unsupported model");
|
}
|
OcrDirectionModel model = null;
|
try {
|
model = (OcrDirectionModel) clazz.newInstance();
|
} catch (InstantiationException | IllegalAccessException e) {
|
throw new OcrException(e);
|
}
|
model.loadModel(config);
|
return model;
|
}
|
|
|
// 初始化默认算法
|
static {
|
//通用-检测模型
|
registerCommonDetModel(CommonDetModelEnum.PP_OCR_V5_SERVER_DET_MODEL, OcrCommonDetModelImpl.class);
|
registerCommonDetModel(CommonDetModelEnum.PP_OCR_V5_MOBILE_DET_MODEL, OcrCommonDetModelImpl.class);
|
registerCommonDetModel(CommonDetModelEnum.PP_OCR_V4_SERVER_DET_MODEL, OcrCommonDetModelImpl.class);
|
registerCommonDetModel(CommonDetModelEnum.PP_OCR_V4_MOBILE_DET_MODEL, OcrCommonDetModelImpl.class);
|
registerCommonRecModel(CommonRecModelEnum.PP_OCR_V5_SERVER_REC_MODEL, OcrCommonRecModelImpl.class);
|
registerCommonRecModel(CommonRecModelEnum.PP_OCR_V5_MOBILE_REC_MODEL, OcrCommonRecModelImpl.class);
|
registerCommonRecModel(CommonRecModelEnum.PP_OCR_V4_SERVER_REC_MODEL, OcrCommonRecModelImpl.class);
|
registerCommonRecModel(CommonRecModelEnum.PP_OCR_V4_MOBILE_REC_MODEL, OcrCommonRecModelImpl.class);
|
registerDirectionModel(DirectionModelEnum.CH_PPOCR_MOBILE_V2_CLS, PPOCRMobileV2ClsModel.class);
|
registerDirectionModel(DirectionModelEnum.PP_LCNET_X0_25, PPOCRMobileV2ClsModel.class);
|
registerDirectionModel(DirectionModelEnum.PP_LCNET_X1_0, PPOCRMobileV2ClsModel.class);
|
log.debug("缓存目录:{}", Config.getCachePath());
|
}
|
|
}
|