1use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
5
6use super::{KEY_FACTS_COLLECTION, SemanticMemory};
7use crate::embedding_store::MessageKind;
8use crate::error::MemoryError;
9use crate::types::{ConversationId, MessageId};
10
11#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
12pub struct StructuredSummary {
13 pub summary: String,
14 pub key_facts: Vec<String>,
15 pub entities: Vec<String>,
16}
17
18#[derive(Debug, Clone)]
19pub struct Summary {
20 pub id: i64,
21 pub conversation_id: ConversationId,
22 pub content: String,
23 pub first_message_id: Option<MessageId>,
25 pub last_message_id: Option<MessageId>,
27 pub token_estimate: i64,
28}
29
30#[must_use]
31pub fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
32 let mut prompt = String::from(
33 "Summarize the following conversation. Extract key facts, decisions, entities, \
34 and context needed to continue the conversation.\n\n\
35 Respond in JSON with fields: summary (string), key_facts (list of strings), \
36 entities (list of strings).\n\nConversation:\n",
37 );
38
39 for (_, role, content) in messages {
40 prompt.push_str(role);
41 prompt.push_str(": ");
42 prompt.push_str(content);
43 prompt.push('\n');
44 }
45
46 prompt
47}
48
49impl SemanticMemory {
50 pub async fn load_summaries(
56 &self,
57 conversation_id: ConversationId,
58 ) -> Result<Vec<Summary>, MemoryError> {
59 let rows = self.sqlite.load_summaries(conversation_id).await?;
60 let summaries = rows
61 .into_iter()
62 .map(
63 |(
64 id,
65 conversation_id,
66 content,
67 first_message_id,
68 last_message_id,
69 token_estimate,
70 )| {
71 Summary {
72 id,
73 conversation_id,
74 content,
75 first_message_id,
76 last_message_id,
77 token_estimate,
78 }
79 },
80 )
81 .collect();
82 Ok(summaries)
83 }
84
85 #[tracing::instrument(name = "memory.summarize", skip_all, fields(input_msgs = %message_count, output_len = tracing::field::Empty))]
93 pub async fn summarize(
94 &self,
95 conversation_id: ConversationId,
96 message_count: usize,
97 ) -> Result<Option<i64>, MemoryError> {
98 let total = self.sqlite.count_messages(conversation_id).await?;
99
100 if total <= i64::try_from(message_count)? {
101 return Ok(None);
102 }
103
104 let after_id = self
105 .sqlite
106 .latest_summary_last_message_id(conversation_id)
107 .await?
108 .unwrap_or(MessageId(0));
109
110 let messages = self
111 .sqlite
112 .load_messages_range(conversation_id, after_id, message_count)
113 .await?;
114
115 if messages.is_empty() {
116 return Ok(None);
117 }
118
119 let prompt = build_summarization_prompt(&messages);
120 let chat_messages = vec![Message {
121 role: Role::User,
122 content: prompt,
123 parts: vec![],
124 metadata: MessageMetadata::default(),
125 }];
126
127 let structured = self.call_summarization_llm(&chat_messages).await?;
128 let summary_text = &structured.summary;
129
130 let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
131 let first_message_id = messages[0].0;
132 let last_message_id = messages[messages.len() - 1].0;
133
134 let summary_id = self
135 .sqlite
136 .save_summary(
137 conversation_id,
138 summary_text,
139 Some(first_message_id),
140 Some(last_message_id),
141 token_estimate,
142 )
143 .await?;
144
145 if let Some(qdrant) = &self.qdrant
146 && self.effective_embed_provider().supports_embeddings()
147 {
148 match tokio::time::timeout(
149 std::time::Duration::from_secs(5),
150 self.effective_embed_provider().embed(summary_text),
151 )
152 .await
153 {
154 Ok(Ok(vector)) => {
155 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
156 if let Err(e) = qdrant.ensure_collection(vector_size).await {
157 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
158 } else if let Err(e) = qdrant
159 .store(
160 MessageId(summary_id),
161 conversation_id,
162 "system",
163 vector,
164 MessageKind::Summary,
165 &self.embedding_model,
166 0,
167 )
168 .await
169 {
170 tracing::warn!("Failed to embed summary: {e:#}");
171 }
172 }
173 Ok(Err(e)) => {
174 tracing::warn!("Failed to generate summary embedding: {e:#}");
175 }
176 Err(_) => {
177 tracing::warn!("summarize: embed timed out for summary text — skipping store");
178 }
179 }
180 }
181
182 if !structured.key_facts.is_empty() {
183 self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
184 .await;
185 }
186
187 Ok(Some(summary_id))
188 }
189
190 async fn call_summarization_llm(
199 &self,
200 chat_messages: &[Message],
201 ) -> Result<StructuredSummary, MemoryError> {
202 let timeout_secs = self.summarization_llm_timeout_secs;
203 let timeout = std::time::Duration::from_secs(timeout_secs);
204 match tokio::time::timeout(
205 timeout,
206 self.provider
207 .chat_typed_erased::<StructuredSummary>(chat_messages),
208 )
209 .await
210 {
211 Ok(Ok(s)) => Ok(s),
212 Ok(Err(e)) => {
213 tracing::warn!(
214 "structured summarization failed, falling back to plain text: {e:#}"
215 );
216 match tokio::time::timeout(timeout, self.provider.chat(chat_messages)).await {
217 Ok(Ok(plain)) => Ok(StructuredSummary {
218 summary: plain,
219 key_facts: vec![],
220 entities: vec![],
221 }),
222 Ok(Err(e)) => Err(MemoryError::Llm(e)),
223 Err(_elapsed) => {
224 tracing::warn!(
225 "summarization: plain text fallback LLM call timed out after {timeout_secs}s"
226 );
227 Err(MemoryError::Timeout("LLM call timed out".into()))
228 }
229 }
230 }
231 Err(_elapsed) => {
232 tracing::warn!(
233 "summarization: structured LLM call timed out after {timeout_secs}s"
234 );
235 Err(MemoryError::Timeout("LLM call timed out".into()))
236 }
237 }
238 }
239
240 pub(super) async fn store_key_facts(
241 &self,
242 conversation_id: ConversationId,
243 source_summary_id: i64,
244 key_facts: &[String],
245 ) {
246 let Some(qdrant) = &self.qdrant else {
247 return;
248 };
249 if !self.effective_embed_provider().supports_embeddings() {
250 return;
251 }
252
253 let filtered: Vec<&str> = key_facts
257 .iter()
258 .filter(|f| !is_policy_decision_fact(f.as_str()))
259 .map(String::as_str)
260 .collect();
261
262 let Some(first_fact) = filtered.first().copied() else {
263 return;
264 };
265 let first_vector = match tokio::time::timeout(
266 std::time::Duration::from_secs(5),
267 self.effective_embed_provider().embed(first_fact),
268 )
269 .await
270 {
271 Ok(Ok(v)) => v,
272 Ok(Err(e)) => {
273 tracing::warn!("Failed to embed key fact: {e:#}");
274 return;
275 }
276 Err(_) => {
277 tracing::warn!("store_key_facts: embed timed out for first fact — skipping");
278 return;
279 }
280 };
281 let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
282 if let Err(e) = qdrant
283 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
284 .await
285 {
286 tracing::warn!("Failed to ensure key_facts collection: {e:#}");
287 return;
288 }
289
290 let threshold = self.key_facts_dedup_threshold;
291 self.store_key_fact_if_unique(
292 qdrant,
293 conversation_id,
294 source_summary_id,
295 first_fact,
296 first_vector,
297 threshold,
298 )
299 .await;
300
301 for fact in filtered[1..].iter().copied() {
302 match tokio::time::timeout(
303 std::time::Duration::from_secs(5),
304 self.effective_embed_provider().embed(fact),
305 )
306 .await
307 {
308 Ok(Ok(vector)) => {
309 self.store_key_fact_if_unique(
310 qdrant,
311 conversation_id,
312 source_summary_id,
313 fact,
314 vector,
315 threshold,
316 )
317 .await;
318 }
319 Ok(Err(e)) => {
320 tracing::warn!("Failed to embed key fact: {e:#}");
321 }
322 Err(_) => {
323 tracing::warn!("store_key_facts: embed timed out for fact — skipping");
324 }
325 }
326 }
327 }
328
329 async fn store_key_fact_if_unique(
330 &self,
331 qdrant: &crate::embedding_store::EmbeddingStore,
332 conversation_id: ConversationId,
333 source_summary_id: i64,
334 fact: &str,
335 vector: Vec<f32>,
336 threshold: f32,
337 ) {
338 match qdrant
339 .search_collection(KEY_FACTS_COLLECTION, &vector, 1, None)
340 .await
341 {
342 Ok(hits) if hits.first().is_some_and(|h| h.score >= threshold) => {
343 tracing::debug!(
344 score = hits[0].score,
345 threshold,
346 "key-facts: skipping near-duplicate fact"
347 );
348 return;
349 }
350 Ok(_) => {}
351 Err(e) => {
352 tracing::warn!("key-facts: dedup search failed, storing anyway: {e:#}");
353 }
354 }
355
356 let payload = serde_json::json!({
357 "conversation_id": conversation_id.0,
358 "fact_text": fact,
359 "source_summary_id": source_summary_id,
360 });
361 if let Err(e) = qdrant
362 .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
363 .await
364 {
365 tracing::warn!("Failed to store key fact: {e:#}");
366 }
367 }
368
369 pub async fn search_key_facts(
375 &self,
376 query: &str,
377 limit: usize,
378 ) -> Result<Vec<String>, MemoryError> {
379 let Some(qdrant) = &self.qdrant else {
380 tracing::debug!("key-facts: skipped, no vector store");
381 return Ok(Vec::new());
382 };
383 if !self.effective_embed_provider().supports_embeddings() {
384 tracing::debug!("key-facts: skipped, no embedding support");
385 return Ok(Vec::new());
386 }
387
388 let vector = match tokio::time::timeout(
389 std::time::Duration::from_secs(5),
390 self.effective_embed_provider().embed(query),
391 )
392 .await
393 {
394 Ok(Ok(v)) => v,
395 Ok(Err(e)) => return Err(e.into()),
396 Err(_) => {
397 tracing::warn!("search_key_facts: embed timed out, returning empty results");
398 return Ok(Vec::new());
399 }
400 };
401 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
402 qdrant
403 .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
404 .await?;
405
406 let points = qdrant
407 .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
408 .await?;
409
410 tracing::debug!(results = points.len(), limit, "key-facts: search complete");
411
412 let facts = points
413 .into_iter()
414 .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
415 .collect();
416
417 Ok(facts)
418 }
419
420 pub async fn search_document_collection(
430 &self,
431 collection: &str,
432 query: &str,
433 limit: usize,
434 ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
435 let Some(qdrant) = &self.qdrant else {
436 return Ok(Vec::new());
437 };
438 if !self.effective_embed_provider().supports_embeddings() {
439 return Ok(Vec::new());
440 }
441 if !qdrant.collection_exists(collection).await? {
442 return Ok(Vec::new());
443 }
444 let vector = match tokio::time::timeout(
445 std::time::Duration::from_secs(5),
446 self.effective_embed_provider().embed(query),
447 )
448 .await
449 {
450 Ok(Ok(v)) => v,
451 Ok(Err(e)) => return Err(e.into()),
452 Err(_) => {
453 tracing::warn!(
454 "search_document_collection: embed timed out, returning empty results"
455 );
456 return Ok(Vec::new());
457 }
458 };
459 let results = qdrant
460 .search_collection(collection, &vector, limit, None)
461 .await?;
462
463 tracing::debug!(
464 results = results.len(),
465 limit,
466 collection,
467 "document-collection: search complete"
468 );
469
470 Ok(results)
471 }
472}
473
474pub(crate) fn is_policy_decision_fact(fact: &str) -> bool {
480 const MARKERS: &[&str] = &[
481 "blocked",
482 "skipped",
483 "cannot access",
484 "security polic",
485 "utility polic",
486 "not allowed",
487 "permission denied",
488 "access denied",
489 "was denied",
490 ];
491 let lower = fact.to_lowercase();
492 MARKERS.iter().any(|m| lower.contains(m))
493}