Skip to main content

zeph_memory/sqlite/
summaries.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use super::SqliteStore;
5use crate::error::MemoryError;
6use crate::types::{ConversationId, MessageId};
7
8impl SqliteStore {
9    /// Save a summary and return its ID.
10    ///
11    /// `first_message_id` and `last_message_id` are `None` for session-level summaries
12    /// (e.g. shutdown summaries) that do not correspond to a specific message range.
13    ///
14    /// # Errors
15    ///
16    /// Returns an error if the insert fails.
17    pub async fn save_summary(
18        &self,
19        conversation_id: ConversationId,
20        content: &str,
21        first_message_id: Option<MessageId>,
22        last_message_id: Option<MessageId>,
23        token_estimate: i64,
24    ) -> Result<i64, MemoryError> {
25        let row: (i64,) = sqlx::query_as(
26            "INSERT INTO summaries (conversation_id, content, first_message_id, last_message_id, token_estimate) \
27             VALUES (?, ?, ?, ?, ?) RETURNING id",
28        )
29        .bind(conversation_id)
30        .bind(content)
31        .bind(first_message_id)
32        .bind(last_message_id)
33        .bind(token_estimate)
34        .fetch_one(&self.pool)
35        .await
36        ?;
37        Ok(row.0)
38    }
39
40    /// Load all summaries for a conversation.
41    ///
42    /// `first_message_id` and `last_message_id` are `None` for session-level summaries.
43    ///
44    /// # Errors
45    ///
46    /// Returns an error if the query fails.
47    pub async fn load_summaries(
48        &self,
49        conversation_id: ConversationId,
50    ) -> Result<
51        Vec<(
52            i64,
53            ConversationId,
54            String,
55            Option<MessageId>,
56            Option<MessageId>,
57            i64,
58        )>,
59        MemoryError,
60    > {
61        #[allow(clippy::type_complexity)]
62        let rows: Vec<(
63            i64,
64            ConversationId,
65            String,
66            Option<MessageId>,
67            Option<MessageId>,
68            i64,
69        )> = sqlx::query_as(
70            "SELECT id, conversation_id, content, first_message_id, last_message_id, \
71                 token_estimate FROM summaries WHERE conversation_id = ? ORDER BY id ASC",
72        )
73        .bind(conversation_id)
74        .fetch_all(&self.pool)
75        .await?;
76
77        Ok(rows)
78    }
79
80    /// Get the last message ID covered by the most recent summary.
81    ///
82    /// Returns `None` if no summaries exist or the most recent is a session-level summary
83    /// (shutdown summary) with no tracked message range.
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if the query fails.
88    pub async fn latest_summary_last_message_id(
89        &self,
90        conversation_id: ConversationId,
91    ) -> Result<Option<MessageId>, MemoryError> {
92        let row: Option<(Option<MessageId>,)> = sqlx::query_as(
93            "SELECT last_message_id FROM summaries \
94             WHERE conversation_id = ? ORDER BY id DESC LIMIT 1",
95        )
96        .bind(conversation_id)
97        .fetch_optional(&self.pool)
98        .await?;
99
100        Ok(row.and_then(|r| r.0))
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    async fn test_store() -> SqliteStore {
109        SqliteStore::new(":memory:").await.unwrap()
110    }
111
112    #[tokio::test]
113    async fn save_and_load_summary() {
114        let store = test_store().await;
115        let cid = store.create_conversation().await.unwrap();
116
117        let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
118        let msg_id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
119
120        let summary_id = store
121            .save_summary(
122                cid,
123                "User greeted assistant",
124                Some(msg_id1),
125                Some(msg_id2),
126                5,
127            )
128            .await
129            .unwrap();
130
131        let summaries = store.load_summaries(cid).await.unwrap();
132        assert_eq!(summaries.len(), 1);
133        assert_eq!(summaries[0].0, summary_id);
134        assert_eq!(summaries[0].2, "User greeted assistant");
135        assert_eq!(summaries[0].3, Some(msg_id1));
136        assert_eq!(summaries[0].4, Some(msg_id2));
137        assert_eq!(summaries[0].5, 5);
138    }
139
140    #[tokio::test]
141    async fn load_summaries_empty() {
142        let store = test_store().await;
143        let cid = store.create_conversation().await.unwrap();
144
145        let summaries = store.load_summaries(cid).await.unwrap();
146        assert!(summaries.is_empty());
147    }
148
149    #[tokio::test]
150    async fn load_summaries_ordered() {
151        let store = test_store().await;
152        let cid = store.create_conversation().await.unwrap();
153
154        let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
155        let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
156        let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
157
158        let s1 = store
159            .save_summary(cid, "summary1", Some(msg_id1), Some(msg_id2), 3)
160            .await
161            .unwrap();
162        let s2 = store
163            .save_summary(cid, "summary2", Some(msg_id2), Some(msg_id3), 3)
164            .await
165            .unwrap();
166
167        let summaries = store.load_summaries(cid).await.unwrap();
168        assert_eq!(summaries.len(), 2);
169        assert_eq!(summaries[0].0, s1);
170        assert_eq!(summaries[1].0, s2);
171    }
172
173    #[tokio::test]
174    async fn latest_summary_last_message_id_none() {
175        let store = test_store().await;
176        let cid = store.create_conversation().await.unwrap();
177
178        let last = store.latest_summary_last_message_id(cid).await.unwrap();
179        assert!(last.is_none());
180    }
181
182    #[tokio::test]
183    async fn latest_summary_last_message_id_some() {
184        let store = test_store().await;
185        let cid = store.create_conversation().await.unwrap();
186
187        let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
188        let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
189        let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
190
191        store
192            .save_summary(cid, "summary1", Some(msg_id1), Some(msg_id2), 3)
193            .await
194            .unwrap();
195        store
196            .save_summary(cid, "summary2", Some(msg_id2), Some(msg_id3), 3)
197            .await
198            .unwrap();
199
200        let last = store.latest_summary_last_message_id(cid).await.unwrap();
201        assert_eq!(last, Some(msg_id3));
202    }
203
204    #[tokio::test]
205    async fn cascade_delete_removes_summaries() {
206        let store = test_store().await;
207        let pool = store.pool();
208        let cid = store.create_conversation().await.unwrap();
209
210        let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
211        let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
212
213        store
214            .save_summary(cid, "summary", Some(msg_id1), Some(msg_id2), 3)
215            .await
216            .unwrap();
217
218        let before: (i64,) =
219            sqlx::query_as("SELECT COUNT(*) FROM summaries WHERE conversation_id = ?")
220                .bind(cid)
221                .fetch_one(pool)
222                .await
223                .unwrap();
224        assert_eq!(before.0, 1);
225
226        sqlx::query("DELETE FROM conversations WHERE id = ?")
227            .bind(cid)
228            .execute(pool)
229            .await
230            .unwrap();
231
232        let after: (i64,) =
233            sqlx::query_as("SELECT COUNT(*) FROM summaries WHERE conversation_id = ?")
234                .bind(cid)
235                .fetch_one(pool)
236                .await
237                .unwrap();
238        assert_eq!(after.0, 0);
239    }
240}