package com.xindao.ocr.smartjavaai.model.table.criteria; import ai.djl.Device; import ai.djl.modality.cv.Image; import ai.djl.repository.zoo.Criteria; import ai.djl.training.util.ProgressBar; import cn.smartjavaai.common.enums.DeviceEnum; import com.xindao.ocr.smartjavaai.config.TableStructureConfig; import com.xindao.ocr.smartjavaai.entity.TableStructureResult; import com.xindao.ocr.smartjavaai.enums.TableStructureModelEnum; import com.xindao.ocr.smartjavaai.model.table.translator.TableStructTranslator; import java.nio.file.Paths; import java.util.Objects; /** * @author dwj * @date 2025/7/10 */ public class StructureCriteriaFactory { public static Criteria createCriteria(TableStructureConfig config) { Device device = null; if(!Objects.isNull(config.getDevice())){ device = config.getDevice() == DeviceEnum.CPU ? Device.cpu() : Device.gpu(config.getGpuId()); } Criteria criteria = null; if(config.getModelEnum() == TableStructureModelEnum.SLANET){ criteria = Criteria.builder() .optEngine("OnnxRuntime") .setTypes(Image.class, TableStructureResult.class) .optModelPath(Paths.get(config.getModelPath())) .optOption("removePass", "repeated_fc_relu_fuse_pass") .optDevice(device) .optTranslator(new TableStructTranslator()) .optProgress(new ProgressBar()) .build(); }else if(config.getModelEnum() == TableStructureModelEnum.SLANET_PLUS){ criteria = Criteria.builder() .optEngine("OnnxRuntime") .setTypes(Image.class, TableStructureResult.class) .optModelPath(Paths.get(config.getModelPath())) .optOption("removePass", "repeated_fc_relu_fuse_pass") .optDevice(device) .optTranslator(new TableStructTranslator()) .optProgress(new ProgressBar()) .build(); } return criteria; } }