package com.xindao.ocr.smartjavaai.factory; import cn.smartjavaai.common.config.Config; import com.xindao.ocr.smartjavaai.config.PlateDetModelConfig; import com.xindao.ocr.smartjavaai.config.PlateRecModelConfig; import com.xindao.ocr.smartjavaai.enums.PlateDetModelEnum; import com.xindao.ocr.smartjavaai.enums.PlateRecModelEnum; import com.xindao.ocr.smartjavaai.exception.OcrException; import com.xindao.ocr.smartjavaai.model.plate.CRNNPlateRecModel; import com.xindao.ocr.smartjavaai.model.plate.PlateDetModel; import com.xindao.ocr.smartjavaai.model.plate.PlateRecModel; import com.xindao.ocr.smartjavaai.model.plate.Yolov5PlateDetModel; import lombok.extern.slf4j.Slf4j; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; /** * 车牌识别模型工厂 * @author dwj */ @Slf4j public class PlateModelFactory { // 使用 volatile 和双重检查锁定来确保线程安全的单例模式 private static volatile PlateModelFactory instance; /** * 模型缓存 */ private static final ConcurrentHashMap detModelMap = new ConcurrentHashMap<>(); /** * 模型缓存 */ private static final ConcurrentHashMap recModelMap = new ConcurrentHashMap<>(); /** * 模型注册表 */ private static final Map> detModelRegistry = new ConcurrentHashMap<>(); /** * 模型注册表 */ private static final Map> recModelRegistry = new ConcurrentHashMap<>(); public static PlateModelFactory getInstance() { if (instance == null) { synchronized (PlateModelFactory.class) { if (instance == null) { instance = new PlateModelFactory(); } } } return instance; } /** * 注册模型 * @param plateDetModelEnum * @param clazz */ private static void registerDetModel(PlateDetModelEnum plateDetModelEnum, Class clazz) { detModelRegistry.put(plateDetModelEnum, clazz); } /** * 注册模型 * @param plateRecModelEnum * @param clazz */ private static void registerRecModel(PlateRecModelEnum plateRecModelEnum, Class clazz) { recModelRegistry.put(plateRecModelEnum, clazz); } /** * 获取模型 * @param config * @return */ public PlateDetModel getDetModel(PlateDetModelConfig config) { if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){ throw new OcrException("未配置OCR模型"); } return detModelMap.computeIfAbsent(config.getModelEnum(), k -> { return createDetModel(config); }); } /** * 获取模型 * @param config * @return */ public PlateRecModel getRecModel(PlateRecModelConfig config) { if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){ throw new OcrException("未配置OCR模型"); } return recModelMap.computeIfAbsent(config.getModelEnum(), k -> { return createRecModel(config); }); } /** * 创建检测模型 * @param config * @return */ private PlateDetModel createDetModel(PlateDetModelConfig config) { Class clazz = detModelRegistry.get(config.getModelEnum()); if(clazz == null){ throw new OcrException("Unsupported model"); } PlateDetModel model = null; try { model = (PlateDetModel) clazz.newInstance(); } catch (InstantiationException | IllegalAccessException e) { throw new OcrException(e); } model.loadModel(config); return model; } /** * 创建识别模型 * @param config * @return */ private PlateRecModel createRecModel(PlateRecModelConfig config) { Class clazz = recModelRegistry.get(config.getModelEnum()); if(clazz == null){ throw new OcrException("Unsupported model"); } PlateRecModel model = null; try { model = (PlateRecModel) clazz.newInstance(); } catch (InstantiationException | IllegalAccessException e) { throw new OcrException(e); } model.loadModel(config); return model; } // 初始化默认算法 static { registerDetModel(PlateDetModelEnum.YOLOV5, Yolov5PlateDetModel.class); registerDetModel(PlateDetModelEnum.YOLOV7, Yolov5PlateDetModel.class); registerRecModel(PlateRecModelEnum.PLATE_REC_CRNN, CRNNPlateRecModel.class); log.debug("缓存目录:{}", Config.getCachePath()); } }