6 天以前 ea5a55deffa6d33048a1f7e03b71424c8add5e31
src/main/java/com/ruoyi/ai/store/MongoChatMemoryStore.java
@@ -5,47 +5,127 @@
import dev.langchain4j.data.message.ChatMessageDeserializer;
import dev.langchain4j.data.message.ChatMessageSerializer;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import org.springframework.beans.factory.annotation.Autowired;
import lombok.RequiredArgsConstructor;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;
/**
 * @author :yys
 * @date : 2025/5/2 19:18
 */
@Component
@RequiredArgsConstructor
public class MongoChatMemoryStore implements ChatMemoryStore {
    @Autowired
    private MongoTemplate mongoTemplate;
    private final MongoTemplate mongoTemplate;
    @Override
    public List<ChatMessage> getMessages(Object memoryId) {
        Criteria criteria = Criteria.where("memoryId").is(memoryId);
        Query query = new Query(criteria);
        ChatMessages chatMessages = mongoTemplate.findOne(query, ChatMessages.class);
        if(chatMessages == null) return new LinkedList<>();
        ChatMessages chatMessages = findChatMessages(memoryId);
        if (chatMessages == null || chatMessages.getContent() == null) {
            return new LinkedList<>();
        }
        return ChatMessageDeserializer.messagesFromJson(chatMessages.getContent());
    }
    @Override
    public void updateMessages(Object memoryId, List<ChatMessage> messages) {
        Criteria criteria = Criteria.where("memoryId").is(memoryId);
        Query query = new Query(criteria);
        String memoryIdValue = memoryIdString(memoryId);
        Query query = Query.query(Criteria.where("memoryId").is(memoryIdValue));
        Update update = new Update();
        update.set("memoryId", memoryIdValue);
        update.set("content", ChatMessageSerializer.messagesToJson(messages));
        //根据query条件能查询出文档,则修改文档;否则新增文档
        update.set("updateTime", new Date());
        update.setOnInsert("createTime", new Date());
        mongoTemplate.upsert(query, update, ChatMessages.class);
    }
    @Override
    public void deleteMessages(Object memoryId) {
        Criteria criteria = Criteria.where("memoryId").is(memoryId);
        Query query = new Query(criteria);
        Query query = Query.query(Criteria.where("memoryId").is(memoryIdString(memoryId)));
        mongoTemplate.remove(query, ChatMessages.class);
    }
    public void appendMessages(Object memoryId, List<ChatMessage> appendList) {
        List<ChatMessage> messages = new LinkedList<>(getMessages(memoryId));
        messages.addAll(appendList);
        updateMessages(memoryId, messages);
    }
    public void appendAnalyzeFileContext(Object memoryId, String userQuestion, List<String> filePaths) {
        String memoryIdValue = memoryIdString(memoryId);
        if (!StringUtils.hasText(memoryIdValue)) {
            return;
        }
        List<String> validFilePaths = new LinkedList<>();
        if (!CollectionUtils.isEmpty(filePaths)) {
            for (String filePath : filePaths) {
                if (StringUtils.hasText(filePath)) {
                    validFilePaths.add(filePath);
                }
            }
        }
        if (!StringUtils.hasText(userQuestion) && validFilePaths.isEmpty()) {
            return;
        }
        Query query = Query.query(Criteria.where("memoryId").is(memoryIdValue));
        Update update = new Update();
        update.set("memoryId", memoryIdValue);
        update.set("updateTime", new Date());
        update.setOnInsert("createTime", new Date());
        if (StringUtils.hasText(userQuestion)) {
            update.push("analyzeUserQuestions", userQuestion);
        }
        if (!validFilePaths.isEmpty()) {
            update.push("analyzeFilePaths").each(validFilePaths.toArray());
            update.push("analyzeFilePathGroups", validFilePaths);
        }
        mongoTemplate.upsert(query, update, ChatMessages.class);
    }
    public List<String> getAnalyzeUserQuestions(Object memoryId) {
        ChatMessages chatMessages = findChatMessages(memoryId);
        if (chatMessages == null || CollectionUtils.isEmpty(chatMessages.getAnalyzeUserQuestions())) {
            return new LinkedList<>();
        }
        return new LinkedList<>(chatMessages.getAnalyzeUserQuestions());
    }
    public List<List<String>> getAnalyzeFilePathGroups(Object memoryId) {
        ChatMessages chatMessages = findChatMessages(memoryId);
        if (chatMessages == null) {
            return new LinkedList<>();
        }
        if (CollectionUtils.isEmpty(chatMessages.getAnalyzeFilePathGroups())) {
            if (CollectionUtils.isEmpty(chatMessages.getAnalyzeFilePaths())) {
                return new LinkedList<>();
            }
            List<List<String>> fallback = new LinkedList<>();
            fallback.add(new LinkedList<>(chatMessages.getAnalyzeFilePaths()));
            return fallback;
        }
        List<List<String>> groups = new LinkedList<>();
        for (List<String> group : chatMessages.getAnalyzeFilePathGroups()) {
            if (CollectionUtils.isEmpty(group)) {
                groups.add(new LinkedList<>());
            } else {
                groups.add(new LinkedList<>(group));
            }
        }
        return groups;
    }
    private String memoryIdString(Object memoryId) {
        return memoryId == null ? "" : memoryId.toString();
    }
    private ChatMessages findChatMessages(Object memoryId) {
        Query query = Query.query(Criteria.where("memoryId").is(memoryIdString(memoryId)));
        return mongoTemplate.findOne(query, ChatMessages.class);
    }
}