1use super::SqliteStore;
5use crate::error::MemoryError;
6use crate::types::{ConversationId, MessageId};
7
8impl SqliteStore {
9 pub async fn save_summary(
18 &self,
19 conversation_id: ConversationId,
20 content: &str,
21 first_message_id: Option<MessageId>,
22 last_message_id: Option<MessageId>,
23 token_estimate: i64,
24 ) -> Result<i64, MemoryError> {
25 let row: (i64,) = sqlx::query_as(
26 "INSERT INTO summaries (conversation_id, content, first_message_id, last_message_id, token_estimate) \
27 VALUES (?, ?, ?, ?, ?) RETURNING id",
28 )
29 .bind(conversation_id)
30 .bind(content)
31 .bind(first_message_id)
32 .bind(last_message_id)
33 .bind(token_estimate)
34 .fetch_one(&self.pool)
35 .await
36 ?;
37 Ok(row.0)
38 }
39
40 pub async fn load_summaries(
48 &self,
49 conversation_id: ConversationId,
50 ) -> Result<
51 Vec<(
52 i64,
53 ConversationId,
54 String,
55 Option<MessageId>,
56 Option<MessageId>,
57 i64,
58 )>,
59 MemoryError,
60 > {
61 #[allow(clippy::type_complexity)]
62 let rows: Vec<(
63 i64,
64 ConversationId,
65 String,
66 Option<MessageId>,
67 Option<MessageId>,
68 i64,
69 )> = sqlx::query_as(
70 "SELECT id, conversation_id, content, first_message_id, last_message_id, \
71 token_estimate FROM summaries WHERE conversation_id = ? ORDER BY id ASC",
72 )
73 .bind(conversation_id)
74 .fetch_all(&self.pool)
75 .await?;
76
77 Ok(rows)
78 }
79
80 pub async fn latest_summary_last_message_id(
89 &self,
90 conversation_id: ConversationId,
91 ) -> Result<Option<MessageId>, MemoryError> {
92 let row: Option<(Option<MessageId>,)> = sqlx::query_as(
93 "SELECT last_message_id FROM summaries \
94 WHERE conversation_id = ? ORDER BY id DESC LIMIT 1",
95 )
96 .bind(conversation_id)
97 .fetch_optional(&self.pool)
98 .await?;
99
100 Ok(row.and_then(|r| r.0))
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 async fn test_store() -> SqliteStore {
109 SqliteStore::new(":memory:").await.unwrap()
110 }
111
112 #[tokio::test]
113 async fn save_and_load_summary() {
114 let store = test_store().await;
115 let cid = store.create_conversation().await.unwrap();
116
117 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
118 let msg_id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
119
120 let summary_id = store
121 .save_summary(
122 cid,
123 "User greeted assistant",
124 Some(msg_id1),
125 Some(msg_id2),
126 5,
127 )
128 .await
129 .unwrap();
130
131 let summaries = store.load_summaries(cid).await.unwrap();
132 assert_eq!(summaries.len(), 1);
133 assert_eq!(summaries[0].0, summary_id);
134 assert_eq!(summaries[0].2, "User greeted assistant");
135 assert_eq!(summaries[0].3, Some(msg_id1));
136 assert_eq!(summaries[0].4, Some(msg_id2));
137 assert_eq!(summaries[0].5, 5);
138 }
139
140 #[tokio::test]
141 async fn load_summaries_empty() {
142 let store = test_store().await;
143 let cid = store.create_conversation().await.unwrap();
144
145 let summaries = store.load_summaries(cid).await.unwrap();
146 assert!(summaries.is_empty());
147 }
148
149 #[tokio::test]
150 async fn load_summaries_ordered() {
151 let store = test_store().await;
152 let cid = store.create_conversation().await.unwrap();
153
154 let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
155 let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
156 let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
157
158 let s1 = store
159 .save_summary(cid, "summary1", Some(msg_id1), Some(msg_id2), 3)
160 .await
161 .unwrap();
162 let s2 = store
163 .save_summary(cid, "summary2", Some(msg_id2), Some(msg_id3), 3)
164 .await
165 .unwrap();
166
167 let summaries = store.load_summaries(cid).await.unwrap();
168 assert_eq!(summaries.len(), 2);
169 assert_eq!(summaries[0].0, s1);
170 assert_eq!(summaries[1].0, s2);
171 }
172
173 #[tokio::test]
174 async fn latest_summary_last_message_id_none() {
175 let store = test_store().await;
176 let cid = store.create_conversation().await.unwrap();
177
178 let last = store.latest_summary_last_message_id(cid).await.unwrap();
179 assert!(last.is_none());
180 }
181
182 #[tokio::test]
183 async fn latest_summary_last_message_id_some() {
184 let store = test_store().await;
185 let cid = store.create_conversation().await.unwrap();
186
187 let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
188 let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
189 let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
190
191 store
192 .save_summary(cid, "summary1", Some(msg_id1), Some(msg_id2), 3)
193 .await
194 .unwrap();
195 store
196 .save_summary(cid, "summary2", Some(msg_id2), Some(msg_id3), 3)
197 .await
198 .unwrap();
199
200 let last = store.latest_summary_last_message_id(cid).await.unwrap();
201 assert_eq!(last, Some(msg_id3));
202 }
203
204 #[tokio::test]
205 async fn cascade_delete_removes_summaries() {
206 let store = test_store().await;
207 let pool = store.pool();
208 let cid = store.create_conversation().await.unwrap();
209
210 let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
211 let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
212
213 store
214 .save_summary(cid, "summary", Some(msg_id1), Some(msg_id2), 3)
215 .await
216 .unwrap();
217
218 let before: (i64,) =
219 sqlx::query_as("SELECT COUNT(*) FROM summaries WHERE conversation_id = ?")
220 .bind(cid)
221 .fetch_one(pool)
222 .await
223 .unwrap();
224 assert_eq!(before.0, 1);
225
226 sqlx::query("DELETE FROM conversations WHERE id = ?")
227 .bind(cid)
228 .execute(pool)
229 .await
230 .unwrap();
231
232 let after: (i64,) =
233 sqlx::query_as("SELECT COUNT(*) FROM summaries WHERE conversation_id = ?")
234 .bind(cid)
235 .fetch_one(pool)
236 .await
237 .unwrap();
238 assert_eq!(after.0, 0);
239 }
240}