Skip to main content

zeph_memory/semantic/
summarization.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
5
6use super::{KEY_FACTS_COLLECTION, SemanticMemory};
7use crate::embedding_store::MessageKind;
8use crate::error::MemoryError;
9use crate::types::{ConversationId, MessageId};
10
11#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
12pub struct StructuredSummary {
13    pub summary: String,
14    pub key_facts: Vec<String>,
15    pub entities: Vec<String>,
16}
17
18#[derive(Debug, Clone)]
19pub struct Summary {
20    pub id: i64,
21    pub conversation_id: ConversationId,
22    pub content: String,
23    /// `None` for session-level summaries (e.g. shutdown summaries) with no tracked message range.
24    pub first_message_id: Option<MessageId>,
25    /// `None` for session-level summaries (e.g. shutdown summaries) with no tracked message range.
26    pub last_message_id: Option<MessageId>,
27    pub token_estimate: i64,
28}
29
30#[must_use]
31pub fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
32    let mut prompt = String::from(
33        "Summarize the following conversation. Extract key facts, decisions, entities, \
34         and context needed to continue the conversation.\n\n\
35         Respond in JSON with fields: summary (string), key_facts (list of strings), \
36         entities (list of strings).\n\nConversation:\n",
37    );
38
39    for (_, role, content) in messages {
40        prompt.push_str(role);
41        prompt.push_str(": ");
42        prompt.push_str(content);
43        prompt.push('\n');
44    }
45
46    prompt
47}
48
49impl SemanticMemory {
50    /// Load all summaries for a conversation.
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the query fails.
55    pub async fn load_summaries(
56        &self,
57        conversation_id: ConversationId,
58    ) -> Result<Vec<Summary>, MemoryError> {
59        let rows = self.sqlite.load_summaries(conversation_id).await?;
60        let summaries = rows
61            .into_iter()
62            .map(
63                |(
64                    id,
65                    conversation_id,
66                    content,
67                    first_message_id,
68                    last_message_id,
69                    token_estimate,
70                )| {
71                    Summary {
72                        id,
73                        conversation_id,
74                        content,
75                        first_message_id,
76                        last_message_id,
77                        token_estimate,
78                    }
79                },
80            )
81            .collect();
82        Ok(summaries)
83    }
84
85    /// Generate a summary of the oldest unsummarized messages.
86    ///
87    /// Returns `Ok(None)` if there are not enough messages to summarize.
88    ///
89    /// # Errors
90    ///
91    /// Returns an error if LLM call or database operation fails.
92    #[cfg_attr(
93        feature = "profiling",
94        tracing::instrument(name = "memory.summarize", skip_all, fields(input_msgs = %message_count, output_len = tracing::field::Empty))
95    )]
96    pub async fn summarize(
97        &self,
98        conversation_id: ConversationId,
99        message_count: usize,
100    ) -> Result<Option<i64>, MemoryError> {
101        let total = self.sqlite.count_messages(conversation_id).await?;
102
103        if total <= i64::try_from(message_count)? {
104            return Ok(None);
105        }
106
107        let after_id = self
108            .sqlite
109            .latest_summary_last_message_id(conversation_id)
110            .await?
111            .unwrap_or(MessageId(0));
112
113        let messages = self
114            .sqlite
115            .load_messages_range(conversation_id, after_id, message_count)
116            .await?;
117
118        if messages.is_empty() {
119            return Ok(None);
120        }
121
122        let prompt = build_summarization_prompt(&messages);
123        let chat_messages = vec![Message {
124            role: Role::User,
125            content: prompt,
126            parts: vec![],
127            metadata: MessageMetadata::default(),
128        }];
129
130        let structured = match self
131            .provider
132            .chat_typed_erased::<StructuredSummary>(&chat_messages)
133            .await
134        {
135            Ok(s) => s,
136            Err(e) => {
137                tracing::warn!(
138                    "structured summarization failed, falling back to plain text: {e:#}"
139                );
140                let plain = self.provider.chat(&chat_messages).await?;
141                StructuredSummary {
142                    summary: plain,
143                    key_facts: vec![],
144                    entities: vec![],
145                }
146            }
147        };
148        let summary_text = &structured.summary;
149
150        let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
151        let first_message_id = messages[0].0;
152        let last_message_id = messages[messages.len() - 1].0;
153
154        let summary_id = self
155            .sqlite
156            .save_summary(
157                conversation_id,
158                summary_text,
159                Some(first_message_id),
160                Some(last_message_id),
161                token_estimate,
162            )
163            .await?;
164
165        if let Some(qdrant) = &self.qdrant
166            && self.provider.supports_embeddings()
167        {
168            match self.provider.embed(summary_text).await {
169                Ok(vector) => {
170                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
171                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
172                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
173                    } else if let Err(e) = qdrant
174                        .store(
175                            MessageId(summary_id),
176                            conversation_id,
177                            "system",
178                            vector,
179                            MessageKind::Summary,
180                            &self.embedding_model,
181                            0,
182                        )
183                        .await
184                    {
185                        tracing::warn!("Failed to embed summary: {e:#}");
186                    }
187                }
188                Err(e) => {
189                    tracing::warn!("Failed to generate summary embedding: {e:#}");
190                }
191            }
192        }
193
194        if !structured.key_facts.is_empty() {
195            self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
196                .await;
197        }
198
199        Ok(Some(summary_id))
200    }
201
202    pub(super) async fn store_key_facts(
203        &self,
204        conversation_id: ConversationId,
205        source_summary_id: i64,
206        key_facts: &[String],
207    ) {
208        let Some(qdrant) = &self.qdrant else {
209            return;
210        };
211        if !self.provider.supports_embeddings() {
212            return;
213        }
214
215        // Filter out transient policy-decision facts that describe a blocked or denied action.
216        // These reflect the agent's state at a single point in time and must not be recalled
217        // as stable world facts in future turns — doing so causes the agent to skip valid calls.
218        let filtered: Vec<&str> = key_facts
219            .iter()
220            .filter(|f| !is_policy_decision_fact(f.as_str()))
221            .map(String::as_str)
222            .collect();
223
224        let Some(first_fact) = filtered.first().copied() else {
225            return;
226        };
227        let first_vector = match self.provider.embed(first_fact).await {
228            Ok(v) => v,
229            Err(e) => {
230                tracing::warn!("Failed to embed key fact: {e:#}");
231                return;
232            }
233        };
234        let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
235        if let Err(e) = qdrant
236            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
237            .await
238        {
239            tracing::warn!("Failed to ensure key_facts collection: {e:#}");
240            return;
241        }
242
243        let threshold = self.key_facts_dedup_threshold;
244        self.store_key_fact_if_unique(
245            qdrant,
246            conversation_id,
247            source_summary_id,
248            first_fact,
249            first_vector,
250            threshold,
251        )
252        .await;
253
254        for fact in filtered[1..].iter().copied() {
255            match self.provider.embed(fact).await {
256                Ok(vector) => {
257                    self.store_key_fact_if_unique(
258                        qdrant,
259                        conversation_id,
260                        source_summary_id,
261                        fact,
262                        vector,
263                        threshold,
264                    )
265                    .await;
266                }
267                Err(e) => {
268                    tracing::warn!("Failed to embed key fact: {e:#}");
269                }
270            }
271        }
272    }
273
274    async fn store_key_fact_if_unique(
275        &self,
276        qdrant: &crate::embedding_store::EmbeddingStore,
277        conversation_id: ConversationId,
278        source_summary_id: i64,
279        fact: &str,
280        vector: Vec<f32>,
281        threshold: f32,
282    ) {
283        match qdrant
284            .search_collection(KEY_FACTS_COLLECTION, &vector, 1, None)
285            .await
286        {
287            Ok(hits) if hits.first().is_some_and(|h| h.score >= threshold) => {
288                tracing::debug!(
289                    score = hits[0].score,
290                    threshold,
291                    "key-facts: skipping near-duplicate fact"
292                );
293                return;
294            }
295            Ok(_) => {}
296            Err(e) => {
297                tracing::warn!("key-facts: dedup search failed, storing anyway: {e:#}");
298            }
299        }
300
301        let payload = serde_json::json!({
302            "conversation_id": conversation_id.0,
303            "fact_text": fact,
304            "source_summary_id": source_summary_id,
305        });
306        if let Err(e) = qdrant
307            .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
308            .await
309        {
310            tracing::warn!("Failed to store key fact: {e:#}");
311        }
312    }
313
314    /// Search key facts extracted from conversation summaries.
315    ///
316    /// # Errors
317    ///
318    /// Returns an error if embedding or Qdrant search fails.
319    pub async fn search_key_facts(
320        &self,
321        query: &str,
322        limit: usize,
323    ) -> Result<Vec<String>, MemoryError> {
324        let Some(qdrant) = &self.qdrant else {
325            tracing::debug!("key-facts: skipped, no vector store");
326            return Ok(Vec::new());
327        };
328        if !self.provider.supports_embeddings() {
329            tracing::debug!("key-facts: skipped, no embedding support");
330            return Ok(Vec::new());
331        }
332
333        let vector = self.provider.embed(query).await?;
334        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
335        qdrant
336            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
337            .await?;
338
339        let points = qdrant
340            .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
341            .await?;
342
343        tracing::debug!(results = points.len(), limit, "key-facts: search complete");
344
345        let facts = points
346            .into_iter()
347            .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
348            .collect();
349
350        Ok(facts)
351    }
352
353    /// Search a named document collection by semantic similarity.
354    ///
355    /// Returns up to `limit` scored vector points whose payloads contain ingested document chunks.
356    /// Returns an empty vec when Qdrant is unavailable, the collection does not exist,
357    /// or the provider does not support embeddings.
358    ///
359    /// # Errors
360    ///
361    /// Returns an error if embedding generation or Qdrant search fails.
362    pub async fn search_document_collection(
363        &self,
364        collection: &str,
365        query: &str,
366        limit: usize,
367    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
368        let Some(qdrant) = &self.qdrant else {
369            return Ok(Vec::new());
370        };
371        if !self.provider.supports_embeddings() {
372            return Ok(Vec::new());
373        }
374        if !qdrant.collection_exists(collection).await? {
375            return Ok(Vec::new());
376        }
377        let vector = self.provider.embed(query).await?;
378        let results = qdrant
379            .search_collection(collection, &vector, limit, None)
380            .await?;
381
382        tracing::debug!(
383            results = results.len(),
384            limit,
385            collection,
386            "document-collection: search complete"
387        );
388
389        Ok(results)
390    }
391}
392
393/// Returns `true` when a fact string describes a transient policy or permission decision.
394///
395/// Facts like "reading /etc/passwd was blocked by utility policy" are snapshots of a
396/// single-turn enforcement state and must not be recalled as durable world knowledge.
397/// Storing them causes the agent to believe a tool is permanently unavailable.
398pub(crate) fn is_policy_decision_fact(fact: &str) -> bool {
399    const MARKERS: &[&str] = &[
400        "blocked",
401        "skipped",
402        "cannot access",
403        "security polic",
404        "utility polic",
405        "not allowed",
406        "permission denied",
407        "access denied",
408        "was denied",
409    ];
410    let lower = fact.to_lowercase();
411    MARKERS.iter().any(|m| lower.contains(m))
412}