1use super::SqliteStore;
5use crate::error::MemoryError;
6use crate::types::{ConversationId, MessageId};
7
8impl SqliteStore {
9 pub async fn save_summary(
15 &self,
16 conversation_id: ConversationId,
17 content: &str,
18 first_message_id: MessageId,
19 last_message_id: MessageId,
20 token_estimate: i64,
21 ) -> Result<i64, MemoryError> {
22 let row: (i64,) = sqlx::query_as(
23 "INSERT INTO summaries (conversation_id, content, first_message_id, last_message_id, token_estimate) \
24 VALUES (?, ?, ?, ?, ?) RETURNING id",
25 )
26 .bind(conversation_id)
27 .bind(content)
28 .bind(first_message_id)
29 .bind(last_message_id)
30 .bind(token_estimate)
31 .fetch_one(&self.pool)
32 .await
33 ?;
34 Ok(row.0)
35 }
36
37 pub async fn load_summaries(
43 &self,
44 conversation_id: ConversationId,
45 ) -> Result<Vec<(i64, ConversationId, String, MessageId, MessageId, i64)>, MemoryError> {
46 let rows: Vec<(i64, ConversationId, String, MessageId, MessageId, i64)> = sqlx::query_as(
47 "SELECT id, conversation_id, content, first_message_id, last_message_id, token_estimate \
48 FROM summaries WHERE conversation_id = ? ORDER BY id ASC",
49 )
50 .bind(conversation_id)
51 .fetch_all(&self.pool)
52 .await
53 ?;
54
55 Ok(rows)
56 }
57
58 pub async fn latest_summary_last_message_id(
64 &self,
65 conversation_id: ConversationId,
66 ) -> Result<Option<MessageId>, MemoryError> {
67 let row: Option<(MessageId,)> = sqlx::query_as(
68 "SELECT last_message_id FROM summaries \
69 WHERE conversation_id = ? ORDER BY id DESC LIMIT 1",
70 )
71 .bind(conversation_id)
72 .fetch_optional(&self.pool)
73 .await?;
74
75 Ok(row.map(|r| r.0))
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 async fn test_store() -> SqliteStore {
84 SqliteStore::new(":memory:").await.unwrap()
85 }
86
87 #[tokio::test]
88 async fn save_and_load_summary() {
89 let store = test_store().await;
90 let cid = store.create_conversation().await.unwrap();
91
92 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
93 let msg_id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
94
95 let summary_id = store
96 .save_summary(cid, "User greeted assistant", msg_id1, msg_id2, 5)
97 .await
98 .unwrap();
99
100 let summaries = store.load_summaries(cid).await.unwrap();
101 assert_eq!(summaries.len(), 1);
102 assert_eq!(summaries[0].0, summary_id);
103 assert_eq!(summaries[0].2, "User greeted assistant");
104 assert_eq!(summaries[0].3, msg_id1);
105 assert_eq!(summaries[0].4, msg_id2);
106 assert_eq!(summaries[0].5, 5);
107 }
108
109 #[tokio::test]
110 async fn load_summaries_empty() {
111 let store = test_store().await;
112 let cid = store.create_conversation().await.unwrap();
113
114 let summaries = store.load_summaries(cid).await.unwrap();
115 assert!(summaries.is_empty());
116 }
117
118 #[tokio::test]
119 async fn load_summaries_ordered() {
120 let store = test_store().await;
121 let cid = store.create_conversation().await.unwrap();
122
123 let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
124 let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
125 let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
126
127 let s1 = store
128 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
129 .await
130 .unwrap();
131 let s2 = store
132 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
133 .await
134 .unwrap();
135
136 let summaries = store.load_summaries(cid).await.unwrap();
137 assert_eq!(summaries.len(), 2);
138 assert_eq!(summaries[0].0, s1);
139 assert_eq!(summaries[1].0, s2);
140 }
141
142 #[tokio::test]
143 async fn latest_summary_last_message_id_none() {
144 let store = test_store().await;
145 let cid = store.create_conversation().await.unwrap();
146
147 let last = store.latest_summary_last_message_id(cid).await.unwrap();
148 assert!(last.is_none());
149 }
150
151 #[tokio::test]
152 async fn latest_summary_last_message_id_some() {
153 let store = test_store().await;
154 let cid = store.create_conversation().await.unwrap();
155
156 let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
157 let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
158 let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
159
160 store
161 .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
162 .await
163 .unwrap();
164 store
165 .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
166 .await
167 .unwrap();
168
169 let last = store.latest_summary_last_message_id(cid).await.unwrap();
170 assert_eq!(last, Some(msg_id3));
171 }
172
173 #[tokio::test]
174 async fn cascade_delete_removes_summaries() {
175 let store = test_store().await;
176 let pool = store.pool();
177 let cid = store.create_conversation().await.unwrap();
178
179 let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
180 let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
181
182 store
183 .save_summary(cid, "summary", msg_id1, msg_id2, 3)
184 .await
185 .unwrap();
186
187 let before: (i64,) =
188 sqlx::query_as("SELECT COUNT(*) FROM summaries WHERE conversation_id = ?")
189 .bind(cid)
190 .fetch_one(pool)
191 .await
192 .unwrap();
193 assert_eq!(before.0, 1);
194
195 sqlx::query("DELETE FROM conversations WHERE id = ?")
196 .bind(cid)
197 .execute(pool)
198 .await
199 .unwrap();
200
201 let after: (i64,) =
202 sqlx::query_as("SELECT COUNT(*) FROM summaries WHERE conversation_id = ?")
203 .bind(cid)
204 .fetch_one(pool)
205 .await
206 .unwrap();
207 assert_eq!(after.0, 0);
208 }
209}