package com.xindao.ocr.smartjavaai.factory;
|
|
import cn.smartjavaai.common.config.Config;
|
import com.xindao.ocr.smartjavaai.config.PlateDetModelConfig;
|
import com.xindao.ocr.smartjavaai.config.PlateRecModelConfig;
|
import com.xindao.ocr.smartjavaai.enums.PlateDetModelEnum;
|
import com.xindao.ocr.smartjavaai.enums.PlateRecModelEnum;
|
import com.xindao.ocr.smartjavaai.exception.OcrException;
|
import com.xindao.ocr.smartjavaai.model.plate.CRNNPlateRecModel;
|
import com.xindao.ocr.smartjavaai.model.plate.PlateDetModel;
|
import com.xindao.ocr.smartjavaai.model.plate.PlateRecModel;
|
import com.xindao.ocr.smartjavaai.model.plate.Yolov5PlateDetModel;
|
import lombok.extern.slf4j.Slf4j;
|
|
import java.util.Map;
|
import java.util.Objects;
|
import java.util.concurrent.ConcurrentHashMap;
|
|
/**
|
* 车牌识别模型工厂
|
* @author dwj
|
*/
|
@Slf4j
|
public class PlateModelFactory {
|
|
// 使用 volatile 和双重检查锁定来确保线程安全的单例模式
|
private static volatile PlateModelFactory instance;
|
|
/**
|
* 模型缓存
|
*/
|
private static final ConcurrentHashMap<PlateDetModelEnum, PlateDetModel> detModelMap = new ConcurrentHashMap<>();
|
|
/**
|
* 模型缓存
|
*/
|
private static final ConcurrentHashMap<PlateRecModelEnum, PlateRecModel> recModelMap = new ConcurrentHashMap<>();
|
|
|
/**
|
* 模型注册表
|
*/
|
private static final Map<PlateDetModelEnum, Class<? extends PlateDetModel>> detModelRegistry =
|
new ConcurrentHashMap<>();
|
|
/**
|
* 模型注册表
|
*/
|
private static final Map<PlateRecModelEnum, Class<? extends PlateRecModel>> recModelRegistry =
|
new ConcurrentHashMap<>();
|
|
|
public static PlateModelFactory getInstance() {
|
if (instance == null) {
|
synchronized (PlateModelFactory.class) {
|
if (instance == null) {
|
instance = new PlateModelFactory();
|
}
|
}
|
}
|
return instance;
|
}
|
|
|
|
/**
|
* 注册模型
|
* @param plateDetModelEnum
|
* @param clazz
|
*/
|
private static void registerDetModel(PlateDetModelEnum plateDetModelEnum, Class<? extends PlateDetModel> clazz) {
|
detModelRegistry.put(plateDetModelEnum, clazz);
|
}
|
|
/**
|
* 注册模型
|
* @param plateRecModelEnum
|
* @param clazz
|
*/
|
private static void registerRecModel(PlateRecModelEnum plateRecModelEnum, Class<? extends PlateRecModel> clazz) {
|
recModelRegistry.put(plateRecModelEnum, clazz);
|
}
|
|
|
/**
|
* 获取模型
|
* @param config
|
* @return
|
*/
|
public PlateDetModel getDetModel(PlateDetModelConfig config) {
|
if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){
|
throw new OcrException("未配置OCR模型");
|
}
|
return detModelMap.computeIfAbsent(config.getModelEnum(), k -> {
|
return createDetModel(config);
|
});
|
}
|
|
/**
|
* 获取模型
|
* @param config
|
* @return
|
*/
|
public PlateRecModel getRecModel(PlateRecModelConfig config) {
|
if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){
|
throw new OcrException("未配置OCR模型");
|
}
|
return recModelMap.computeIfAbsent(config.getModelEnum(), k -> {
|
return createRecModel(config);
|
});
|
}
|
|
|
|
/**
|
* 创建检测模型
|
* @param config
|
* @return
|
*/
|
private PlateDetModel createDetModel(PlateDetModelConfig config) {
|
Class<?> clazz = detModelRegistry.get(config.getModelEnum());
|
if(clazz == null){
|
throw new OcrException("Unsupported model");
|
}
|
PlateDetModel model = null;
|
try {
|
model = (PlateDetModel) clazz.newInstance();
|
} catch (InstantiationException | IllegalAccessException e) {
|
throw new OcrException(e);
|
}
|
model.loadModel(config);
|
return model;
|
}
|
|
/**
|
* 创建识别模型
|
* @param config
|
* @return
|
*/
|
private PlateRecModel createRecModel(PlateRecModelConfig config) {
|
Class<?> clazz = recModelRegistry.get(config.getModelEnum());
|
if(clazz == null){
|
throw new OcrException("Unsupported model");
|
}
|
PlateRecModel model = null;
|
try {
|
model = (PlateRecModel) clazz.newInstance();
|
} catch (InstantiationException | IllegalAccessException e) {
|
throw new OcrException(e);
|
}
|
model.loadModel(config);
|
return model;
|
}
|
|
|
// 初始化默认算法
|
static {
|
registerDetModel(PlateDetModelEnum.YOLOV5, Yolov5PlateDetModel.class);
|
registerDetModel(PlateDetModelEnum.YOLOV7, Yolov5PlateDetModel.class);
|
registerRecModel(PlateRecModelEnum.PLATE_REC_CRNN, CRNNPlateRecModel.class);
|
log.debug("缓存目录:{}", Config.getCachePath());
|
}
|
|
}
|