Skip to main content

zeph_memory/store/
summaries.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use super::SqliteStore;
5use crate::error::MemoryError;
6use crate::types::{ConversationId, MessageId};
7#[allow(unused_imports)]
8use zeph_db::sql;
9
10impl SqliteStore {
11    /// Save a summary and return its ID.
12    ///
13    /// `first_message_id` and `last_message_id` are `None` for session-level summaries
14    /// (e.g. shutdown summaries) that do not correspond to a specific message range.
15    ///
16    /// # Errors
17    ///
18    /// Returns an error if the insert fails.
19    pub async fn save_summary(
20        &self,
21        conversation_id: ConversationId,
22        content: &str,
23        first_message_id: Option<MessageId>,
24        last_message_id: Option<MessageId>,
25        token_estimate: i64,
26    ) -> Result<i64, MemoryError> {
27        let row: (i64,) = zeph_db::query_as(
28            sql!("INSERT INTO summaries (conversation_id, content, first_message_id, last_message_id, token_estimate) \
29             VALUES (?, ?, ?, ?, ?) RETURNING id"),
30        )
31        .bind(conversation_id)
32        .bind(content)
33        .bind(first_message_id)
34        .bind(last_message_id)
35        .bind(token_estimate)
36        .fetch_one(&self.pool)
37        .await
38        ?;
39        Ok(row.0)
40    }
41
42    /// Load all summaries for a conversation.
43    ///
44    /// `first_message_id` and `last_message_id` are `None` for session-level summaries.
45    ///
46    /// # Errors
47    ///
48    /// Returns an error if the query fails.
49    pub async fn load_summaries(
50        &self,
51        conversation_id: ConversationId,
52    ) -> Result<
53        Vec<(
54            i64,
55            ConversationId,
56            String,
57            Option<MessageId>,
58            Option<MessageId>,
59            i64,
60        )>,
61        MemoryError,
62    > {
63        #[allow(clippy::type_complexity)]
64        let rows: Vec<(
65            i64,
66            ConversationId,
67            String,
68            Option<MessageId>,
69            Option<MessageId>,
70            i64,
71        )> = zeph_db::query_as(sql!(
72            "SELECT id, conversation_id, content, first_message_id, last_message_id, \
73                 token_estimate FROM summaries WHERE conversation_id = ? ORDER BY id ASC"
74        ))
75        .bind(conversation_id)
76        .fetch_all(&self.pool)
77        .await?;
78
79        Ok(rows)
80    }
81
82    /// Get the last message ID covered by the most recent summary.
83    ///
84    /// Returns `None` if no summaries exist or the most recent is a session-level summary
85    /// (shutdown summary) with no tracked message range.
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if the query fails.
90    pub async fn latest_summary_last_message_id(
91        &self,
92        conversation_id: ConversationId,
93    ) -> Result<Option<MessageId>, MemoryError> {
94        let row: Option<(Option<MessageId>,)> = zeph_db::query_as(sql!(
95            "SELECT last_message_id FROM summaries \
96             WHERE conversation_id = ? ORDER BY id DESC LIMIT 1"
97        ))
98        .bind(conversation_id)
99        .fetch_optional(&self.pool)
100        .await?;
101
102        Ok(row.and_then(|r| r.0))
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    async fn test_store() -> SqliteStore {
111        SqliteStore::new(":memory:").await.unwrap()
112    }
113
114    #[tokio::test]
115    async fn save_and_load_summary() {
116        let store = test_store().await;
117        let cid = store.create_conversation().await.unwrap();
118
119        let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
120        let msg_id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
121
122        let summary_id = store
123            .save_summary(
124                cid,
125                "User greeted assistant",
126                Some(msg_id1),
127                Some(msg_id2),
128                5,
129            )
130            .await
131            .unwrap();
132
133        let summaries = store.load_summaries(cid).await.unwrap();
134        assert_eq!(summaries.len(), 1);
135        assert_eq!(summaries[0].0, summary_id);
136        assert_eq!(summaries[0].2, "User greeted assistant");
137        assert_eq!(summaries[0].3, Some(msg_id1));
138        assert_eq!(summaries[0].4, Some(msg_id2));
139        assert_eq!(summaries[0].5, 5);
140    }
141
142    #[tokio::test]
143    async fn load_summaries_empty() {
144        let store = test_store().await;
145        let cid = store.create_conversation().await.unwrap();
146
147        let summaries = store.load_summaries(cid).await.unwrap();
148        assert!(summaries.is_empty());
149    }
150
151    #[tokio::test]
152    async fn load_summaries_ordered() {
153        let store = test_store().await;
154        let cid = store.create_conversation().await.unwrap();
155
156        let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
157        let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
158        let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
159
160        let s1 = store
161            .save_summary(cid, "summary1", Some(msg_id1), Some(msg_id2), 3)
162            .await
163            .unwrap();
164        let s2 = store
165            .save_summary(cid, "summary2", Some(msg_id2), Some(msg_id3), 3)
166            .await
167            .unwrap();
168
169        let summaries = store.load_summaries(cid).await.unwrap();
170        assert_eq!(summaries.len(), 2);
171        assert_eq!(summaries[0].0, s1);
172        assert_eq!(summaries[1].0, s2);
173    }
174
175    #[tokio::test]
176    async fn latest_summary_last_message_id_none() {
177        let store = test_store().await;
178        let cid = store.create_conversation().await.unwrap();
179
180        let last = store.latest_summary_last_message_id(cid).await.unwrap();
181        assert!(last.is_none());
182    }
183
184    #[tokio::test]
185    async fn latest_summary_last_message_id_some() {
186        let store = test_store().await;
187        let cid = store.create_conversation().await.unwrap();
188
189        let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
190        let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
191        let msg_id3 = store.save_message(cid, "user", "m3").await.unwrap();
192
193        store
194            .save_summary(cid, "summary1", Some(msg_id1), Some(msg_id2), 3)
195            .await
196            .unwrap();
197        store
198            .save_summary(cid, "summary2", Some(msg_id2), Some(msg_id3), 3)
199            .await
200            .unwrap();
201
202        let last = store.latest_summary_last_message_id(cid).await.unwrap();
203        assert_eq!(last, Some(msg_id3));
204    }
205
206    #[tokio::test]
207    async fn cascade_delete_removes_summaries() {
208        let store = test_store().await;
209        let pool = store.pool();
210        let cid = store.create_conversation().await.unwrap();
211
212        let msg_id1 = store.save_message(cid, "user", "m1").await.unwrap();
213        let msg_id2 = store.save_message(cid, "assistant", "m2").await.unwrap();
214
215        store
216            .save_summary(cid, "summary", Some(msg_id1), Some(msg_id2), 3)
217            .await
218            .unwrap();
219
220        let before: (i64,) = zeph_db::query_as(sql!(
221            "SELECT COUNT(*) FROM summaries WHERE conversation_id = ?"
222        ))
223        .bind(cid)
224        .fetch_one(pool)
225        .await
226        .unwrap();
227        assert_eq!(before.0, 1);
228
229        zeph_db::query(sql!("DELETE FROM conversations WHERE id = ?"))
230            .bind(cid)
231            .execute(pool)
232            .await
233            .unwrap();
234
235        let after: (i64,) = zeph_db::query_as(sql!(
236            "SELECT COUNT(*) FROM summaries WHERE conversation_id = ?"
237        ))
238        .bind(cid)
239        .fetch_one(pool)
240        .await
241        .unwrap();
242        assert_eq!(after.0, 0);
243    }
244}