Skip to main content

zeph_memory/semantic/
summarization.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
5
6use super::{KEY_FACTS_COLLECTION, SemanticMemory};
7use crate::embedding_store::MessageKind;
8use crate::error::MemoryError;
9use crate::types::{ConversationId, MessageId};
10
11#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
12pub struct StructuredSummary {
13    pub summary: String,
14    pub key_facts: Vec<String>,
15    pub entities: Vec<String>,
16}
17
18#[derive(Debug, Clone)]
19pub struct Summary {
20    pub id: i64,
21    pub conversation_id: ConversationId,
22    pub content: String,
23    /// `None` for session-level summaries (e.g. shutdown summaries) with no tracked message range.
24    pub first_message_id: Option<MessageId>,
25    /// `None` for session-level summaries (e.g. shutdown summaries) with no tracked message range.
26    pub last_message_id: Option<MessageId>,
27    pub token_estimate: i64,
28}
29
30#[must_use]
31pub fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
32    let mut prompt = String::from(
33        "Summarize the following conversation. Extract key facts, decisions, entities, \
34         and context needed to continue the conversation.\n\n\
35         Respond in JSON with fields: summary (string), key_facts (list of strings), \
36         entities (list of strings).\n\nConversation:\n",
37    );
38
39    for (_, role, content) in messages {
40        prompt.push_str(role);
41        prompt.push_str(": ");
42        prompt.push_str(content);
43        prompt.push('\n');
44    }
45
46    prompt
47}
48
49impl SemanticMemory {
50    /// Load all summaries for a conversation.
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the query fails.
55    pub async fn load_summaries(
56        &self,
57        conversation_id: ConversationId,
58    ) -> Result<Vec<Summary>, MemoryError> {
59        let rows = self.sqlite.load_summaries(conversation_id).await?;
60        let summaries = rows
61            .into_iter()
62            .map(
63                |(
64                    id,
65                    conversation_id,
66                    content,
67                    first_message_id,
68                    last_message_id,
69                    token_estimate,
70                )| {
71                    Summary {
72                        id,
73                        conversation_id,
74                        content,
75                        first_message_id,
76                        last_message_id,
77                        token_estimate,
78                    }
79                },
80            )
81            .collect();
82        Ok(summaries)
83    }
84
85    /// Generate a summary of the oldest unsummarized messages.
86    ///
87    /// Returns `Ok(None)` if there are not enough messages to summarize.
88    ///
89    /// # Errors
90    ///
91    /// Returns an error if LLM call or database operation fails.
92    pub async fn summarize(
93        &self,
94        conversation_id: ConversationId,
95        message_count: usize,
96    ) -> Result<Option<i64>, MemoryError> {
97        let total = self.sqlite.count_messages(conversation_id).await?;
98
99        if total <= i64::try_from(message_count)? {
100            return Ok(None);
101        }
102
103        let after_id = self
104            .sqlite
105            .latest_summary_last_message_id(conversation_id)
106            .await?
107            .unwrap_or(MessageId(0));
108
109        let messages = self
110            .sqlite
111            .load_messages_range(conversation_id, after_id, message_count)
112            .await?;
113
114        if messages.is_empty() {
115            return Ok(None);
116        }
117
118        let prompt = build_summarization_prompt(&messages);
119        let chat_messages = vec![Message {
120            role: Role::User,
121            content: prompt,
122            parts: vec![],
123            metadata: MessageMetadata::default(),
124        }];
125
126        let structured = match self
127            .provider
128            .chat_typed_erased::<StructuredSummary>(&chat_messages)
129            .await
130        {
131            Ok(s) => s,
132            Err(e) => {
133                tracing::warn!(
134                    "structured summarization failed, falling back to plain text: {e:#}"
135                );
136                let plain = self.provider.chat(&chat_messages).await?;
137                StructuredSummary {
138                    summary: plain,
139                    key_facts: vec![],
140                    entities: vec![],
141                }
142            }
143        };
144        let summary_text = &structured.summary;
145
146        let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
147        let first_message_id = messages[0].0;
148        let last_message_id = messages[messages.len() - 1].0;
149
150        let summary_id = self
151            .sqlite
152            .save_summary(
153                conversation_id,
154                summary_text,
155                Some(first_message_id),
156                Some(last_message_id),
157                token_estimate,
158            )
159            .await?;
160
161        if let Some(qdrant) = &self.qdrant
162            && self.provider.supports_embeddings()
163        {
164            match self.provider.embed(summary_text).await {
165                Ok(vector) => {
166                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
167                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
168                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
169                    } else if let Err(e) = qdrant
170                        .store(
171                            MessageId(summary_id),
172                            conversation_id,
173                            "system",
174                            vector,
175                            MessageKind::Summary,
176                            &self.embedding_model,
177                            0,
178                        )
179                        .await
180                    {
181                        tracing::warn!("Failed to embed summary: {e:#}");
182                    }
183                }
184                Err(e) => {
185                    tracing::warn!("Failed to generate summary embedding: {e:#}");
186                }
187            }
188        }
189
190        if !structured.key_facts.is_empty() {
191            self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
192                .await;
193        }
194
195        Ok(Some(summary_id))
196    }
197
198    pub(super) async fn store_key_facts(
199        &self,
200        conversation_id: ConversationId,
201        source_summary_id: i64,
202        key_facts: &[String],
203    ) {
204        let Some(qdrant) = &self.qdrant else {
205            return;
206        };
207        if !self.provider.supports_embeddings() {
208            return;
209        }
210
211        // Filter out transient policy-decision facts that describe a blocked or denied action.
212        // These reflect the agent's state at a single point in time and must not be recalled
213        // as stable world facts in future turns — doing so causes the agent to skip valid calls.
214        let filtered: Vec<&str> = key_facts
215            .iter()
216            .filter(|f| !is_policy_decision_fact(f.as_str()))
217            .map(String::as_str)
218            .collect();
219
220        let Some(first_fact) = filtered.first().copied() else {
221            return;
222        };
223        let first_vector = match self.provider.embed(first_fact).await {
224            Ok(v) => v,
225            Err(e) => {
226                tracing::warn!("Failed to embed key fact: {e:#}");
227                return;
228            }
229        };
230        let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
231        if let Err(e) = qdrant
232            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
233            .await
234        {
235            tracing::warn!("Failed to ensure key_facts collection: {e:#}");
236            return;
237        }
238
239        let threshold = self.key_facts_dedup_threshold;
240        self.store_key_fact_if_unique(
241            qdrant,
242            conversation_id,
243            source_summary_id,
244            first_fact,
245            first_vector,
246            threshold,
247        )
248        .await;
249
250        for fact in filtered[1..].iter().copied() {
251            match self.provider.embed(fact).await {
252                Ok(vector) => {
253                    self.store_key_fact_if_unique(
254                        qdrant,
255                        conversation_id,
256                        source_summary_id,
257                        fact,
258                        vector,
259                        threshold,
260                    )
261                    .await;
262                }
263                Err(e) => {
264                    tracing::warn!("Failed to embed key fact: {e:#}");
265                }
266            }
267        }
268    }
269
270    async fn store_key_fact_if_unique(
271        &self,
272        qdrant: &crate::embedding_store::EmbeddingStore,
273        conversation_id: ConversationId,
274        source_summary_id: i64,
275        fact: &str,
276        vector: Vec<f32>,
277        threshold: f32,
278    ) {
279        match qdrant
280            .search_collection(KEY_FACTS_COLLECTION, &vector, 1, None)
281            .await
282        {
283            Ok(hits) if hits.first().is_some_and(|h| h.score >= threshold) => {
284                tracing::debug!(
285                    score = hits[0].score,
286                    threshold,
287                    "key-facts: skipping near-duplicate fact"
288                );
289                return;
290            }
291            Ok(_) => {}
292            Err(e) => {
293                tracing::warn!("key-facts: dedup search failed, storing anyway: {e:#}");
294            }
295        }
296
297        let payload = serde_json::json!({
298            "conversation_id": conversation_id.0,
299            "fact_text": fact,
300            "source_summary_id": source_summary_id,
301        });
302        if let Err(e) = qdrant
303            .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
304            .await
305        {
306            tracing::warn!("Failed to store key fact: {e:#}");
307        }
308    }
309
310    /// Search key facts extracted from conversation summaries.
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if embedding or Qdrant search fails.
315    pub async fn search_key_facts(
316        &self,
317        query: &str,
318        limit: usize,
319    ) -> Result<Vec<String>, MemoryError> {
320        let Some(qdrant) = &self.qdrant else {
321            tracing::debug!("key-facts: skipped, no vector store");
322            return Ok(Vec::new());
323        };
324        if !self.provider.supports_embeddings() {
325            tracing::debug!("key-facts: skipped, no embedding support");
326            return Ok(Vec::new());
327        }
328
329        let vector = self.provider.embed(query).await?;
330        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
331        qdrant
332            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
333            .await?;
334
335        let points = qdrant
336            .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
337            .await?;
338
339        tracing::debug!(results = points.len(), limit, "key-facts: search complete");
340
341        let facts = points
342            .into_iter()
343            .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
344            .collect();
345
346        Ok(facts)
347    }
348
349    /// Search a named document collection by semantic similarity.
350    ///
351    /// Returns up to `limit` scored vector points whose payloads contain ingested document chunks.
352    /// Returns an empty vec when Qdrant is unavailable, the collection does not exist,
353    /// or the provider does not support embeddings.
354    ///
355    /// # Errors
356    ///
357    /// Returns an error if embedding generation or Qdrant search fails.
358    pub async fn search_document_collection(
359        &self,
360        collection: &str,
361        query: &str,
362        limit: usize,
363    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
364        let Some(qdrant) = &self.qdrant else {
365            return Ok(Vec::new());
366        };
367        if !self.provider.supports_embeddings() {
368            return Ok(Vec::new());
369        }
370        if !qdrant.collection_exists(collection).await? {
371            return Ok(Vec::new());
372        }
373        let vector = self.provider.embed(query).await?;
374        let results = qdrant
375            .search_collection(collection, &vector, limit, None)
376            .await?;
377
378        tracing::debug!(
379            results = results.len(),
380            limit,
381            collection,
382            "document-collection: search complete"
383        );
384
385        Ok(results)
386    }
387}
388
389/// Returns `true` when a fact string describes a transient policy or permission decision.
390///
391/// Facts like "reading /etc/passwd was blocked by utility policy" are snapshots of a
392/// single-turn enforcement state and must not be recalled as durable world knowledge.
393/// Storing them causes the agent to believe a tool is permanently unavailable.
394pub(crate) fn is_policy_decision_fact(fact: &str) -> bool {
395    const MARKERS: &[&str] = &[
396        "blocked",
397        "skipped",
398        "cannot access",
399        "security polic",
400        "utility polic",
401        "not allowed",
402        "permission denied",
403        "access denied",
404        "was denied",
405    ];
406    let lower = fact.to_lowercase();
407    MARKERS.iter().any(|m| lower.contains(m))
408}