package com.ruoyi.ai.store;
|
|
import com.ruoyi.ai.mongodbBean.ChatMessages;
|
import dev.langchain4j.data.message.ChatMessage;
|
import dev.langchain4j.data.message.ChatMessageDeserializer;
|
import dev.langchain4j.data.message.ChatMessageSerializer;
|
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.data.mongodb.core.MongoTemplate;
|
import org.springframework.data.mongodb.core.query.Criteria;
|
import org.springframework.data.mongodb.core.query.Query;
|
import org.springframework.data.mongodb.core.query.Update;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.StringUtils;
|
|
import java.util.Date;
|
import java.util.LinkedList;
|
import java.util.List;
|
|
@Component
|
@RequiredArgsConstructor
|
public class MongoChatMemoryStore implements ChatMemoryStore {
|
|
private final MongoTemplate mongoTemplate;
|
|
@Override
|
public List<ChatMessage> getMessages(Object memoryId) {
|
ChatMessages chatMessages = findChatMessages(memoryId);
|
if (chatMessages == null || chatMessages.getContent() == null) {
|
return new LinkedList<>();
|
}
|
return ChatMessageDeserializer.messagesFromJson(chatMessages.getContent());
|
}
|
|
@Override
|
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
|
String memoryIdValue = memoryIdString(memoryId);
|
Query query = Query.query(Criteria.where("memoryId").is(memoryIdValue));
|
Update update = new Update();
|
update.set("memoryId", memoryIdValue);
|
update.set("content", ChatMessageSerializer.messagesToJson(messages));
|
update.set("updateTime", new Date());
|
update.setOnInsert("createTime", new Date());
|
mongoTemplate.upsert(query, update, ChatMessages.class);
|
}
|
|
@Override
|
public void deleteMessages(Object memoryId) {
|
Query query = Query.query(Criteria.where("memoryId").is(memoryIdString(memoryId)));
|
mongoTemplate.remove(query, ChatMessages.class);
|
}
|
|
public void appendMessages(Object memoryId, List<ChatMessage> appendList) {
|
List<ChatMessage> messages = new LinkedList<>(getMessages(memoryId));
|
messages.addAll(appendList);
|
updateMessages(memoryId, messages);
|
}
|
|
public void appendAnalyzeFileContext(Object memoryId, String userQuestion, List<String> filePaths) {
|
String memoryIdValue = memoryIdString(memoryId);
|
if (!StringUtils.hasText(memoryIdValue)) {
|
return;
|
}
|
List<String> validFilePaths = new LinkedList<>();
|
if (!CollectionUtils.isEmpty(filePaths)) {
|
for (String filePath : filePaths) {
|
if (StringUtils.hasText(filePath)) {
|
validFilePaths.add(filePath);
|
}
|
}
|
}
|
if (!StringUtils.hasText(userQuestion) && validFilePaths.isEmpty()) {
|
return;
|
}
|
Query query = Query.query(Criteria.where("memoryId").is(memoryIdValue));
|
Update update = new Update();
|
update.set("memoryId", memoryIdValue);
|
update.set("updateTime", new Date());
|
update.setOnInsert("createTime", new Date());
|
if (StringUtils.hasText(userQuestion)) {
|
update.push("analyzeUserQuestions", userQuestion);
|
}
|
if (!validFilePaths.isEmpty()) {
|
update.push("analyzeFilePaths").each(validFilePaths.toArray());
|
update.push("analyzeFilePathGroups", validFilePaths);
|
}
|
mongoTemplate.upsert(query, update, ChatMessages.class);
|
}
|
|
public List<String> getAnalyzeUserQuestions(Object memoryId) {
|
ChatMessages chatMessages = findChatMessages(memoryId);
|
if (chatMessages == null || CollectionUtils.isEmpty(chatMessages.getAnalyzeUserQuestions())) {
|
return new LinkedList<>();
|
}
|
return new LinkedList<>(chatMessages.getAnalyzeUserQuestions());
|
}
|
|
public List<List<String>> getAnalyzeFilePathGroups(Object memoryId) {
|
ChatMessages chatMessages = findChatMessages(memoryId);
|
if (chatMessages == null) {
|
return new LinkedList<>();
|
}
|
if (CollectionUtils.isEmpty(chatMessages.getAnalyzeFilePathGroups())) {
|
if (CollectionUtils.isEmpty(chatMessages.getAnalyzeFilePaths())) {
|
return new LinkedList<>();
|
}
|
List<List<String>> fallback = new LinkedList<>();
|
fallback.add(new LinkedList<>(chatMessages.getAnalyzeFilePaths()));
|
return fallback;
|
}
|
List<List<String>> groups = new LinkedList<>();
|
for (List<String> group : chatMessages.getAnalyzeFilePathGroups()) {
|
if (CollectionUtils.isEmpty(group)) {
|
groups.add(new LinkedList<>());
|
} else {
|
groups.add(new LinkedList<>(group));
|
}
|
}
|
return groups;
|
}
|
|
private String memoryIdString(Object memoryId) {
|
return memoryId == null ? "" : memoryId.toString();
|
}
|
|
private ChatMessages findChatMessages(Object memoryId) {
|
Query query = Query.query(Criteria.where("memoryId").is(memoryIdString(memoryId)));
|
return mongoTemplate.findOne(query, ChatMessages.class);
|
}
|
}
|