1use super::SqliteStore;
2use crate::error::MemoryError;
3use crate::types::{ConversationId, MessageId};
4
5impl SqliteStore {
6 pub async fn save_summary(
12 &self,
13 conversation_id: ConversationId,
14 content: &str,
15 first_message_id: MessageId,
16 last_message_id: MessageId,
17 token_estimate: i64,
18 ) -> Result<i64, MemoryError> {
19 let row: (i64,) = sqlx::query_as(
20 "INSERT INTO summaries (conversation_id, content, first_message_id, last_message_id, token_estimate) \
21 VALUES (?, ?, ?, ?, ?) RETURNING id",
22 )
23 .bind(conversation_id)
24 .bind(content)
25 .bind(first_message_id)
26 .bind(last_message_id)
27 .bind(token_estimate)
28 .fetch_one(&self.pool)
29 .await
30 ?;
31 Ok(row.0)
32 }
33
34 pub async fn load_summaries(
40 &self,
41 conversation_id: ConversationId,
42 ) -> Result<Vec<(i64, ConversationId, String, MessageId, MessageId, i64)>, MemoryError> {
43 let rows: Vec<(i64, ConversationId, String, MessageId, MessageId, i64)> = sqlx::query_as(
44 "SELECT id, conversation_id, content, first_message_id, last_message_id, token_estimate \
45 FROM summaries WHERE conversation_id = ? ORDER BY id ASC",
46 )
47 .bind(conversation_id)
48 .fetch_all(&self.pool)
49 .await
50 ?;
51
52 Ok(rows)
53 }
54
55 pub async fn latest_summary_last_message_id(
61 &self,
62 conversation_id: ConversationId,
63 ) -> Result<Option<MessageId>, MemoryError> {
64 let row: Option<(MessageId,)> = sqlx::query_as(
65 "SELECT last_message_id FROM summaries \
66 WHERE conversation_id = ? ORDER BY id DESC LIMIT 1",
67 )
68 .bind(conversation_id)
69 .fetch_optional(&self.pool)
70 .await?;
71
72 Ok(row.map(|r| r.0))
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 async fn test_store() -> SqliteStore {
81 SqliteStore::new(":memory:").await.unwrap()
82 }
83
84 #[tokio::test]
85 async fn save_and_load_summary() {
86 let store = test_store().await;
87 let cid = store.create_conversation().await.unwrap();
88
89 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
90 let msg_id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
91
92 let summary_id = store
93 .save_summary(cid, "User greeted assistant", msg_id1, msg_id2, 5)
94 .await
95 .unwrap();
96
97 let summaries = store.load_summaries(cid).await.unwrap();
98 assert_eq!(summaries.len(), 1);
99 assert_eq!(summaries[0].0, summary_id);
100 assert_eq!(summaries[0].2, "User greeted assistant");
101 assert_eq!(summaries[0].3, msg_id1);
102 assert_eq!(summaries[0].4, msg_id2);
103 assert_eq!(summaries[0].5, 5);
104 }
105
106 #[tokio::test]
107 async fn load_summaries_empty() {
108 let store = test_store().await;
109 let cid = store.create_conversation().await.unwrap();
110
111 let summaries = store.load_summaries(cid).await.unwrap();
112 assert!(summaries.is_empty());
113 }
114
115 #[tokio::test]
116 async fn load_summaries_ordered() {
117 let store = test_store().await;
118 let cid = store.create_conversation().await.unwrap();
119
120 let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
121 let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
122 let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
123
124 let s1 = store
125 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
126 .await
127 .unwrap();
128 let s2 = store
129 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
130 .await
131 .unwrap();
132
133 let summaries = store.load_summaries(cid).await.unwrap();
134 assert_eq!(summaries.len(), 2);
135 assert_eq!(summaries[0].0, s1);
136 assert_eq!(summaries[1].0, s2);
137 }
138
139 #[tokio::test]
140 async fn latest_summary_last_message_id_none() {
141 let store = test_store().await;
142 let cid = store.create_conversation().await.unwrap();
143
144 let last = store.latest_summary_last_message_id(cid).await.unwrap();
145 assert!(last.is_none());
146 }
147
148 #[tokio::test]
149 async fn latest_summary_last_message_id_some() {
150 let store = test_store().await;
151 let cid = store.create_conversation().await.unwrap();
152
153 let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
154 let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
155 let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
156
157 store
158 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
159 .await
160 .unwrap();
161 store
162 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
163 .await
164 .unwrap();
165
166 let last = store.latest_summary_last_message_id(cid).await.unwrap();
167 assert_eq!(last, Some(msg_id3));
168 }
169
170 #[tokio::test]
171 async fn cascade_delete_removes_summaries() {
172 let store = test_store().await;
173 let pool = store.pool();
174 let cid = store.create_conversation().await.unwrap();
175
176 let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
177 let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
178
179 store
180 .save_summary(cid, "summary", msg_id1, msg_id2, 3)
181 .await
182 .unwrap();
183
184 let before: (i64,) =
185 sqlx::query_as("SELECT COUNT(*) FROM summaries WHERE conversation_id = ?")
186 .bind(cid)
187 .fetch_one(pool)
188 .await
189 .unwrap();
190 assert_eq!(before.0, 1);
191
192 sqlx::query("DELETE FROM conversations WHERE id = ?")
193 .bind(cid)
194 .execute(pool)
195 .await
196 .unwrap();
197
198 let after: (i64,) =
199 sqlx::query_as("SELECT COUNT(*) FROM summaries WHERE conversation_id = ?")
200 .bind(cid)
201 .fetch_one(pool)
202 .await
203 .unwrap();
204 assert_eq!(after.0, 0);
205 }
206}