zouyu
2025-11-27 eed98e551c817ead7965e08820d4b7adbc4a47f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
package com.xindao.ocr.smartjavaai.factory;
 
import cn.smartjavaai.common.config.Config;
import com.xindao.ocr.smartjavaai.config.DirectionModelConfig;
import com.xindao.ocr.smartjavaai.config.OcrDetModelConfig;
import com.xindao.ocr.smartjavaai.config.OcrRecModelConfig;
import com.xindao.ocr.smartjavaai.enums.CommonDetModelEnum;
import com.xindao.ocr.smartjavaai.enums.CommonRecModelEnum;
import com.xindao.ocr.smartjavaai.enums.DirectionModelEnum;
import com.xindao.ocr.smartjavaai.exception.OcrException;
import com.xindao.ocr.smartjavaai.model.common.detect.OcrCommonDetModel;
import com.xindao.ocr.smartjavaai.model.common.detect.OcrCommonDetModelImpl;
import com.xindao.ocr.smartjavaai.model.common.direction.OcrDirectionModel;
import com.xindao.ocr.smartjavaai.model.common.direction.PPOCRMobileV2ClsModel;
import com.xindao.ocr.smartjavaai.model.common.recognize.OcrCommonRecModel;
import com.xindao.ocr.smartjavaai.model.common.recognize.OcrCommonRecModelImpl;
import lombok.extern.slf4j.Slf4j;
 
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
 
/**
 * OCR模型工厂
 * @author dwj
 */
@Slf4j
public class OcrModelFactory {
 
    // 使用 volatile 和双重检查锁定来确保线程安全的单例模式
    private static volatile OcrModelFactory instance;
 
    private static final ConcurrentHashMap<CommonDetModelEnum, OcrCommonDetModel> commonDetModelMap = new ConcurrentHashMap<>();
 
 
    private static final ConcurrentHashMap<CommonRecModelEnum, OcrCommonRecModel> commonRecModelMap = new ConcurrentHashMap<>();
 
    private static final ConcurrentHashMap<DirectionModelEnum, OcrDirectionModel> directionModelMap = new ConcurrentHashMap<>();
 
    /**
     * 检测模型注册表
     */
    private static final Map<CommonDetModelEnum, Class<? extends OcrCommonDetModel>> commonDetRegistry =
            new ConcurrentHashMap<>();
 
    /**
     * 识别模型注册表
     */
    private static final Map<CommonRecModelEnum, Class<? extends OcrCommonRecModel>> commonRecRegistry =
            new ConcurrentHashMap<>();
 
    /**
     * 方向分类模型注册表
     */
    private static final Map<DirectionModelEnum, Class<? extends OcrDirectionModel>> directionRegistry =
            new ConcurrentHashMap<>();
 
 
    public static OcrModelFactory getInstance() {
        if (instance == null) {
            synchronized (OcrModelFactory.class) {
                if (instance == null) {
                    instance = new OcrModelFactory();
                }
            }
        }
        return instance;
    }
 
 
 
    /**
     * 注册通用检测模型
     * @param detModelEnum
     * @param clazz
     */
    private static void registerCommonDetModel(CommonDetModelEnum detModelEnum, Class<? extends OcrCommonDetModel> clazz) {
        commonDetRegistry.put(detModelEnum, clazz);
    }
 
    /**
     * 注册通用识别模型
     * @param recModelEnum
     * @param clazz
     */
    private static void registerCommonRecModel(CommonRecModelEnum recModelEnum, Class<? extends OcrCommonRecModel> clazz) {
        commonRecRegistry.put(recModelEnum, clazz);
    }
 
    /**
     * 注册通用方向分类模型
     * @param directionModelEnum
     * @param clazz
     */
    private static void registerDirectionModel(DirectionModelEnum directionModelEnum, Class<? extends OcrDirectionModel> clazz) {
        directionRegistry.put(directionModelEnum, clazz);
    }
 
 
    /**
     * 获取检测模型(通过配置)
     * @param config
     * @return
     */
    public OcrCommonDetModel getDetModel(OcrDetModelConfig config) {
        if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){
            throw new OcrException("未配置OCR模型");
        }
        return commonDetModelMap.computeIfAbsent(config.getModelEnum(), k -> {
            return createCommonDetModel(config);
        });
    }
 
    /**
     * 获取识别模型(通过配置)
     * @param config
     * @return
     */
    public OcrCommonRecModel getRecModel(OcrRecModelConfig config) {
        if(Objects.isNull(config) || Objects.isNull(config.getRecModelEnum())){
            throw new OcrException("未配置OCR模型");
        }
        return commonRecModelMap.computeIfAbsent(config.getRecModelEnum(), k -> {
            return createCommonRecModel(config);
        });
    }
 
    /**
     * 获取模型(通过配置)
     * @param config
     * @return
     */
    public OcrDirectionModel getDirectionModel(DirectionModelConfig config) {
        if(Objects.isNull(config) || Objects.isNull(config.getModelEnum())){
            throw new OcrException("未配置OCR模型");
        }
        return directionModelMap.computeIfAbsent(config.getModelEnum(), k -> {
            return createDirectionModel(config);
        });
    }
 
 
 
    /**
     * 创建OCR通用检测模型
     * @param config
     * @return
     */
    private OcrCommonDetModel createCommonDetModel(OcrDetModelConfig config) {
        Class<?> clazz = commonDetRegistry.get(config.getModelEnum());
        if(clazz == null){
            throw new OcrException("Unsupported model");
        }
        OcrCommonDetModel model = null;
        try {
            model = (OcrCommonDetModel) clazz.newInstance();
        } catch (InstantiationException | IllegalAccessException e) {
            throw new OcrException(e);
        }
        model.loadModel(config);
        return model;
    }
 
 
    /**
     * 创建OCR通用识别模型
     * @param config
     * @return
     */
    private OcrCommonRecModel createCommonRecModel(OcrRecModelConfig config) {
        Class<?> clazz = commonRecRegistry.get(config.getRecModelEnum());
        if(clazz == null){
            throw new OcrException("Unsupported model");
        }
        OcrCommonRecModel model = null;
        try {
            model = (OcrCommonRecModel) clazz.newInstance();
        } catch (InstantiationException | IllegalAccessException e) {
            throw new OcrException(e);
        }
        model.loadModel(config);
        return model;
    }
 
    /**
     * 创建OCR方向分类模型
     * @param config
     * @return
     */
    private OcrDirectionModel createDirectionModel(DirectionModelConfig config) {
        Class<?> clazz = directionRegistry.get(config.getModelEnum());
        if(clazz == null){
            throw new OcrException("Unsupported model");
        }
        OcrDirectionModel model = null;
        try {
            model = (OcrDirectionModel) clazz.newInstance();
        } catch (InstantiationException | IllegalAccessException e) {
            throw new OcrException(e);
        }
        model.loadModel(config);
        return model;
    }
 
 
    // 初始化默认算法
    static {
        //通用-检测模型
        registerCommonDetModel(CommonDetModelEnum.PP_OCR_V5_SERVER_DET_MODEL, OcrCommonDetModelImpl.class);
        registerCommonDetModel(CommonDetModelEnum.PP_OCR_V5_MOBILE_DET_MODEL, OcrCommonDetModelImpl.class);
        registerCommonDetModel(CommonDetModelEnum.PP_OCR_V4_SERVER_DET_MODEL, OcrCommonDetModelImpl.class);
        registerCommonDetModel(CommonDetModelEnum.PP_OCR_V4_MOBILE_DET_MODEL, OcrCommonDetModelImpl.class);
        registerCommonRecModel(CommonRecModelEnum.PP_OCR_V5_SERVER_REC_MODEL, OcrCommonRecModelImpl.class);
        registerCommonRecModel(CommonRecModelEnum.PP_OCR_V5_MOBILE_REC_MODEL, OcrCommonRecModelImpl.class);
        registerCommonRecModel(CommonRecModelEnum.PP_OCR_V4_SERVER_REC_MODEL, OcrCommonRecModelImpl.class);
        registerCommonRecModel(CommonRecModelEnum.PP_OCR_V4_MOBILE_REC_MODEL, OcrCommonRecModelImpl.class);
        registerDirectionModel(DirectionModelEnum.CH_PPOCR_MOBILE_V2_CLS, PPOCRMobileV2ClsModel.class);
        registerDirectionModel(DirectionModelEnum.PP_LCNET_X0_25, PPOCRMobileV2ClsModel.class);
        registerDirectionModel(DirectionModelEnum.PP_LCNET_X1_0, PPOCRMobileV2ClsModel.class);
        log.debug("缓存目录:{}", Config.getCachePath());
    }
 
}