zeph_memory/semantic/
cross_session.rs1use zeph_llm::provider::LlmProvider as _;
5
6use crate::error::MemoryError;
7use crate::types::ConversationId;
8use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
9
10use super::{SESSION_SUMMARIES_COLLECTION, SemanticMemory};
11
12#[derive(Debug, Clone)]
13pub struct SessionSummaryResult {
14 pub summary_text: String,
15 pub score: f32,
16 pub conversation_id: ConversationId,
17}
18
19impl SemanticMemory {
20 pub async fn has_session_summary(
30 &self,
31 conversation_id: ConversationId,
32 ) -> Result<bool, MemoryError> {
33 let summaries = self.sqlite.load_summaries(conversation_id).await?;
34 Ok(!summaries.is_empty())
35 }
36
37 pub async fn store_shutdown_summary(
49 &self,
50 conversation_id: ConversationId,
51 summary_text: &str,
52 key_facts: &[String],
53 ) -> Result<(), MemoryError> {
54 let token_estimate =
55 i64::try_from(self.token_counter.count_tokens(summary_text)).unwrap_or(0);
56 let summary_id = self
59 .sqlite
60 .save_summary(conversation_id, summary_text, None, None, token_estimate)
61 .await?;
62
63 if let Err(e) = self
65 .store_session_summary(conversation_id, summary_text)
66 .await
67 {
68 tracing::warn!("shutdown summary: failed to embed into session summaries: {e:#}");
69 }
70
71 if !key_facts.is_empty() {
72 self.store_key_facts(conversation_id, summary_id, key_facts)
73 .await;
74 }
75
76 tracing::debug!(
77 conversation_id = conversation_id.0,
78 summary_id,
79 "stored shutdown session summary"
80 );
81 Ok(())
82 }
83
84 pub async fn store_session_summary(
90 &self,
91 conversation_id: ConversationId,
92 summary_text: &str,
93 ) -> Result<(), MemoryError> {
94 let Some(qdrant) = &self.qdrant else {
95 return Ok(());
96 };
97 if !self.provider.supports_embeddings() {
98 return Ok(());
99 }
100
101 let vector = self.provider.embed(summary_text).await?;
102 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
103 qdrant
104 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
105 .await?;
106
107 let payload = serde_json::json!({
108 "conversation_id": conversation_id.0,
109 "summary_text": summary_text,
110 });
111
112 qdrant
113 .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
114 .await?;
115
116 tracing::debug!(
117 conversation_id = conversation_id.0,
118 "stored session summary"
119 );
120 Ok(())
121 }
122
123 pub async fn search_session_summaries(
129 &self,
130 query: &str,
131 limit: usize,
132 exclude_conversation_id: Option<ConversationId>,
133 ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
134 let Some(qdrant) = &self.qdrant else {
135 tracing::debug!("session-summaries: skipped, no vector store");
136 return Ok(Vec::new());
137 };
138 if !self.provider.supports_embeddings() {
139 tracing::debug!("session-summaries: skipped, no embedding support");
140 return Ok(Vec::new());
141 }
142
143 let vector = self.provider.embed(query).await?;
144 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
145 qdrant
146 .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
147 .await?;
148
149 let filter = exclude_conversation_id.map(|cid| VectorFilter {
150 must: vec![],
151 must_not: vec![FieldCondition {
152 field: "conversation_id".into(),
153 value: FieldValue::Integer(cid.0),
154 }],
155 });
156
157 let points = qdrant
158 .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
159 .await?;
160
161 tracing::debug!(
162 results = points.len(),
163 limit,
164 exclude_conversation_id = exclude_conversation_id.map(|c| c.0),
165 "session-summaries: search complete"
166 );
167
168 let results = points
169 .into_iter()
170 .filter_map(|point| {
171 let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
172 let conversation_id =
173 ConversationId(point.payload.get("conversation_id")?.as_i64()?);
174 Some(SessionSummaryResult {
175 summary_text,
176 score: point.score,
177 conversation_id,
178 })
179 })
180 .collect();
181
182 Ok(results)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use zeph_llm::any::AnyProvider;
189 use zeph_llm::mock::MockProvider;
190
191 use crate::types::MessageId;
192
193 use super::*;
194
195 async fn make_memory() -> SemanticMemory {
196 SemanticMemory::new(
197 ":memory:",
198 "http://127.0.0.1:1",
199 AnyProvider::Mock(MockProvider::default()),
200 "test-model",
201 )
202 .await
203 .unwrap()
204 }
205
206 async fn insert_message(memory: &SemanticMemory, cid: ConversationId) -> MessageId {
209 let id = memory
210 .sqlite()
211 .save_message(cid, "user", "test message")
212 .await
213 .unwrap();
214 id
215 }
216
217 #[tokio::test]
218 async fn has_session_summary_returns_false_when_no_summaries() {
219 let memory = make_memory().await;
220 let cid = memory.sqlite().create_conversation().await.unwrap();
221
222 let result = memory.has_session_summary(cid).await.unwrap();
223 assert!(!result, "new conversation must have no summaries");
224 }
225
226 #[tokio::test]
227 async fn has_session_summary_returns_true_after_summary_stored_via_sqlite() {
228 let memory = make_memory().await;
229 let cid = memory.sqlite().create_conversation().await.unwrap();
230 let msg_id = insert_message(&memory, cid).await;
231
232 memory
234 .sqlite()
235 .save_summary(
236 cid,
237 "session about Rust and async",
238 Some(msg_id),
239 Some(msg_id),
240 10,
241 )
242 .await
243 .unwrap();
244
245 let result = memory.has_session_summary(cid).await.unwrap();
246 assert!(result, "must return true after a summary is persisted");
247 }
248
249 #[tokio::test]
250 async fn has_session_summary_is_isolated_per_conversation() {
251 let memory = make_memory().await;
252 let cid_a = memory.sqlite().create_conversation().await.unwrap();
253 let cid_b = memory.sqlite().create_conversation().await.unwrap();
254 let msg_id = insert_message(&memory, cid_a).await;
255
256 memory
257 .sqlite()
258 .save_summary(
259 cid_a,
260 "summary for conversation A",
261 Some(msg_id),
262 Some(msg_id),
263 5,
264 )
265 .await
266 .unwrap();
267
268 assert!(
269 memory.has_session_summary(cid_a).await.unwrap(),
270 "cid_a must have a summary"
271 );
272 assert!(
273 !memory.has_session_summary(cid_b).await.unwrap(),
274 "cid_b must not be affected by cid_a summary"
275 );
276 }
277
278 #[tokio::test]
279 async fn store_shutdown_summary_succeeds_with_null_message_ids() {
280 let memory = make_memory().await;
281 let cid = memory.sqlite().create_conversation().await.unwrap();
282
283 let result = memory
284 .store_shutdown_summary(cid, "summary text", &[])
285 .await;
286
287 assert!(
288 result.is_ok(),
289 "shutdown summary must succeed without messages"
290 );
291 assert!(
292 memory.has_session_summary(cid).await.unwrap(),
293 "SQLite must record the shutdown summary"
294 );
295 }
296}