package com.xindao.ocr.smartjavaai.factory;
|
|
import cn.smartjavaai.common.config.Config;
|
import com.xindao.ocr.smartjavaai.config.TableStructureConfig;
|
import com.xindao.ocr.smartjavaai.enums.TableStructureModelEnum;
|
import com.xindao.ocr.smartjavaai.exception.OcrException;
|
import com.xindao.ocr.smartjavaai.model.table.CommonTableStructureModel;
|
import com.xindao.ocr.smartjavaai.model.table.TableStructureModel;
|
import lombok.extern.slf4j.Slf4j;
|
|
import java.util.Map;
|
import java.util.Objects;
|
import java.util.concurrent.ConcurrentHashMap;
|
|
/**
|
* OCR 表格识别模型工厂
|
* @author dwj
|
*/
|
@Slf4j
|
public class TableRecModelFactory {
|
|
// 使用 volatile 和双重检查锁定来确保线程安全的单例模式
|
private static volatile TableRecModelFactory instance;
|
|
/**
|
* 模型缓存
|
*/
|
private static final ConcurrentHashMap<TableStructureModelEnum, TableStructureModel> tableStructureModelMap = new ConcurrentHashMap<>();
|
|
|
/**
|
* 模型注册表
|
*/
|
private static final Map<TableStructureModelEnum, Class<? extends TableStructureModel>> tableStructureRegistry =
|
new ConcurrentHashMap<>();
|
|
|
public static TableRecModelFactory getInstance() {
|
if (instance == null) {
|
synchronized (TableRecModelFactory.class) {
|
if (instance == null) {
|
instance = new TableRecModelFactory();
|
}
|
}
|
}
|
return instance;
|
}
|
|
|
|
/**
|
* 注册模型
|
* @param tableStructureModelEnum
|
* @param clazz
|
*/
|
private static void registerTableStructureModel(TableStructureModelEnum tableStructureModelEnum, Class<? extends TableStructureModel> clazz) {
|
tableStructureRegistry.put(tableStructureModelEnum, clazz);
|
}
|
|
|
/**
|
* 获取模型(通过配置)
|
* @param config
|
* @return
|
*/
|
public TableStructureModel getTableStructureModel(TableStructureConfig config) {
|
if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){
|
throw new OcrException("未配置OCR模型");
|
}
|
return tableStructureModelMap.computeIfAbsent(config.getModelEnum(), k -> {
|
return createTableStructureModel(config);
|
});
|
}
|
|
|
|
/**
|
* 创建模型
|
* @param config
|
* @return
|
*/
|
private TableStructureModel createTableStructureModel(TableStructureConfig config) {
|
Class<?> clazz = tableStructureRegistry.get(config.getModelEnum());
|
if(clazz == null){
|
throw new OcrException("Unsupported model");
|
}
|
TableStructureModel model = null;
|
try {
|
model = (TableStructureModel) clazz.newInstance();
|
} catch (InstantiationException | IllegalAccessException e) {
|
throw new OcrException(e);
|
}
|
model.loadModel(config);
|
return model;
|
}
|
|
|
|
|
// 初始化默认算法
|
static {
|
registerTableStructureModel(TableStructureModelEnum.SLANET, CommonTableStructureModel.class);
|
registerTableStructureModel(TableStructureModelEnum.SLANET_PLUS, CommonTableStructureModel.class);
|
log.debug("缓存目录:{}", Config.getCachePath());
|
}
|
|
}
|