Skip to main content

zeph_memory/semantic/
summarization.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
5
6use super::{KEY_FACTS_COLLECTION, SemanticMemory};
7use crate::embedding_store::MessageKind;
8use crate::error::MemoryError;
9use crate::types::{ConversationId, MessageId};
10
11#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
12pub struct StructuredSummary {
13    pub summary: String,
14    pub key_facts: Vec<String>,
15    pub entities: Vec<String>,
16}
17
18#[derive(Debug, Clone)]
19pub struct Summary {
20    pub id: i64,
21    pub conversation_id: ConversationId,
22    pub content: String,
23    /// `None` for session-level summaries (e.g. shutdown summaries) with no tracked message range.
24    pub first_message_id: Option<MessageId>,
25    /// `None` for session-level summaries (e.g. shutdown summaries) with no tracked message range.
26    pub last_message_id: Option<MessageId>,
27    pub token_estimate: i64,
28}
29
30#[must_use]
31pub fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
32    let mut prompt = String::from(
33        "Summarize the following conversation. Extract key facts, decisions, entities, \
34         and context needed to continue the conversation.\n\n\
35         Respond in JSON with fields: summary (string), key_facts (list of strings), \
36         entities (list of strings).\n\nConversation:\n",
37    );
38
39    for (_, role, content) in messages {
40        prompt.push_str(role);
41        prompt.push_str(": ");
42        prompt.push_str(content);
43        prompt.push('\n');
44    }
45
46    prompt
47}
48
49impl SemanticMemory {
50    /// Load all summaries for a conversation.
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the query fails.
55    pub async fn load_summaries(
56        &self,
57        conversation_id: ConversationId,
58    ) -> Result<Vec<Summary>, MemoryError> {
59        let rows = self.sqlite.load_summaries(conversation_id).await?;
60        let summaries = rows
61            .into_iter()
62            .map(
63                |(
64                    id,
65                    conversation_id,
66                    content,
67                    first_message_id,
68                    last_message_id,
69                    token_estimate,
70                )| {
71                    Summary {
72                        id,
73                        conversation_id,
74                        content,
75                        first_message_id,
76                        last_message_id,
77                        token_estimate,
78                    }
79                },
80            )
81            .collect();
82        Ok(summaries)
83    }
84
85    /// Generate a summary of the oldest unsummarized messages.
86    ///
87    /// Returns `Ok(None)` if there are not enough messages to summarize.
88    ///
89    /// # Errors
90    ///
91    /// Returns an error if LLM call or database operation fails.
92    #[tracing::instrument(name = "memory.summarize", skip_all, fields(input_msgs = %message_count, output_len = tracing::field::Empty))]
93    pub async fn summarize(
94        &self,
95        conversation_id: ConversationId,
96        message_count: usize,
97    ) -> Result<Option<i64>, MemoryError> {
98        let total = self.sqlite.count_messages(conversation_id).await?;
99
100        if total <= i64::try_from(message_count)? {
101            return Ok(None);
102        }
103
104        let after_id = self
105            .sqlite
106            .latest_summary_last_message_id(conversation_id)
107            .await?
108            .unwrap_or(MessageId(0));
109
110        let messages = self
111            .sqlite
112            .load_messages_range(conversation_id, after_id, message_count)
113            .await?;
114
115        if messages.is_empty() {
116            return Ok(None);
117        }
118
119        let prompt = build_summarization_prompt(&messages);
120        let chat_messages = vec![Message {
121            role: Role::User,
122            content: prompt,
123            parts: vec![],
124            metadata: MessageMetadata::default(),
125        }];
126
127        let structured = self.call_summarization_llm(&chat_messages).await?;
128        let summary_text = &structured.summary;
129
130        let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
131        let first_message_id = messages[0].0;
132        let last_message_id = messages[messages.len() - 1].0;
133
134        let summary_id = self
135            .sqlite
136            .save_summary(
137                conversation_id,
138                summary_text,
139                Some(first_message_id),
140                Some(last_message_id),
141                token_estimate,
142            )
143            .await?;
144
145        if let Some(qdrant) = &self.qdrant
146            && self.effective_embed_provider().supports_embeddings()
147        {
148            match tokio::time::timeout(
149                std::time::Duration::from_secs(5),
150                self.effective_embed_provider().embed(summary_text),
151            )
152            .await
153            {
154                Ok(Ok(vector)) => {
155                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
156                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
157                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
158                    } else if let Err(e) = qdrant
159                        .store(
160                            MessageId(summary_id),
161                            conversation_id,
162                            "system",
163                            vector,
164                            MessageKind::Summary,
165                            &self.embedding_model,
166                            0,
167                        )
168                        .await
169                    {
170                        tracing::warn!("Failed to embed summary: {e:#}");
171                    }
172                }
173                Ok(Err(e)) => {
174                    tracing::warn!("Failed to generate summary embedding: {e:#}");
175                }
176                Err(_) => {
177                    tracing::warn!("summarize: embed timed out for summary text — skipping store");
178                }
179            }
180        }
181
182        if !structured.key_facts.is_empty() {
183            self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
184                .await;
185        }
186
187        Ok(Some(summary_id))
188    }
189
190    /// Call the LLM to produce a [`StructuredSummary`], falling back to plain text on parse error.
191    ///
192    /// Both the structured and fallback calls are bounded by `summarization_llm_timeout_secs`.
193    ///
194    /// # Errors
195    ///
196    /// Returns [`MemoryError::Timeout`] if the LLM exceeds the deadline, or
197    /// [`MemoryError::Llm`] if the provider returns an error.
198    async fn call_summarization_llm(
199        &self,
200        chat_messages: &[Message],
201    ) -> Result<StructuredSummary, MemoryError> {
202        let timeout_secs = self.summarization_llm_timeout_secs;
203        let timeout = std::time::Duration::from_secs(timeout_secs);
204        match tokio::time::timeout(
205            timeout,
206            self.provider
207                .chat_typed_erased::<StructuredSummary>(chat_messages),
208        )
209        .await
210        {
211            Ok(Ok(s)) => Ok(s),
212            Ok(Err(e)) => {
213                tracing::warn!(
214                    "structured summarization failed, falling back to plain text: {e:#}"
215                );
216                match tokio::time::timeout(timeout, self.provider.chat(chat_messages)).await {
217                    Ok(Ok(plain)) => Ok(StructuredSummary {
218                        summary: plain,
219                        key_facts: vec![],
220                        entities: vec![],
221                    }),
222                    Ok(Err(e)) => Err(MemoryError::Llm(e)),
223                    Err(_elapsed) => {
224                        tracing::warn!(
225                            "summarization: plain text fallback LLM call timed out after {timeout_secs}s"
226                        );
227                        Err(MemoryError::Timeout("LLM call timed out".into()))
228                    }
229                }
230            }
231            Err(_elapsed) => {
232                tracing::warn!(
233                    "summarization: structured LLM call timed out after {timeout_secs}s"
234                );
235                Err(MemoryError::Timeout("LLM call timed out".into()))
236            }
237        }
238    }
239
240    pub(super) async fn store_key_facts(
241        &self,
242        conversation_id: ConversationId,
243        source_summary_id: i64,
244        key_facts: &[String],
245    ) {
246        let Some(qdrant) = &self.qdrant else {
247            return;
248        };
249        if !self.effective_embed_provider().supports_embeddings() {
250            return;
251        }
252
253        // Filter out transient policy-decision facts that describe a blocked or denied action.
254        // These reflect the agent's state at a single point in time and must not be recalled
255        // as stable world facts in future turns — doing so causes the agent to skip valid calls.
256        let filtered: Vec<&str> = key_facts
257            .iter()
258            .filter(|f| !is_policy_decision_fact(f.as_str()))
259            .map(String::as_str)
260            .collect();
261
262        let Some(first_fact) = filtered.first().copied() else {
263            return;
264        };
265        let first_vector = match tokio::time::timeout(
266            std::time::Duration::from_secs(5),
267            self.effective_embed_provider().embed(first_fact),
268        )
269        .await
270        {
271            Ok(Ok(v)) => v,
272            Ok(Err(e)) => {
273                tracing::warn!("Failed to embed key fact: {e:#}");
274                return;
275            }
276            Err(_) => {
277                tracing::warn!("store_key_facts: embed timed out for first fact — skipping");
278                return;
279            }
280        };
281        let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
282        if let Err(e) = qdrant
283            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
284            .await
285        {
286            tracing::warn!("Failed to ensure key_facts collection: {e:#}");
287            return;
288        }
289
290        let threshold = self.key_facts_dedup_threshold;
291        self.store_key_fact_if_unique(
292            qdrant,
293            conversation_id,
294            source_summary_id,
295            first_fact,
296            first_vector,
297            threshold,
298        )
299        .await;
300
301        for fact in filtered[1..].iter().copied() {
302            match tokio::time::timeout(
303                std::time::Duration::from_secs(5),
304                self.effective_embed_provider().embed(fact),
305            )
306            .await
307            {
308                Ok(Ok(vector)) => {
309                    self.store_key_fact_if_unique(
310                        qdrant,
311                        conversation_id,
312                        source_summary_id,
313                        fact,
314                        vector,
315                        threshold,
316                    )
317                    .await;
318                }
319                Ok(Err(e)) => {
320                    tracing::warn!("Failed to embed key fact: {e:#}");
321                }
322                Err(_) => {
323                    tracing::warn!("store_key_facts: embed timed out for fact — skipping");
324                }
325            }
326        }
327    }
328
329    async fn store_key_fact_if_unique(
330        &self,
331        qdrant: &crate::embedding_store::EmbeddingStore,
332        conversation_id: ConversationId,
333        source_summary_id: i64,
334        fact: &str,
335        vector: Vec<f32>,
336        threshold: f32,
337    ) {
338        match qdrant
339            .search_collection(KEY_FACTS_COLLECTION, &vector, 1, None)
340            .await
341        {
342            Ok(hits) if hits.first().is_some_and(|h| h.score >= threshold) => {
343                tracing::debug!(
344                    score = hits[0].score,
345                    threshold,
346                    "key-facts: skipping near-duplicate fact"
347                );
348                return;
349            }
350            Ok(_) => {}
351            Err(e) => {
352                tracing::warn!("key-facts: dedup search failed, storing anyway: {e:#}");
353            }
354        }
355
356        let payload = serde_json::json!({
357            "conversation_id": conversation_id.0,
358            "fact_text": fact,
359            "source_summary_id": source_summary_id,
360        });
361        if let Err(e) = qdrant
362            .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
363            .await
364        {
365            tracing::warn!("Failed to store key fact: {e:#}");
366        }
367    }
368
369    /// Search key facts extracted from conversation summaries.
370    ///
371    /// # Errors
372    ///
373    /// Returns an error if embedding or Qdrant search fails.
374    pub async fn search_key_facts(
375        &self,
376        query: &str,
377        limit: usize,
378    ) -> Result<Vec<String>, MemoryError> {
379        let Some(qdrant) = &self.qdrant else {
380            tracing::debug!("key-facts: skipped, no vector store");
381            return Ok(Vec::new());
382        };
383        if !self.effective_embed_provider().supports_embeddings() {
384            tracing::debug!("key-facts: skipped, no embedding support");
385            return Ok(Vec::new());
386        }
387
388        let vector = match tokio::time::timeout(
389            std::time::Duration::from_secs(5),
390            self.effective_embed_provider().embed(query),
391        )
392        .await
393        {
394            Ok(Ok(v)) => v,
395            Ok(Err(e)) => return Err(e.into()),
396            Err(_) => {
397                tracing::warn!("search_key_facts: embed timed out, returning empty results");
398                return Ok(Vec::new());
399            }
400        };
401        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
402        qdrant
403            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
404            .await?;
405
406        let points = qdrant
407            .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
408            .await?;
409
410        tracing::debug!(results = points.len(), limit, "key-facts: search complete");
411
412        let facts = points
413            .into_iter()
414            .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
415            .collect();
416
417        Ok(facts)
418    }
419
420    /// Search a named document collection by semantic similarity.
421    ///
422    /// Returns up to `limit` scored vector points whose payloads contain ingested document chunks.
423    /// Returns an empty vec when Qdrant is unavailable, the collection does not exist,
424    /// or the provider does not support embeddings.
425    ///
426    /// # Errors
427    ///
428    /// Returns an error if embedding generation or Qdrant search fails.
429    pub async fn search_document_collection(
430        &self,
431        collection: &str,
432        query: &str,
433        limit: usize,
434    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
435        let Some(qdrant) = &self.qdrant else {
436            return Ok(Vec::new());
437        };
438        if !self.effective_embed_provider().supports_embeddings() {
439            return Ok(Vec::new());
440        }
441        if !qdrant.collection_exists(collection).await? {
442            return Ok(Vec::new());
443        }
444        let vector = match tokio::time::timeout(
445            std::time::Duration::from_secs(5),
446            self.effective_embed_provider().embed(query),
447        )
448        .await
449        {
450            Ok(Ok(v)) => v,
451            Ok(Err(e)) => return Err(e.into()),
452            Err(_) => {
453                tracing::warn!(
454                    "search_document_collection: embed timed out, returning empty results"
455                );
456                return Ok(Vec::new());
457            }
458        };
459        let results = qdrant
460            .search_collection(collection, &vector, limit, None)
461            .await?;
462
463        tracing::debug!(
464            results = results.len(),
465            limit,
466            collection,
467            "document-collection: search complete"
468        );
469
470        Ok(results)
471    }
472}
473
474/// Returns `true` when a fact string describes a transient policy or permission decision.
475///
476/// Facts like "reading /etc/passwd was blocked by utility policy" are snapshots of a
477/// single-turn enforcement state and must not be recalled as durable world knowledge.
478/// Storing them causes the agent to believe a tool is permanently unavailable.
479pub(crate) fn is_policy_decision_fact(fact: &str) -> bool {
480    const MARKERS: &[&str] = &[
481        "blocked",
482        "skipped",
483        "cannot access",
484        "security polic",
485        "utility polic",
486        "not allowed",
487        "permission denied",
488        "access denied",
489        "was denied",
490    ];
491    let lower = fact.to_lowercase();
492    MARKERS.iter().any(|m| lower.contains(m))
493}