zouyu
2025-11-27 eed98e551c817ead7965e08820d4b7adbc4a47f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package com.xindao.ocr.smartjavaai.factory;
 
import cn.smartjavaai.common.config.Config;
import com.xindao.ocr.smartjavaai.config.TableStructureConfig;
import com.xindao.ocr.smartjavaai.enums.TableStructureModelEnum;
import com.xindao.ocr.smartjavaai.exception.OcrException;
import com.xindao.ocr.smartjavaai.model.table.CommonTableStructureModel;
import com.xindao.ocr.smartjavaai.model.table.TableStructureModel;
import lombok.extern.slf4j.Slf4j;
 
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
 
/**
 * OCR 表格识别模型工厂
 * @author dwj
 */
@Slf4j
public class TableRecModelFactory {
 
    // 使用 volatile 和双重检查锁定来确保线程安全的单例模式
    private static volatile TableRecModelFactory instance;
 
    /**
     * 模型缓存
     */
    private static final ConcurrentHashMap<TableStructureModelEnum, TableStructureModel> tableStructureModelMap = new ConcurrentHashMap<>();
 
 
    /**
     * 模型注册表
     */
    private static final Map<TableStructureModelEnum, Class<? extends TableStructureModel>> tableStructureRegistry =
            new ConcurrentHashMap<>();
 
 
    public static TableRecModelFactory getInstance() {
        if (instance == null) {
            synchronized (TableRecModelFactory.class) {
                if (instance == null) {
                    instance = new TableRecModelFactory();
                }
            }
        }
        return instance;
    }
 
 
 
    /**
     * 注册模型
     * @param tableStructureModelEnum
     * @param clazz
     */
    private static void registerTableStructureModel(TableStructureModelEnum tableStructureModelEnum, Class<? extends TableStructureModel> clazz) {
        tableStructureRegistry.put(tableStructureModelEnum, clazz);
    }
 
 
    /**
     * 获取模型(通过配置)
     * @param config
     * @return
     */
    public TableStructureModel getTableStructureModel(TableStructureConfig config) {
        if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){
            throw new OcrException("未配置OCR模型");
        }
        return tableStructureModelMap.computeIfAbsent(config.getModelEnum(), k -> {
            return createTableStructureModel(config);
        });
    }
 
 
 
    /**
     * 创建模型
     * @param config
     * @return
     */
    private TableStructureModel createTableStructureModel(TableStructureConfig config) {
        Class<?> clazz = tableStructureRegistry.get(config.getModelEnum());
        if(clazz == null){
            throw new OcrException("Unsupported model");
        }
        TableStructureModel model = null;
        try {
            model = (TableStructureModel) clazz.newInstance();
        } catch (InstantiationException | IllegalAccessException e) {
            throw new OcrException(e);
        }
        model.loadModel(config);
        return model;
    }
 
 
 
 
    // 初始化默认算法
    static {
        registerTableStructureModel(TableStructureModelEnum.SLANET, CommonTableStructureModel.class);
        registerTableStructureModel(TableStructureModelEnum.SLANET_PLUS, CommonTableStructureModel.class);
        log.debug("缓存目录:{}", Config.getCachePath());
    }
 
}