package com.ruoyi.ai.service.impl;
|
|
import com.ruoyi.ai.service.KnowledgeRagService;
|
import com.ruoyi.approve.pojo.KnowledgeBaseVector;
|
import com.ruoyi.approve.service.KnowledgeBaseVectorService;
|
import com.ruoyi.basic.pojo.StorageBlob;
|
import com.ruoyi.basic.service.StorageBlobService;
|
import com.ruoyi.common.config.FileProperties;
|
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.scheduling.annotation.Async;
|
import org.springframework.stereotype.Service;
|
|
import java.io.File;
|
import java.nio.charset.Charset;
|
import java.nio.charset.StandardCharsets;
|
import java.nio.file.Files;
|
import java.util.ArrayList;
|
import java.util.HashMap;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.stream.Collectors;
|
|
/**
|
* 知识库RAG服务实现
|
*/
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class KnowledgeRagServiceImpl implements KnowledgeRagService {
|
|
private final KnowledgeBaseVectorService knowledgeBaseVectorService;
|
private final StorageBlobService storageBlobService;
|
private final EmbeddingModel embeddingModel;
|
private final EmbeddingStore<TextSegment> embeddingStore;
|
private final FileProperties fileProperties;
|
|
private static final int CHUNK_SIZE = 500;
|
private static final int CHUNK_OVERLAP = 100;
|
/**
|
* 文件大小阈值,超过此值才进行切片
|
* 80MB = 80 * 1024 * 1024 字节
|
*/
|
private static final long CHUNK_THRESHOLD_BYTES = 80L * 1024 * 1024;
|
/**
|
* Embedding 模型最大输入长度限制
|
* 阿里云 DashScope 限制为 8192 字符
|
*/
|
private static final int EMBEDDING_MAX_LENGTH = 8000;
|
|
@Override
|
@Async("threadPoolTaskExecutor")
|
public void processVectorAsync(Long vectorId) {
|
log.info("开始异步向量化处理: vectorId={}, thread={}", vectorId, Thread.currentThread().getName());
|
processVector(vectorId);
|
}
|
|
@Override
|
public void processVector(Long vectorId) {
|
log.info("开始处理向量化: vectorId={}", vectorId);
|
KnowledgeBaseVector vector = knowledgeBaseVectorService.getById(vectorId);
|
if (vector == null) {
|
log.error("向量记录不存在: {}", vectorId);
|
return;
|
}
|
|
try {
|
// 更新状态为处理中
|
log.info("更新状态为处理中: vectorId={}", vectorId);
|
knowledgeBaseVectorService.updateVectorStatus(vectorId,
|
KnowledgeBaseVector.STATUS_PROCESSING, null, null);
|
|
// 获取文件内容
|
log.info("获取文件信息: storageBlobId={}", vector.getStorageBlobId());
|
StorageBlob blob = storageBlobService.getById(vector.getStorageBlobId());
|
if (blob == null) {
|
throw new RuntimeException("文件不存在: " + vector.getStorageBlobId());
|
}
|
|
File file = getFile(blob);
|
log.info("文件路径: {}, 是否存在: {}", file.getAbsolutePath(), file.exists());
|
long fileSize = file.length();
|
log.info("文件大小: {} 字节", fileSize);
|
|
// 直接读取文件内容,不使用 MultipartFile 包装
|
log.info("提取文件内容: fileName={}", vector.getFileName());
|
String content = extractFileContent(file, vector.getFileName());
|
log.info("文件内容长度: {}", content != null ? content.length() : 0);
|
|
if (content == null || content.trim().isEmpty()) {
|
throw new RuntimeException("文件内容为空");
|
}
|
|
// 文本切片
|
List<TextSegment> chunks;
|
boolean needChunk = fileSize > CHUNK_THRESHOLD_BYTES || content.length() > EMBEDDING_MAX_LENGTH;
|
if (needChunk) {
|
log.info("开始切片: fileSize={}, contentLength={}", fileSize, content.length());
|
chunks = splitText(content, vector);
|
log.info("切片完成,共 {} 个块", chunks.size());
|
} else {
|
log.info("文件较小且内容长度{}不超过{},不进行切片", content.length(), EMBEDDING_MAX_LENGTH);
|
Map<String, Object> metadata = buildMetadata(vector);
|
chunks = List.of(TextSegment.from(content, new dev.langchain4j.data.document.Metadata(metadata)));
|
}
|
|
// 批量生成嵌入向量并存储
|
int chunkCount = 0;
|
for (TextSegment chunk : chunks) {
|
log.debug("处理第 {} 个块", chunkCount + 1);
|
Embedding embedding = embeddingModel.embed(chunk).content();
|
embeddingStore.add(embedding, chunk);
|
chunkCount++;
|
}
|
|
// 更新状态为完成
|
knowledgeBaseVectorService.updateVectorStatus(vectorId,
|
KnowledgeBaseVector.STATUS_COMPLETED, chunkCount, null);
|
|
log.info("向量化处理完成: vectorId={}, chunkCount={}", vectorId, chunkCount);
|
|
} catch (Exception e) {
|
log.error("向量化处理失败: vectorId={}", vectorId, e);
|
knowledgeBaseVectorService.updateVectorStatus(vectorId,
|
KnowledgeBaseVector.STATUS_FAILED, null, e.getMessage());
|
}
|
}
|
|
@Override
|
public List<String> searchRelevantContent(String namespace, String query, int maxResults) {
|
try {
|
// 生成查询向量
|
Embedding queryEmbedding = embeddingModel.embed(query).content();
|
|
// 构建搜索请求,使用元数据过滤
|
EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
|
.queryEmbedding(queryEmbedding)
|
.maxResults(maxResults)
|
.minScore(0.7)
|
.build();
|
|
EmbeddingSearchResult<TextSegment> searchResult = embeddingStore.search(searchRequest);
|
|
return searchResult.matches().stream()
|
.map(match -> match.embedded().text())
|
.collect(Collectors.toList());
|
|
} catch (Exception e) {
|
log.error("向量检索失败: namespace={}", namespace, e);
|
return new ArrayList<>();
|
}
|
}
|
|
@Override
|
public void deleteEmbeddings(String namespace, Long storageBlobId) {
|
// Pinecone 按命名空间删除需要特定实现
|
// 当前实现:通过 metadata 过滤删除
|
log.info("删除向量数据: namespace={}, storageBlobId={}", namespace, storageBlobId);
|
// 注意:Pinecone 的删除操作需要在 EmbeddingStore 层实现
|
// 当前使用 PineconeEmbeddingStore,可能需要调用 Pinecone 客户端直接删除
|
}
|
|
private File getFile(StorageBlob blob) {
|
String path = blob.getPath();
|
if (path != null && !path.isEmpty()) {
|
return new File(new File(fileProperties.getPath(), path), blob.getUidFilename());
|
}
|
return new File(fileProperties.getPath(), blob.getUidFilename());
|
}
|
|
/**
|
* 提取文件内容
|
*/
|
private String extractFileContent(File file, String fileName) throws Exception {
|
String ext = getFileExtension(fileName);
|
|
// 根据文件类型提取内容
|
if (isPlainText(ext)) {
|
return readFileWithEncoding(file);
|
}
|
|
if ("docx".equals(ext)) {
|
return extractDocx(file);
|
}
|
|
if ("xlsx".equals(ext)) {
|
return extractXlsx(file);
|
}
|
|
if ("xls".equals(ext)) {
|
return extractXls(file);
|
}
|
|
// 默认尝试读取文本
|
return readFileWithEncoding(file);
|
}
|
|
/**
|
* 自动检测文件编码并读取内容
|
* 优先尝试 UTF-8,失败则尝试 GBK
|
*/
|
private String readFileWithEncoding(File file) throws Exception {
|
byte[] bytes = Files.readAllBytes(file.toPath());
|
|
// 先尝试 UTF-8
|
String utf8Content = new String(bytes, StandardCharsets.UTF_8);
|
if (isValidUtf8(utf8Content)) {
|
log.debug("文件编码: UTF-8");
|
return utf8Content;
|
}
|
|
// 尝试 GBK
|
try {
|
Charset gbk = Charset.forName("GBK");
|
String gbkContent = new String(bytes, gbk);
|
log.debug("文件编码: GBK");
|
return gbkContent;
|
} catch (Exception e) {
|
log.warn("编码检测失败,使用 UTF-8");
|
return utf8Content;
|
}
|
}
|
|
/**
|
* 检查 UTF-8 解码是否有效
|
*/
|
private boolean isValidUtf8(String decoded) {
|
// 检查是否包含替换字符(说明 UTF-8 解码失败)
|
if (decoded.contains("�")) {
|
return false;
|
}
|
// 检查是否有过多的非打印字符(乱码特征)
|
int invalidCount = 0;
|
for (int i = 0; i < Math.min(decoded.length(), 1000); i++) {
|
char c = decoded.charAt(i);
|
// 检查私有使用区域或异常的控制字符
|
if ((c >= '' && c <= '') || (c < ' ' && c != '\n' && c != '\r' && c != '\t')) {
|
invalidCount++;
|
}
|
}
|
// 如果无效字符超过 5%,认为是编码错误
|
return invalidCount < Math.min(decoded.length(), 1000) * 0.05;
|
}
|
|
private String getFileExtension(String fileName) {
|
if (fileName == null || !fileName.contains(".")) {
|
return "";
|
}
|
return fileName.substring(fileName.lastIndexOf('.') + 1).toLowerCase();
|
}
|
|
private boolean isPlainText(String ext) {
|
return "txt".equals(ext) || "md".equals(ext) || "json".equals(ext)
|
|| "csv".equals(ext) || "xml".equals(ext) || "yaml".equals(ext)
|
|| "yml".equals(ext);
|
}
|
|
private String extractDocx(File file) throws Exception {
|
try (var doc = new org.apache.poi.xwpf.usermodel.XWPFDocument(new java.io.FileInputStream(file));
|
var extractor = new org.apache.poi.xwpf.extractor.XWPFWordExtractor(doc)) {
|
return extractor.getText();
|
}
|
}
|
|
private String extractXlsx(File file) throws Exception {
|
try (var workbook = new org.apache.poi.xssf.usermodel.XSSFWorkbook(file)) {
|
return extractWorkbook(workbook);
|
}
|
}
|
|
private String extractXls(File file) throws Exception {
|
try (var workbook = new org.apache.poi.hssf.usermodel.HSSFWorkbook(new java.io.FileInputStream(file))) {
|
return extractWorkbook(workbook);
|
}
|
}
|
|
private String extractWorkbook(org.apache.poi.ss.usermodel.Workbook workbook) {
|
StringBuilder text = new StringBuilder();
|
var formatter = new org.apache.poi.ss.usermodel.DataFormatter();
|
for (int i = 0; i < workbook.getNumberOfSheets(); i++) {
|
var sheet = workbook.getSheetAt(i);
|
text.append("Sheet: ").append(sheet.getSheetName()).append("\n");
|
for (var row : sheet) {
|
for (var cell : row) {
|
text.append(formatter.formatCellValue(cell)).append("\t");
|
}
|
text.append("\n");
|
}
|
}
|
return text.toString();
|
}
|
|
/**
|
* 文本切片
|
*/
|
private List<TextSegment> splitText(String content, KnowledgeBaseVector vector) {
|
List<TextSegment> chunks = new ArrayList<>();
|
|
if (content.length() <= CHUNK_SIZE) {
|
Map<String, Object> metadata = buildMetadata(vector);
|
chunks.add(TextSegment.from(content, new dev.langchain4j.data.document.Metadata(metadata)));
|
return chunks;
|
}
|
|
int start = 0;
|
int chunkIndex = 0;
|
while (start < content.length()) {
|
int end = Math.min(start + CHUNK_SIZE, content.length());
|
|
// 尝试在句子边界切分
|
if (end < content.length()) {
|
int lastPeriod = content.lastIndexOf('。', end);
|
int lastNewline = content.lastIndexOf('\n', end);
|
int boundary = Math.max(lastPeriod, lastNewline);
|
if (boundary > start + CHUNK_SIZE / 2) {
|
end = boundary + 1;
|
}
|
}
|
|
String chunkText = content.substring(start, end).trim();
|
if (!chunkText.isEmpty()) {
|
Map<String, Object> metadata = buildMetadata(vector);
|
metadata.put("chunkIndex", chunkIndex);
|
chunks.add(TextSegment.from(chunkText, new dev.langchain4j.data.document.Metadata(metadata)));
|
chunkIndex++;
|
}
|
|
start = end - CHUNK_OVERLAP;
|
if (start < 0) start = 0;
|
if (start >= content.length() - CHUNK_OVERLAP) break;
|
}
|
|
return chunks;
|
}
|
|
private Map<String, Object> buildMetadata(KnowledgeBaseVector vector) {
|
Map<String, Object> metadata = new HashMap<>();
|
metadata.put("knowledgeBaseId", vector.getKnowledgeBaseId());
|
metadata.put("storageBlobId", vector.getStorageBlobId());
|
metadata.put("fileName", vector.getFileName());
|
metadata.put("namespace", vector.getNamespace());
|
return metadata;
|
}
|
}
|