Skip to main content

zeph_memory/semantic/
summarization.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
5
6use super::{KEY_FACTS_COLLECTION, SemanticMemory};
7use crate::embedding_store::MessageKind;
8use crate::error::MemoryError;
9use crate::types::{ConversationId, MessageId};
10
11#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
12pub struct StructuredSummary {
13    pub summary: String,
14    pub key_facts: Vec<String>,
15    pub entities: Vec<String>,
16}
17
18#[derive(Debug, Clone)]
19pub struct Summary {
20    pub id: i64,
21    pub conversation_id: ConversationId,
22    pub content: String,
23    /// `None` for session-level summaries (e.g. shutdown summaries) with no tracked message range.
24    pub first_message_id: Option<MessageId>,
25    /// `None` for session-level summaries (e.g. shutdown summaries) with no tracked message range.
26    pub last_message_id: Option<MessageId>,
27    pub token_estimate: i64,
28}
29
30#[must_use]
31pub fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
32    let mut prompt = String::from(
33        "Summarize the following conversation. Extract key facts, decisions, entities, \
34         and context needed to continue the conversation.\n\n\
35         Respond in JSON with fields: summary (string), key_facts (list of strings), \
36         entities (list of strings).\n\nConversation:\n",
37    );
38
39    for (_, role, content) in messages {
40        prompt.push_str(role);
41        prompt.push_str(": ");
42        prompt.push_str(content);
43        prompt.push('\n');
44    }
45
46    prompt
47}
48
49impl SemanticMemory {
50    /// Load all summaries for a conversation.
51    ///
52    /// # Errors
53    ///
54    /// Returns an error if the query fails.
55    pub async fn load_summaries(
56        &self,
57        conversation_id: ConversationId,
58    ) -> Result<Vec<Summary>, MemoryError> {
59        let rows = self.sqlite.load_summaries(conversation_id).await?;
60        let summaries = rows
61            .into_iter()
62            .map(
63                |(
64                    id,
65                    conversation_id,
66                    content,
67                    first_message_id,
68                    last_message_id,
69                    token_estimate,
70                )| {
71                    Summary {
72                        id,
73                        conversation_id,
74                        content,
75                        first_message_id,
76                        last_message_id,
77                        token_estimate,
78                    }
79                },
80            )
81            .collect();
82        Ok(summaries)
83    }
84
85    /// Generate a summary of the oldest unsummarized messages.
86    ///
87    /// Returns `Ok(None)` if there are not enough messages to summarize.
88    ///
89    /// # Errors
90    ///
91    /// Returns an error if LLM call or database operation fails.
92    pub async fn summarize(
93        &self,
94        conversation_id: ConversationId,
95        message_count: usize,
96    ) -> Result<Option<i64>, MemoryError> {
97        let total = self.sqlite.count_messages(conversation_id).await?;
98
99        if total <= i64::try_from(message_count)? {
100            return Ok(None);
101        }
102
103        let after_id = self
104            .sqlite
105            .latest_summary_last_message_id(conversation_id)
106            .await?
107            .unwrap_or(MessageId(0));
108
109        let messages = self
110            .sqlite
111            .load_messages_range(conversation_id, after_id, message_count)
112            .await?;
113
114        if messages.is_empty() {
115            return Ok(None);
116        }
117
118        let prompt = build_summarization_prompt(&messages);
119        let chat_messages = vec![Message {
120            role: Role::User,
121            content: prompt,
122            parts: vec![],
123            metadata: MessageMetadata::default(),
124        }];
125
126        let structured = match self
127            .provider
128            .chat_typed_erased::<StructuredSummary>(&chat_messages)
129            .await
130        {
131            Ok(s) => s,
132            Err(e) => {
133                tracing::warn!(
134                    "structured summarization failed, falling back to plain text: {e:#}"
135                );
136                let plain = self.provider.chat(&chat_messages).await?;
137                StructuredSummary {
138                    summary: plain,
139                    key_facts: vec![],
140                    entities: vec![],
141                }
142            }
143        };
144        let summary_text = &structured.summary;
145
146        let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
147        let first_message_id = messages[0].0;
148        let last_message_id = messages[messages.len() - 1].0;
149
150        let summary_id = self
151            .sqlite
152            .save_summary(
153                conversation_id,
154                summary_text,
155                Some(first_message_id),
156                Some(last_message_id),
157                token_estimate,
158            )
159            .await?;
160
161        if let Some(qdrant) = &self.qdrant
162            && self.provider.supports_embeddings()
163        {
164            match self.provider.embed(summary_text).await {
165                Ok(vector) => {
166                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
167                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
168                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
169                    } else if let Err(e) = qdrant
170                        .store(
171                            MessageId(summary_id),
172                            conversation_id,
173                            "system",
174                            vector,
175                            MessageKind::Summary,
176                            &self.embedding_model,
177                        )
178                        .await
179                    {
180                        tracing::warn!("Failed to embed summary: {e:#}");
181                    }
182                }
183                Err(e) => {
184                    tracing::warn!("Failed to generate summary embedding: {e:#}");
185                }
186            }
187        }
188
189        if !structured.key_facts.is_empty() {
190            self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
191                .await;
192        }
193
194        Ok(Some(summary_id))
195    }
196
197    pub(super) async fn store_key_facts(
198        &self,
199        conversation_id: ConversationId,
200        source_summary_id: i64,
201        key_facts: &[String],
202    ) {
203        let Some(qdrant) = &self.qdrant else {
204            return;
205        };
206        if !self.provider.supports_embeddings() {
207            return;
208        }
209
210        let Some(first_fact) = key_facts.first() else {
211            return;
212        };
213        let first_vector = match self.provider.embed(first_fact).await {
214            Ok(v) => v,
215            Err(e) => {
216                tracing::warn!("Failed to embed key fact: {e:#}");
217                return;
218            }
219        };
220        let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
221        if let Err(e) = qdrant
222            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
223            .await
224        {
225            tracing::warn!("Failed to ensure key_facts collection: {e:#}");
226            return;
227        }
228
229        let first_payload = serde_json::json!({
230            "conversation_id": conversation_id.0,
231            "fact_text": first_fact,
232            "source_summary_id": source_summary_id,
233        });
234        if let Err(e) = qdrant
235            .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
236            .await
237        {
238            tracing::warn!("Failed to store key fact: {e:#}");
239        }
240
241        for fact in &key_facts[1..] {
242            match self.provider.embed(fact).await {
243                Ok(vector) => {
244                    let payload = serde_json::json!({
245                        "conversation_id": conversation_id.0,
246                        "fact_text": fact,
247                        "source_summary_id": source_summary_id,
248                    });
249                    if let Err(e) = qdrant
250                        .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
251                        .await
252                    {
253                        tracing::warn!("Failed to store key fact: {e:#}");
254                    }
255                }
256                Err(e) => {
257                    tracing::warn!("Failed to embed key fact: {e:#}");
258                }
259            }
260        }
261    }
262
263    /// Search key facts extracted from conversation summaries.
264    ///
265    /// # Errors
266    ///
267    /// Returns an error if embedding or Qdrant search fails.
268    pub async fn search_key_facts(
269        &self,
270        query: &str,
271        limit: usize,
272    ) -> Result<Vec<String>, MemoryError> {
273        let Some(qdrant) = &self.qdrant else {
274            tracing::debug!("key-facts: skipped, no vector store");
275            return Ok(Vec::new());
276        };
277        if !self.provider.supports_embeddings() {
278            tracing::debug!("key-facts: skipped, no embedding support");
279            return Ok(Vec::new());
280        }
281
282        let vector = self.provider.embed(query).await?;
283        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
284        qdrant
285            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
286            .await?;
287
288        let points = qdrant
289            .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
290            .await?;
291
292        tracing::debug!(results = points.len(), limit, "key-facts: search complete");
293
294        let facts = points
295            .into_iter()
296            .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
297            .collect();
298
299        Ok(facts)
300    }
301
302    /// Search a named document collection by semantic similarity.
303    ///
304    /// Returns up to `limit` scored vector points whose payloads contain ingested document chunks.
305    /// Returns an empty vec when Qdrant is unavailable, the collection does not exist,
306    /// or the provider does not support embeddings.
307    ///
308    /// # Errors
309    ///
310    /// Returns an error if embedding generation or Qdrant search fails.
311    pub async fn search_document_collection(
312        &self,
313        collection: &str,
314        query: &str,
315        limit: usize,
316    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
317        let Some(qdrant) = &self.qdrant else {
318            return Ok(Vec::new());
319        };
320        if !self.provider.supports_embeddings() {
321            return Ok(Vec::new());
322        }
323        if !qdrant.collection_exists(collection).await? {
324            return Ok(Vec::new());
325        }
326        let vector = self.provider.embed(query).await?;
327        let results = qdrant
328            .search_collection(collection, &vector, limit, None)
329            .await?;
330
331        tracing::debug!(
332            results = results.len(),
333            limit,
334            collection,
335            "document-collection: search complete"
336        );
337
338        Ok(results)
339    }
340}