package com.ruoyi.ai.service.impl;
|
|
import com.ruoyi.ai.service.AiFileTextExtractor;
|
import com.ruoyi.ai.service.KnowledgeRagService;
|
import com.ruoyi.approve.pojo.KnowledgeBaseVector;
|
import com.ruoyi.approve.service.KnowledgeBaseVectorService;
|
import com.ruoyi.basic.pojo.StorageBlob;
|
import com.ruoyi.basic.service.StorageBlobService;
|
import com.ruoyi.common.config.FileProperties;
|
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.scheduling.annotation.Async;
|
import org.springframework.stereotype.Service;
|
|
import java.io.File;
|
import java.nio.file.Files;
|
import java.util.ArrayList;
|
import java.util.HashMap;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.stream.Collectors;
|
|
/**
|
* 知识库RAG服务实现
|
*/
|
@Slf4j
|
@Service
|
@RequiredArgsConstructor
|
public class KnowledgeRagServiceImpl implements KnowledgeRagService {
|
|
private final KnowledgeBaseVectorService knowledgeBaseVectorService;
|
private final StorageBlobService storageBlobService;
|
private final AiFileTextExtractor aiFileTextExtractor;
|
private final EmbeddingModel embeddingModel;
|
private final EmbeddingStore<TextSegment> embeddingStore;
|
private final FileProperties fileProperties;
|
|
private static final int CHUNK_SIZE = 500;
|
private static final int CHUNK_OVERLAP = 100;
|
|
@Override
|
@Async
|
public void processVectorAsync(Long vectorId) {
|
processVector(vectorId);
|
}
|
|
@Override
|
public void processVector(Long vectorId) {
|
KnowledgeBaseVector vector = knowledgeBaseVectorService.getById(vectorId);
|
if (vector == null) {
|
log.error("向量记录不存在: {}", vectorId);
|
return;
|
}
|
|
try {
|
// 更新状态为处理中
|
knowledgeBaseVectorService.updateVectorStatus(vectorId,
|
KnowledgeBaseVector.STATUS_PROCESSING, null, null);
|
|
// 获取文件内容
|
StorageBlob blob = storageBlobService.getById(vector.getStorageBlobId());
|
if (blob == null) {
|
throw new RuntimeException("文件不存在: " + vector.getStorageBlobId());
|
}
|
|
File file = getFile(blob);
|
|
// 直接读取文件内容,不使用 MultipartFile 包装
|
String content = extractFileContent(file, vector.getFileName(), blob.getContentType());
|
|
if (content == null || content.trim().isEmpty()) {
|
throw new RuntimeException("文件内容为空");
|
}
|
|
// 文本切片
|
List<TextSegment> chunks = splitText(content, vector);
|
|
// 批量生成嵌入向量并存储
|
int chunkCount = 0;
|
for (TextSegment chunk : chunks) {
|
Embedding embedding = embeddingModel.embed(chunk).content();
|
embeddingStore.add(embedding, chunk);
|
chunkCount++;
|
}
|
|
// 更新状态为完成
|
knowledgeBaseVectorService.updateVectorStatus(vectorId,
|
KnowledgeBaseVector.STATUS_COMPLETED, chunkCount, null);
|
|
log.info("向量化处理完成: vectorId={}, chunkCount={}", vectorId, chunkCount);
|
|
} catch (Exception e) {
|
log.error("向量化处理失败: vectorId={}", vectorId, e);
|
knowledgeBaseVectorService.updateVectorStatus(vectorId,
|
KnowledgeBaseVector.STATUS_FAILED, null, e.getMessage());
|
}
|
}
|
|
@Override
|
public List<String> searchRelevantContent(String namespace, String query, int maxResults) {
|
try {
|
// 生成查询向量
|
Embedding queryEmbedding = embeddingModel.embed(query).content();
|
|
// 构建搜索请求,使用元数据过滤
|
EmbeddingSearchRequest searchRequest = EmbeddingSearchRequest.builder()
|
.queryEmbedding(queryEmbedding)
|
.maxResults(maxResults)
|
.minScore(0.7)
|
.build();
|
|
EmbeddingSearchResult<TextSegment> searchResult = embeddingStore.search(searchRequest);
|
|
return searchResult.matches().stream()
|
.map(match -> match.embedded().text())
|
.collect(Collectors.toList());
|
|
} catch (Exception e) {
|
log.error("向量检索失败: namespace={}", namespace, e);
|
return new ArrayList<>();
|
}
|
}
|
|
@Override
|
public void deleteEmbeddings(String namespace, Long storageBlobId) {
|
// Pinecone 按命名空间删除需要特定实现
|
// 当前实现:通过 metadata 过滤删除
|
log.info("删除向量数据: namespace={}, storageBlobId={}", namespace, storageBlobId);
|
// 注意:Pinecone 的删除操作需要在 EmbeddingStore 层实现
|
// 当前使用 PineconeEmbeddingStore,可能需要调用 Pinecone 客户端直接删除
|
}
|
|
private File getFile(StorageBlob blob) {
|
String path = blob.getPath();
|
if (path != null && !path.isEmpty()) {
|
return new File(new File(fileProperties.getPath(), path), blob.getUidFilename());
|
}
|
return new File(fileProperties.getPath(), blob.getUidFilename());
|
}
|
|
/**
|
* 提取文件内容
|
*/
|
private String extractFileContent(File file, String fileName, String contentType) throws Exception {
|
String ext = getFileExtension(fileName);
|
|
// 根据文件类型提取内容
|
if (isPlainText(ext)) {
|
return Files.readString(file.toPath());
|
}
|
|
if ("docx".equals(ext)) {
|
return extractDocx(file);
|
}
|
|
if ("xlsx".equals(ext)) {
|
return extractXlsx(file);
|
}
|
|
if ("xls".equals(ext)) {
|
return extractXls(file);
|
}
|
|
// 默认尝试读取文本
|
return Files.readString(file.toPath());
|
}
|
|
private String getFileExtension(String fileName) {
|
if (fileName == null || !fileName.contains(".")) {
|
return "";
|
}
|
return fileName.substring(fileName.lastIndexOf('.') + 1).toLowerCase();
|
}
|
|
private boolean isPlainText(String ext) {
|
return "txt".equals(ext) || "md".equals(ext) || "json".equals(ext)
|
|| "csv".equals(ext) || "xml".equals(ext) || "yaml".equals(ext)
|
|| "yml".equals(ext);
|
}
|
|
private String extractDocx(File file) throws Exception {
|
try (var doc = new org.apache.poi.xwpf.usermodel.XWPFDocument(new java.io.FileInputStream(file));
|
var extractor = new org.apache.poi.xwpf.extractor.XWPFWordExtractor(doc)) {
|
return extractor.getText();
|
}
|
}
|
|
private String extractXlsx(File file) throws Exception {
|
try (var workbook = new org.apache.poi.xssf.usermodel.XSSFWorkbook(file)) {
|
return extractWorkbook(workbook);
|
}
|
}
|
|
private String extractXls(File file) throws Exception {
|
try (var workbook = new org.apache.poi.hssf.usermodel.HSSFWorkbook(new java.io.FileInputStream(file))) {
|
return extractWorkbook(workbook);
|
}
|
}
|
|
private String extractWorkbook(org.apache.poi.ss.usermodel.Workbook workbook) {
|
StringBuilder text = new StringBuilder();
|
var formatter = new org.apache.poi.ss.usermodel.DataFormatter();
|
for (int i = 0; i < workbook.getNumberOfSheets(); i++) {
|
var sheet = workbook.getSheetAt(i);
|
text.append("Sheet: ").append(sheet.getSheetName()).append("\n");
|
for (var row : sheet) {
|
for (var cell : row) {
|
text.append(formatter.formatCellValue(cell)).append("\t");
|
}
|
text.append("\n");
|
}
|
}
|
return text.toString();
|
}
|
|
/**
|
* 文本切片
|
*/
|
private List<TextSegment> splitText(String content, KnowledgeBaseVector vector) {
|
List<TextSegment> chunks = new ArrayList<>();
|
|
if (content.length() <= CHUNK_SIZE) {
|
Map<String, Object> metadata = buildMetadata(vector);
|
chunks.add(TextSegment.from(content, new dev.langchain4j.data.document.Metadata(metadata)));
|
return chunks;
|
}
|
|
int start = 0;
|
int chunkIndex = 0;
|
while (start < content.length()) {
|
int end = Math.min(start + CHUNK_SIZE, content.length());
|
|
// 尝试在句子边界切分
|
if (end < content.length()) {
|
int lastPeriod = content.lastIndexOf('。', end);
|
int lastNewline = content.lastIndexOf('\n', end);
|
int boundary = Math.max(lastPeriod, lastNewline);
|
if (boundary > start + CHUNK_SIZE / 2) {
|
end = boundary + 1;
|
}
|
}
|
|
String chunkText = content.substring(start, end).trim();
|
if (!chunkText.isEmpty()) {
|
Map<String, Object> metadata = buildMetadata(vector);
|
metadata.put("chunkIndex", chunkIndex);
|
chunks.add(TextSegment.from(chunkText, new dev.langchain4j.data.document.Metadata(metadata)));
|
chunkIndex++;
|
}
|
|
start = end - CHUNK_OVERLAP;
|
if (start < 0) start = 0;
|
if (start >= content.length() - CHUNK_OVERLAP) break;
|
}
|
|
return chunks;
|
}
|
|
private Map<String, Object> buildMetadata(KnowledgeBaseVector vector) {
|
Map<String, Object> metadata = new HashMap<>();
|
metadata.put("knowledgeBaseId", vector.getKnowledgeBaseId());
|
metadata.put("storageBlobId", vector.getStorageBlobId());
|
metadata.put("fileName", vector.getFileName());
|
metadata.put("namespace", vector.getNamespace());
|
return metadata;
|
}
|
}
|