synaptic_memory/
summary_buffer.rs1use std::{collections::HashMap, sync::Arc};
2
3use async_trait::async_trait;
4use synaptic_core::{ChatModel, ChatRequest, MemoryStore, Message, SynapticError};
5use tokio::sync::RwLock;
6
7pub struct ConversationSummaryBufferMemory {
13 store: Arc<dyn MemoryStore>,
14 model: Arc<dyn ChatModel>,
15 summary: Arc<RwLock<HashMap<String, String>>>,
16 max_token_limit: usize,
17}
18
19impl ConversationSummaryBufferMemory {
20 pub fn new(
26 store: Arc<dyn MemoryStore>,
27 model: Arc<dyn ChatModel>,
28 max_token_limit: usize,
29 ) -> Self {
30 Self {
31 store,
32 model,
33 summary: Arc::new(RwLock::new(HashMap::new())),
34 max_token_limit,
35 }
36 }
37
38 fn estimate_tokens(text: &str) -> usize {
39 (text.len() / 4).max(1)
40 }
41
42 async fn summarize(&self, messages: &[Message]) -> Result<String, SynapticError> {
43 let conversation = messages
44 .iter()
45 .map(|m| format!("{}: {}", m.role(), m.content()))
46 .collect::<Vec<_>>()
47 .join("\n");
48
49 let prompt = format!("Summarize the following conversation concisely:\n\n{conversation}");
50 let request = ChatRequest::new(vec![Message::human(prompt)]);
51 let response = self.model.chat(request).await?;
52 Ok(response.message.content().to_string())
53 }
54}
55
56#[async_trait]
57impl MemoryStore for ConversationSummaryBufferMemory {
58 async fn append(&self, session_id: &str, message: Message) -> Result<(), SynapticError> {
59 self.store.append(session_id, message).await?;
60
61 let messages = self.store.load(session_id).await?;
62 let total_tokens: usize = messages
63 .iter()
64 .map(|m| Self::estimate_tokens(m.content()))
65 .sum();
66
67 if total_tokens > self.max_token_limit && messages.len() > 1 {
68 let half_limit = self.max_token_limit / 2;
70 let mut recent_tokens = 0;
71 let mut split_point = messages.len();
72
73 for (i, msg) in messages.iter().enumerate().rev() {
74 let tokens = Self::estimate_tokens(msg.content());
75 if recent_tokens + tokens > half_limit {
76 split_point = i + 1;
77 break;
78 }
79 recent_tokens += tokens;
80 }
81
82 if split_point == 0 {
84 split_point = 1;
85 }
86 if split_point >= messages.len() {
87 split_point = messages.len() - 1;
88 }
89
90 let older = &messages[..split_point];
91 let recent = &messages[split_point..];
92
93 let existing_summary = {
95 let summaries = self.summary.read().await;
96 summaries.get(session_id).cloned()
97 };
98
99 let to_summarize = if let Some(ref existing) = existing_summary {
100 let mut with_context =
101 vec![Message::system(format!("Previous summary: {existing}"))];
102 with_context.extend_from_slice(older);
103 with_context
104 } else {
105 older.to_vec()
106 };
107
108 let new_summary = self.summarize(&to_summarize).await?;
109
110 {
111 let mut summaries = self.summary.write().await;
112 summaries.insert(session_id.to_string(), new_summary);
113 }
114
115 self.store.clear(session_id).await?;
117 for msg in recent {
118 self.store.append(session_id, msg.clone()).await?;
119 }
120 }
121
122 Ok(())
123 }
124
125 async fn load(&self, session_id: &str) -> Result<Vec<Message>, SynapticError> {
126 let messages = self.store.load(session_id).await?;
127 let summaries = self.summary.read().await;
128
129 if let Some(summary_text) = summaries.get(session_id) {
130 let mut result = vec![Message::system(format!(
131 "Summary of earlier conversation: {summary_text}"
132 ))];
133 result.extend(messages);
134 Ok(result)
135 } else {
136 Ok(messages)
137 }
138 }
139
140 async fn clear(&self, session_id: &str) -> Result<(), SynapticError> {
141 self.store.clear(session_id).await?;
142 let mut summaries = self.summary.write().await;
143 summaries.remove(session_id);
144 Ok(())
145 }
146}