package com.xindao.ocr.smartjavaai.factory; import cn.smartjavaai.common.config.Config; import com.xindao.ocr.smartjavaai.config.TableStructureConfig; import com.xindao.ocr.smartjavaai.enums.TableStructureModelEnum; import com.xindao.ocr.smartjavaai.exception.OcrException; import com.xindao.ocr.smartjavaai.model.table.CommonTableStructureModel; import com.xindao.ocr.smartjavaai.model.table.TableStructureModel; import lombok.extern.slf4j.Slf4j; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; /** * OCR 表格识别模型工厂 * @author dwj */ @Slf4j public class TableRecModelFactory { // 使用 volatile 和双重检查锁定来确保线程安全的单例模式 private static volatile TableRecModelFactory instance; /** * 模型缓存 */ private static final ConcurrentHashMap tableStructureModelMap = new ConcurrentHashMap<>(); /** * 模型注册表 */ private static final Map> tableStructureRegistry = new ConcurrentHashMap<>(); public static TableRecModelFactory getInstance() { if (instance == null) { synchronized (TableRecModelFactory.class) { if (instance == null) { instance = new TableRecModelFactory(); } } } return instance; } /** * 注册模型 * @param tableStructureModelEnum * @param clazz */ private static void registerTableStructureModel(TableStructureModelEnum tableStructureModelEnum, Class clazz) { tableStructureRegistry.put(tableStructureModelEnum, clazz); } /** * 获取模型(通过配置) * @param config * @return */ public TableStructureModel getTableStructureModel(TableStructureConfig config) { if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){ throw new OcrException("未配置OCR模型"); } return tableStructureModelMap.computeIfAbsent(config.getModelEnum(), k -> { return createTableStructureModel(config); }); } /** * 创建模型 * @param config * @return */ private TableStructureModel createTableStructureModel(TableStructureConfig config) { Class clazz = tableStructureRegistry.get(config.getModelEnum()); if(clazz == null){ throw new OcrException("Unsupported model"); } TableStructureModel model = null; try { model = (TableStructureModel) clazz.newInstance(); } catch (InstantiationException | IllegalAccessException e) { throw new OcrException(e); } model.loadModel(config); return model; } // 初始化默认算法 static { registerTableStructureModel(TableStructureModelEnum.SLANET, CommonTableStructureModel.class); registerTableStructureModel(TableStructureModelEnum.SLANET_PLUS, CommonTableStructureModel.class); log.debug("缓存目录:{}", Config.getCachePath()); } }