package com.ruoyi.ai.service.impl;
|
|
import com.ruoyi.ai.service.KnowledgeRagService;
|
import com.ruoyi.approve.pojo.KnowledgeBaseVector;
|
import com.ruoyi.approve.service.KnowledgeBaseVectorService;
|
import com.ruoyi.basic.pojo.StorageBlob;
|
import com.ruoyi.basic.service.StorageBlobService;
|
import com.ruoyi.common.config.FileProperties;
|
import com.google.protobuf.Struct;
|
import com.google.protobuf.Value;
|
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import io.pinecone.clients.Index;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.scheduling.annotation.Async;
|
import org.springframework.stereotype.Service;
|
|
import java.io.File;
|
import java.nio.charset.Charset;
|
import java.nio.charset.StandardCharsets;
|
import java.nio.file.Files;
|
import java.util.ArrayList;
|
import java.util.HashMap;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.stream.Collectors;
|
|
/**
|
* 知识库RAG服务实现
|
*/
|
@Slf4j
|
@Service
|
public class KnowledgeRagServiceImpl implements KnowledgeRagService {
|
|
private final KnowledgeBaseVectorService knowledgeBaseVectorService;
|
private final StorageBlobService storageBlobService;
|
private final EmbeddingModel embeddingModel;
|
private final EmbeddingStore<TextSegment> embeddingStore;
|
private final FileProperties fileProperties;
|
private final Index pineconeIndex;
|
|
@org.springframework.beans.factory.annotation.Value("${pinecone.namespace:knowledge-base}")
|
private String namespace;
|
|
public KnowledgeRagServiceImpl(
|
KnowledgeBaseVectorService knowledgeBaseVectorService,
|
StorageBlobService storageBlobService,
|
EmbeddingModel embeddingModel,
|
EmbeddingStore<TextSegment> embeddingStore,
|
FileProperties fileProperties,
|
Index pineconeIndex) {
|
this.knowledgeBaseVectorService = knowledgeBaseVectorService;
|
this.storageBlobService = storageBlobService;
|
this.embeddingModel = embeddingModel;
|
this.embeddingStore = embeddingStore;
|
this.fileProperties = fileProperties;
|
this.pineconeIndex = pineconeIndex;
|
}
|
|
private static final int CHUNK_SIZE = 500;
|
private static final int CHUNK_OVERLAP = 100;
|
private static final long CHUNK_THRESHOLD_BYTES = 80L * 1024 * 1024;
|
private static final int EMBEDDING_MAX_LENGTH = 8000;
|
|
@Override
|
@Async("threadPoolTaskExecutor")
|
public void processVectorAsync(Long vectorId) {
|
log.info("开始异步向量化处理: vectorId={}, thread={}", vectorId, Thread.currentThread().getName());
|
processVector(vectorId);
|
}
|
|
@Override
|
public void processVector(Long vectorId) {
|
log.info("开始处理向量化: vectorId={}", vectorId);
|
KnowledgeBaseVector vector = knowledgeBaseVectorService.getById(vectorId);
|
if (vector == null) {
|
log.error("向量记录不存在: {}", vectorId);
|
return;
|
}
|
|
try {
|
knowledgeBaseVectorService.updateVectorStatus(vectorId,
|
KnowledgeBaseVector.STATUS_PROCESSING, null, null);
|
|
StorageBlob blob = storageBlobService.getById(vector.getStorageBlobId());
|
if (blob == null) {
|
throw new RuntimeException("文件不存在: " + vector.getStorageBlobId());
|
}
|
|
File file = getFile(blob);
|
log.info("文件路径: {}, 是否存在: {}", file.getAbsolutePath(), file.exists());
|
long fileSize = file.length();
|
|
String content = extractFileContent(file, vector.getFileName());
|
log.info("文件内容长度: {}", content != null ? content.length() : 0);
|
|
if (content == null || content.trim().isEmpty()) {
|
throw new RuntimeException("文件内容为空");
|
}
|
|
List<TextSegment> chunks;
|
boolean needChunk = fileSize > CHUNK_THRESHOLD_BYTES || content.length() > EMBEDDING_MAX_LENGTH;
|
if (needChunk) {
|
log.info("开始切片: fileSize={}, contentLength={}", fileSize, content.length());
|
chunks = splitText(content, vector);
|
log.info("切片完成,共 {} 个块", chunks.size());
|
} else {
|
log.info("文件较小,不进行切片");
|
Map<String, Object> metadata = buildMetadata(vector);
|
chunks = List.of(TextSegment.from(content, new dev.langchain4j.data.document.Metadata(metadata)));
|
}
|
|
int chunkCount = 0;
|
for (TextSegment chunk : chunks) {
|
Embedding embedding = embeddingModel.embed(chunk).content();
|
embeddingStore.add(embedding, chunk);
|
chunkCount++;
|
}
|
|
knowledgeBaseVectorService.updateVectorStatus(vectorId,
|
KnowledgeBaseVector.STATUS_COMPLETED, chunkCount, null);
|
|
log.info("向量化处理完成: vectorId={}, chunkCount={}", vectorId, chunkCount);
|
|
} catch (Exception e) {
|
log.error("向量化处理失败: vectorId={}", vectorId, e);
|
knowledgeBaseVectorService.updateVectorStatus(vectorId,
|
KnowledgeBaseVector.STATUS_FAILED, null, e.getMessage());
|
}
|
}
|
|
@Override
|
public List<String> searchRelevantContent(String namespace, String query, int maxResults) {
|
try {
|
Embedding queryEmbedding = embeddingModel.embed(query).content();
|
|
EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
|
.queryEmbedding(queryEmbedding)
|
.maxResults(maxResults)
|
.minScore(0.7)
|
.build();
|
|
EmbeddingSearchResult<TextSegment> searchResult = embeddingStore.search(searchRequest);
|
|
return searchResult.matches().stream()
|
.map(match -> match.embedded().text())
|
.collect(Collectors.toList());
|
|
} catch (Exception e) {
|
log.error("向量检索失败: namespace={}", namespace, e);
|
return new ArrayList<>();
|
}
|
}
|
|
@Override
|
public void deleteEmbeddings(String namespace, Long storageBlobId) {
|
log.info("删除向量数据: namespace={}, storageBlobId={}", namespace, storageBlobId);
|
try {
|
Struct filter = Struct.newBuilder()
|
.putFields("storageBlobId", Value.newBuilder()
|
.setStructValue(Struct.newBuilder()
|
.putFields("$eq", Value.newBuilder()
|
.setNumberValue(storageBlobId.doubleValue())
|
.build()))
|
.build())
|
.build();
|
|
List<String> emptyIds = new ArrayList<>();
|
pineconeIndex.delete(emptyIds, false, this.namespace, filter);
|
log.info("向量删除完成: storageBlobId={}", storageBlobId);
|
} catch (Exception e) {
|
log.error("删除向量数据失败: namespace={}, storageBlobId={}", namespace, storageBlobId, e);
|
}
|
}
|
|
private File getFile(StorageBlob blob) {
|
String path = blob.getPath();
|
if (path != null && !path.isEmpty()) {
|
return new File(new File(fileProperties.getPath(), path), blob.getUidFilename());
|
}
|
return new File(fileProperties.getPath(), blob.getUidFilename());
|
}
|
|
private String extractFileContent(File file, String fileName) throws Exception {
|
String ext = getFileExtension(fileName);
|
|
if (isPlainText(ext)) {
|
return readFileWithEncoding(file);
|
}
|
|
if ("docx".equals(ext)) {
|
return extractDocx(file);
|
}
|
|
if ("xlsx".equals(ext)) {
|
return extractXlsx(file);
|
}
|
|
if ("xls".equals(ext)) {
|
return extractXls(file);
|
}
|
|
return readFileWithEncoding(file);
|
}
|
|
private String readFileWithEncoding(File file) throws Exception {
|
byte[] bytes = Files.readAllBytes(file.toPath());
|
|
String utf8Content = new String(bytes, StandardCharsets.UTF_8);
|
if (isValidUtf8(utf8Content)) {
|
log.debug("文件编码: UTF-8");
|
return utf8Content;
|
}
|
|
try {
|
Charset gbk = Charset.forName("GBK");
|
String gbkContent = new String(bytes, gbk);
|
log.debug("文件编码: GBK");
|
return gbkContent;
|
} catch (Exception e) {
|
log.warn("编码检测失败,使用 UTF-8");
|
return utf8Content;
|
}
|
}
|
|
private boolean isValidUtf8(String decoded) {
|
// 检查替换字符 U+FFFD (UTF-8 解码失败时出现)
|
if (decoded.contains("�")) {
|
return false;
|
}
|
int invalidCount = 0;
|
int checkLen = Math.min(decoded.length(), 1000);
|
for (int i = 0; i < checkLen; i++) {
|
char c = decoded.charAt(i);
|
// 检查私有使用区域 (U+E000-U+F8FF) 或异常控制字符
|
if ((c >= '' && c <= '') || (c < ' ' && c != '\n' && c != '\r' && c != '\t')) {
|
invalidCount++;
|
}
|
}
|
return invalidCount < checkLen * 0.05;
|
}
|
|
private String getFileExtension(String fileName) {
|
if (fileName == null || !fileName.contains(".")) {
|
return "";
|
}
|
return fileName.substring(fileName.lastIndexOf('.') + 1).toLowerCase();
|
}
|
|
private boolean isPlainText(String ext) {
|
return "txt".equals(ext) || "md".equals(ext) || "json".equals(ext)
|
|| "csv".equals(ext) || "xml".equals(ext) || "yaml".equals(ext)
|
|| "yml".equals(ext);
|
}
|
|
private String extractDocx(File file) throws Exception {
|
try (var doc = new org.apache.poi.xwpf.usermodel.XWPFDocument(new java.io.FileInputStream(file));
|
var extractor = new org.apache.poi.xwpf.extractor.XWPFWordExtractor(doc)) {
|
return extractor.getText();
|
}
|
}
|
|
private String extractXlsx(File file) throws Exception {
|
try (var workbook = new org.apache.poi.xssf.usermodel.XSSFWorkbook(file)) {
|
return extractWorkbook(workbook);
|
}
|
}
|
|
private String extractXls(File file) throws Exception {
|
try (var workbook = new org.apache.poi.hssf.usermodel.HSSFWorkbook(new java.io.FileInputStream(file))) {
|
return extractWorkbook(workbook);
|
}
|
}
|
|
private String extractWorkbook(org.apache.poi.ss.usermodel.Workbook workbook) {
|
StringBuilder text = new StringBuilder();
|
var formatter = new org.apache.poi.ss.usermodel.DataFormatter();
|
for (int i = 0; i < workbook.getNumberOfSheets(); i++) {
|
var sheet = workbook.getSheetAt(i);
|
text.append("Sheet: ").append(sheet.getSheetName()).append("\n");
|
for (var row : sheet) {
|
for (var cell : row) {
|
text.append(formatter.formatCellValue(cell)).append("\t");
|
}
|
text.append("\n");
|
}
|
}
|
return text.toString();
|
}
|
|
private List<TextSegment> splitText(String content, KnowledgeBaseVector vector) {
|
List<TextSegment> chunks = new ArrayList<>();
|
|
if (content.length() <= CHUNK_SIZE) {
|
Map<String, Object> metadata = buildMetadata(vector);
|
chunks.add(TextSegment.from(content, new dev.langchain4j.data.document.Metadata(metadata)));
|
return chunks;
|
}
|
|
int start = 0;
|
int chunkIndex = 0;
|
while (start < content.length()) {
|
int end = Math.min(start + CHUNK_SIZE, content.length());
|
|
if (end < content.length()) {
|
int lastPeriod = content.lastIndexOf('。', end);
|
int lastNewline = content.lastIndexOf('\n', end);
|
int boundary = Math.max(lastPeriod, lastNewline);
|
if (boundary > start + CHUNK_SIZE / 2) {
|
end = boundary + 1;
|
}
|
}
|
|
String chunkText = content.substring(start, end).trim();
|
if (!chunkText.isEmpty()) {
|
Map<String, Object> metadata = buildMetadata(vector);
|
metadata.put("chunkIndex", chunkIndex);
|
chunks.add(TextSegment.from(chunkText, new dev.langchain4j.data.document.Metadata(metadata)));
|
chunkIndex++;
|
}
|
|
start = end - CHUNK_OVERLAP;
|
if (start < 0) start = 0;
|
if (start >= content.length() - CHUNK_OVERLAP) break;
|
}
|
|
return chunks;
|
}
|
|
private Map<String, Object> buildMetadata(KnowledgeBaseVector vector) {
|
Map<String, Object> metadata = new HashMap<>();
|
metadata.put("knowledgeBaseId", vector.getKnowledgeBaseId());
|
metadata.put("storageBlobId", vector.getStorageBlobId());
|
metadata.put("fileName", vector.getFileName());
|
metadata.put("namespace", vector.getNamespace());
|
return metadata;
|
}
|
}
|