| | |
| | | import dev.langchain4j.data.message.ChatMessageDeserializer; |
| | | import dev.langchain4j.data.message.ChatMessageSerializer; |
| | | import dev.langchain4j.store.memory.chat.ChatMemoryStore; |
| | | import org.springframework.beans.factory.annotation.Autowired; |
| | | import lombok.RequiredArgsConstructor; |
| | | import org.springframework.data.mongodb.core.MongoTemplate; |
| | | import org.springframework.data.mongodb.core.query.Criteria; |
| | | import org.springframework.data.mongodb.core.query.Query; |
| | | import org.springframework.data.mongodb.core.query.Update; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import java.util.Date; |
| | | import java.util.LinkedList; |
| | | import java.util.List; |
| | | |
| | | /** |
| | | * @author :yys |
| | | * @date : 2025/5/2 19:18 |
| | | */ |
| | | @Component |
| | | @RequiredArgsConstructor |
| | | public class MongoChatMemoryStore implements ChatMemoryStore { |
| | | |
| | | @Autowired |
| | | private MongoTemplate mongoTemplate; |
| | | private final MongoTemplate mongoTemplate; |
| | | |
| | | @Override |
| | | public List<ChatMessage> getMessages(Object memoryId) { |
| | | Criteria criteria = Criteria.where("memoryId").is(memoryId); |
| | | Query query = new Query(criteria); |
| | | Query query = Query.query(Criteria.where("memoryId").is(memoryIdString(memoryId))); |
| | | ChatMessages chatMessages = mongoTemplate.findOne(query, ChatMessages.class); |
| | | if(chatMessages == null) return new LinkedList<>(); |
| | | if (chatMessages == null || chatMessages.getContent() == null) { |
| | | return new LinkedList<>(); |
| | | } |
| | | return ChatMessageDeserializer.messagesFromJson(chatMessages.getContent()); |
| | | } |
| | | |
| | | @Override |
| | | public void updateMessages(Object memoryId, List<ChatMessage> messages) { |
| | | Criteria criteria = Criteria.where("memoryId").is(memoryId); |
| | | Query query = new Query(criteria); |
| | | String memoryIdValue = memoryIdString(memoryId); |
| | | Query query = Query.query(Criteria.where("memoryId").is(memoryIdValue)); |
| | | Update update = new Update(); |
| | | update.set("memoryId", memoryIdValue); |
| | | update.set("content", ChatMessageSerializer.messagesToJson(messages)); |
| | | //根据query条件能查询出文档,则修改文档;否则新增文档 |
| | | update.set("updateTime", new Date()); |
| | | update.setOnInsert("createTime", new Date()); |
| | | mongoTemplate.upsert(query, update, ChatMessages.class); |
| | | } |
| | | |
| | | @Override |
| | | public void deleteMessages(Object memoryId) { |
| | | Criteria criteria = Criteria.where("memoryId").is(memoryId); |
| | | Query query = new Query(criteria); |
| | | Query query = Query.query(Criteria.where("memoryId").is(memoryIdString(memoryId))); |
| | | mongoTemplate.remove(query, ChatMessages.class); |
| | | } |
| | | |
| | | public void appendMessages(Object memoryId, List<ChatMessage> appendList) { |
| | | List<ChatMessage> messages = new LinkedList<>(getMessages(memoryId)); |
| | | messages.addAll(appendList); |
| | | updateMessages(memoryId, messages); |
| | | } |
| | | |
| | | private String memoryIdString(Object memoryId) { |
| | | return memoryId == null ? "" : memoryId.toString(); |
| | | } |
| | | } |