Skip to main content

zeph_memory/sqlite/
summaries.rs

1use super::SqliteStore;
2use crate::error::MemoryError;
3use crate::types::{ConversationId, MessageId};
4
5impl SqliteStore {
6    /// Save a summary and return its ID.
7    ///
8    /// # Errors
9    ///
10    /// Returns an error if the insert fails.
11    pub async fn save_summary(
12        &self,
13        conversation_id: ConversationId,
14        content: &str,
15        first_message_id: MessageId,
16        last_message_id: MessageId,
17        token_estimate: i64,
18    ) -> Result<i64, MemoryError> {
19        let row: (i64,) = sqlx::query_as(
20            "INSERT INTO summaries (conversation_id, content, first_message_id, last_message_id, token_estimate) \
21             VALUES (?, ?, ?, ?, ?) RETURNING id",
22        )
23        .bind(conversation_id)
24        .bind(content)
25        .bind(first_message_id)
26        .bind(last_message_id)
27        .bind(token_estimate)
28        .fetch_one(&self.pool)
29        .await
30        ?;
31        Ok(row.0)
32    }
33
34    /// Load all summaries for a conversation.
35    ///
36    /// # Errors
37    ///
38    /// Returns an error if the query fails.
39    pub async fn load_summaries(
40        &self,
41        conversation_id: ConversationId,
42    ) -> Result<Vec<(i64, ConversationId, String, MessageId, MessageId, i64)>, MemoryError> {
43        let rows: Vec<(i64, ConversationId, String, MessageId, MessageId, i64)> = sqlx::query_as(
44            "SELECT id, conversation_id, content, first_message_id, last_message_id, token_estimate \
45             FROM summaries WHERE conversation_id = ? ORDER BY id ASC",
46        )
47        .bind(conversation_id)
48        .fetch_all(&self.pool)
49        .await
50        ?;
51
52        Ok(rows)
53    }
54
55    /// Get the last message ID covered by the most recent summary.
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if the query fails.
60    pub async fn latest_summary_last_message_id(
61        &self,
62        conversation_id: ConversationId,
63    ) -> Result<Option<MessageId>, MemoryError> {
64        let row: Option<(MessageId,)> = sqlx::query_as(
65            "SELECT last_message_id FROM summaries \
66             WHERE conversation_id = ? ORDER BY id DESC LIMIT 1",
67        )
68        .bind(conversation_id)
69        .fetch_optional(&self.pool)
70        .await?;
71
72        Ok(row.map(|r| r.0))
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    async fn test_store() -> SqliteStore {
81        SqliteStore::new(":memory:").await.unwrap()
82    }
83
84    #[tokio::test]
85    async fn save_and_load_summary() {
86        let store = test_store().await;
87        let cid = store.create_conversation().await.unwrap();
88
89        let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
90        let msg_id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
91
92        let summary_id = store
93            .save_summary(cid, "User greeted assistant", msg_id1, msg_id2, 5)
94            .await
95            .unwrap();
96
97        let summaries = store.load_summaries(cid).await.unwrap();
98        assert_eq!(summaries.len(), 1);
99        assert_eq!(summaries[0].0, summary_id);
100        assert_eq!(summaries[0].2, "User greeted assistant");
101        assert_eq!(summaries[0].3, msg_id1);
102        assert_eq!(summaries[0].4, msg_id2);
103        assert_eq!(summaries[0].5, 5);
104    }
105
106    #[tokio::test]
107    async fn load_summaries_empty() {
108        let store = test_store().await;
109        let cid = store.create_conversation().await.unwrap();
110
111        let summaries = store.load_summaries(cid).await.unwrap();
112        assert!(summaries.is_empty());
113    }
114
115    #[tokio::test]
116    async fn load_summaries_ordered() {
117        let store = test_store().await;
118        let cid = store.create_conversation().await.unwrap();
119
120        let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
121        let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
122        let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
123
124        let s1 = store
125            .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
126            .await
127            .unwrap();
128        let s2 = store
129            .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
130            .await
131            .unwrap();
132
133        let summaries = store.load_summaries(cid).await.unwrap();
134        assert_eq!(summaries.len(), 2);
135        assert_eq!(summaries[0].0, s1);
136        assert_eq!(summaries[1].0, s2);
137    }
138
139    #[tokio::test]
140    async fn latest_summary_last_message_id_none() {
141        let store = test_store().await;
142        let cid = store.create_conversation().await.unwrap();
143
144        let last = store.latest_summary_last_message_id(cid).await.unwrap();
145        assert!(last.is_none());
146    }
147
148    #[tokio::test]
149    async fn latest_summary_last_message_id_some() {
150        let store = test_store().await;
151        let cid = store.create_conversation().await.unwrap();
152
153        let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
154        let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
155        let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
156
157        store
158            .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
159            .await
160            .unwrap();
161        store
162            .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
163            .await
164            .unwrap();
165
166        let last = store.latest_summary_last_message_id(cid).await.unwrap();
167        assert_eq!(last, Some(msg_id3));
168    }
169
170    #[tokio::test]
171    async fn cascade_delete_removes_summaries() {
172        let store = test_store().await;
173        let pool = store.pool();
174        let cid = store.create_conversation().await.unwrap();
175
176        let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
177        let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
178
179        store
180            .save_summary(cid, "summary", msg_id1, msg_id2, 3)
181            .await
182            .unwrap();
183
184        let before: (i64,) =
185            sqlx::query_as("SELECT COUNT(*) FROM summaries WHERE conversation_id = ?")
186                .bind(cid)
187                .fetch_one(pool)
188                .await
189                .unwrap();
190        assert_eq!(before.0, 1);
191
192        sqlx::query("DELETE FROM conversations WHERE id = ?")
193            .bind(cid)
194            .execute(pool)
195            .await
196            .unwrap();
197
198        let after: (i64,) =
199            sqlx::query_as("SELECT COUNT(*) FROM summaries WHERE conversation_id = ?")
200                .bind(cid)
201                .fetch_one(pool)
202                .await
203                .unwrap();
204        assert_eq!(after.0, 0);
205    }
206}