Skip to main content

zeph_memory/
semantic.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8
9use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
10use crate::error::MemoryError;
11use crate::sqlite::SqliteStore;
12use crate::token_counter::TokenCounter;
13use crate::types::{ConversationId, MessageId};
14use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
15
16const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
17const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
18
19#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
20pub struct StructuredSummary {
21    pub summary: String,
22    pub key_facts: Vec<String>,
23    pub entities: Vec<String>,
24}
25
26#[derive(Debug)]
27pub struct RecalledMessage {
28    pub message: Message,
29    pub score: f32,
30}
31
32#[derive(Debug, Clone)]
33pub struct Summary {
34    pub id: i64,
35    pub conversation_id: ConversationId,
36    pub content: String,
37    pub first_message_id: MessageId,
38    pub last_message_id: MessageId,
39    pub token_estimate: i64,
40}
41
42#[derive(Debug, Clone)]
43pub struct SessionSummaryResult {
44    pub summary_text: String,
45    pub score: f32,
46    pub conversation_id: ConversationId,
47}
48
49fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
50    if a.len() != b.len() || a.is_empty() {
51        return 0.0;
52    }
53    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
54    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
55    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
56    if norm_a == 0.0 || norm_b == 0.0 {
57        return 0.0;
58    }
59    dot / (norm_a * norm_b)
60}
61
62fn apply_temporal_decay(
63    ranked: &mut [(MessageId, f64)],
64    timestamps: &std::collections::HashMap<MessageId, i64>,
65    half_life_days: u32,
66) {
67    if half_life_days == 0 {
68        return;
69    }
70    let now = std::time::SystemTime::now()
71        .duration_since(std::time::UNIX_EPOCH)
72        .unwrap_or_default()
73        .as_secs()
74        .cast_signed();
75    let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
76
77    for (msg_id, score) in ranked.iter_mut() {
78        if let Some(&ts) = timestamps.get(msg_id) {
79            #[allow(clippy::cast_precision_loss)]
80            let age_days = (now - ts).max(0) as f64 / 86400.0;
81            *score *= (-lambda * age_days).exp();
82        }
83    }
84}
85
86fn apply_mmr(
87    ranked: &[(MessageId, f64)],
88    vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
89    lambda: f32,
90    limit: usize,
91) -> Vec<(MessageId, f64)> {
92    if ranked.is_empty() || limit == 0 {
93        return Vec::new();
94    }
95
96    let lambda = f64::from(lambda);
97    let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
98    let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
99
100    while selected.len() < limit && !remaining.is_empty() {
101        let best_idx = if selected.is_empty() {
102            // Pick highest relevance first
103            0
104        } else {
105            let mut best = 0usize;
106            let mut best_score = f64::NEG_INFINITY;
107
108            for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
109                let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
110                    selected
111                        .iter()
112                        .filter_map(|(sel_id, _)| vectors.get(sel_id))
113                        .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
114                        .fold(f64::NEG_INFINITY, f64::max)
115                } else {
116                    0.0
117                };
118                let max_sim = if max_sim == f64::NEG_INFINITY {
119                    0.0
120                } else {
121                    max_sim
122                };
123                let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
124                if mmr_score > best_score {
125                    best_score = mmr_score;
126                    best = i;
127                }
128            }
129            best
130        };
131
132        selected.push(remaining.remove(best_idx));
133    }
134
135    selected
136}
137
138fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
139    let mut prompt = String::from(
140        "Summarize the following conversation. Extract key facts, decisions, entities, \
141         and context needed to continue the conversation.\n\n\
142         Respond in JSON with fields: summary (string), key_facts (list of strings), \
143         entities (list of strings).\n\nConversation:\n",
144    );
145
146    for (_, role, content) in messages {
147        prompt.push_str(role);
148        prompt.push_str(": ");
149        prompt.push_str(content);
150        prompt.push('\n');
151    }
152
153    prompt
154}
155
156pub struct SemanticMemory {
157    sqlite: SqliteStore,
158    qdrant: Option<EmbeddingStore>,
159    provider: AnyProvider,
160    embedding_model: String,
161    vector_weight: f64,
162    keyword_weight: f64,
163    temporal_decay_enabled: bool,
164    temporal_decay_half_life_days: u32,
165    mmr_enabled: bool,
166    mmr_lambda: f32,
167    pub token_counter: Arc<TokenCounter>,
168}
169
170impl SemanticMemory {
171    /// Create a new `SemanticMemory` instance with default hybrid search weights (0.7/0.3).
172    ///
173    /// Qdrant connection is best-effort: if unavailable, semantic search is disabled.
174    ///
175    /// # Errors
176    ///
177    /// Returns an error if `SQLite` cannot be initialized.
178    pub async fn new(
179        sqlite_path: &str,
180        qdrant_url: &str,
181        provider: AnyProvider,
182        embedding_model: &str,
183    ) -> Result<Self, MemoryError> {
184        Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
185    }
186
187    /// Create a new `SemanticMemory` with custom vector/keyword weights for hybrid search.
188    ///
189    /// # Errors
190    ///
191    /// Returns an error if `SQLite` cannot be initialized.
192    pub async fn with_weights(
193        sqlite_path: &str,
194        qdrant_url: &str,
195        provider: AnyProvider,
196        embedding_model: &str,
197        vector_weight: f64,
198        keyword_weight: f64,
199    ) -> Result<Self, MemoryError> {
200        let sqlite = SqliteStore::new(sqlite_path).await?;
201        let pool = sqlite.pool().clone();
202
203        let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
204            Ok(store) => Some(store),
205            Err(e) => {
206                tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
207                None
208            }
209        };
210
211        Ok(Self {
212            sqlite,
213            qdrant,
214            provider,
215            embedding_model: embedding_model.into(),
216            vector_weight,
217            keyword_weight,
218            temporal_decay_enabled: false,
219            temporal_decay_half_life_days: 30,
220            mmr_enabled: false,
221            mmr_lambda: 0.7,
222            token_counter: Arc::new(TokenCounter::new()),
223        })
224    }
225
226    /// Configure temporal decay and MMR re-ranking options.
227    #[must_use]
228    pub fn with_ranking_options(
229        mut self,
230        temporal_decay_enabled: bool,
231        temporal_decay_half_life_days: u32,
232        mmr_enabled: bool,
233        mmr_lambda: f32,
234    ) -> Self {
235        self.temporal_decay_enabled = temporal_decay_enabled;
236        self.temporal_decay_half_life_days = temporal_decay_half_life_days;
237        self.mmr_enabled = mmr_enabled;
238        self.mmr_lambda = mmr_lambda;
239        self
240    }
241
242    /// Create a `SemanticMemory` using the `SQLite`-embedded vector backend.
243    ///
244    /// # Errors
245    ///
246    /// Returns an error if `SQLite` cannot be initialized.
247    pub async fn with_sqlite_backend(
248        sqlite_path: &str,
249        provider: AnyProvider,
250        embedding_model: &str,
251        vector_weight: f64,
252        keyword_weight: f64,
253    ) -> Result<Self, MemoryError> {
254        let sqlite = SqliteStore::new(sqlite_path).await?;
255        let pool = sqlite.pool().clone();
256        let store = EmbeddingStore::new_sqlite(pool);
257
258        Ok(Self {
259            sqlite,
260            qdrant: Some(store),
261            provider,
262            embedding_model: embedding_model.into(),
263            vector_weight,
264            keyword_weight,
265            temporal_decay_enabled: false,
266            temporal_decay_half_life_days: 30,
267            mmr_enabled: false,
268            mmr_lambda: 0.7,
269            token_counter: Arc::new(TokenCounter::new()),
270        })
271    }
272
273    /// Save a message to `SQLite` and optionally embed and store in Qdrant.
274    ///
275    /// Returns the message ID assigned by `SQLite`.
276    ///
277    /// # Errors
278    ///
279    /// Returns an error if the `SQLite` save fails. Embedding failures are logged but not
280    /// propagated.
281    pub async fn remember(
282        &self,
283        conversation_id: ConversationId,
284        role: &str,
285        content: &str,
286    ) -> Result<MessageId, MemoryError> {
287        let message_id = self
288            .sqlite
289            .save_message(conversation_id, role, content)
290            .await?;
291
292        if let Some(qdrant) = &self.qdrant
293            && self.provider.supports_embeddings()
294        {
295            match self.provider.embed(content).await {
296                Ok(vector) => {
297                    // Ensure collection exists before storing
298                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
299                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
300                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
301                    } else if let Err(e) = qdrant
302                        .store(
303                            message_id,
304                            conversation_id,
305                            role,
306                            vector,
307                            MessageKind::Regular,
308                            &self.embedding_model,
309                        )
310                        .await
311                    {
312                        tracing::warn!("Failed to store embedding: {e:#}");
313                    }
314                }
315                Err(e) => {
316                    tracing::warn!("Failed to generate embedding: {e:#}");
317                }
318            }
319        }
320
321        Ok(message_id)
322    }
323
324    /// Save a message with pre-serialized parts JSON to `SQLite` and optionally embed in Qdrant.
325    ///
326    /// Returns `(message_id, embedding_stored)` tuple where `embedding_stored` is `true` if
327    /// an embedding was successfully generated and stored in Qdrant.
328    ///
329    /// # Errors
330    ///
331    /// Returns an error if the `SQLite` save fails.
332    pub async fn remember_with_parts(
333        &self,
334        conversation_id: ConversationId,
335        role: &str,
336        content: &str,
337        parts_json: &str,
338    ) -> Result<(MessageId, bool), MemoryError> {
339        let message_id = self
340            .sqlite
341            .save_message_with_parts(conversation_id, role, content, parts_json)
342            .await?;
343
344        let mut embedding_stored = false;
345
346        if let Some(qdrant) = &self.qdrant
347            && self.provider.supports_embeddings()
348        {
349            match self.provider.embed(content).await {
350                Ok(vector) => {
351                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
352                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
353                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
354                    } else if let Err(e) = qdrant
355                        .store(
356                            message_id,
357                            conversation_id,
358                            role,
359                            vector,
360                            MessageKind::Regular,
361                            &self.embedding_model,
362                        )
363                        .await
364                    {
365                        tracing::warn!("Failed to store embedding: {e:#}");
366                    } else {
367                        embedding_stored = true;
368                    }
369                }
370                Err(e) => {
371                    tracing::warn!("Failed to generate embedding: {e:#}");
372                }
373            }
374        }
375
376        Ok((message_id, embedding_stored))
377    }
378
379    /// Save a message to `SQLite` without generating an embedding.
380    ///
381    /// Use this when embedding is intentionally skipped (e.g. autosave disabled for assistant).
382    ///
383    /// # Errors
384    ///
385    /// Returns an error if the `SQLite` save fails.
386    pub async fn save_only(
387        &self,
388        conversation_id: ConversationId,
389        role: &str,
390        content: &str,
391        parts_json: &str,
392    ) -> Result<MessageId, MemoryError> {
393        self.sqlite
394            .save_message_with_parts(conversation_id, role, content, parts_json)
395            .await
396    }
397
398    /// Recall relevant messages using hybrid search (vector + FTS5 keyword).
399    ///
400    /// When Qdrant is available, runs both vector and keyword searches, then merges
401    /// results using weighted scoring. When Qdrant is unavailable, falls back to
402    /// FTS5-only keyword search.
403    ///
404    /// # Errors
405    ///
406    /// Returns an error if embedding generation, Qdrant search, or FTS5 query fails.
407    #[allow(clippy::too_many_lines)]
408    pub async fn recall(
409        &self,
410        query: &str,
411        limit: usize,
412        filter: Option<SearchFilter>,
413    ) -> Result<Vec<RecalledMessage>, MemoryError> {
414        let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
415
416        // FTS5 keyword search (always available)
417        let keyword_results = match self
418            .sqlite
419            .keyword_search(query, limit * 2, conversation_id)
420            .await
421        {
422            Ok(results) => results,
423            Err(e) => {
424                tracing::warn!("FTS5 keyword search failed: {e:#}");
425                Vec::new()
426            }
427        };
428
429        // Vector search (only when Qdrant available)
430        let vector_results = if let Some(qdrant) = &self.qdrant
431            && self.provider.supports_embeddings()
432        {
433            let query_vector = self.provider.embed(query).await?;
434            let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
435            qdrant.ensure_collection(vector_size).await?;
436            qdrant.search(&query_vector, limit * 2, filter).await?
437        } else {
438            Vec::new()
439        };
440
441        // Merge results with weighted scoring
442        let mut scores: std::collections::HashMap<MessageId, f64> =
443            std::collections::HashMap::new();
444
445        if !vector_results.is_empty() {
446            let max_vs = vector_results
447                .iter()
448                .map(|r| r.score)
449                .fold(f32::NEG_INFINITY, f32::max);
450            let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
451            for r in &vector_results {
452                let normalized = f64::from(r.score / norm);
453                *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
454            }
455        }
456
457        if !keyword_results.is_empty() {
458            let max_ks = keyword_results
459                .iter()
460                .map(|r| r.1)
461                .fold(f64::NEG_INFINITY, f64::max);
462            let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
463            for &(msg_id, score) in &keyword_results {
464                let normalized = score / norm;
465                *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
466            }
467        }
468
469        if scores.is_empty() {
470            return Ok(Vec::new());
471        }
472
473        // Sort by combined score descending
474        let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
475        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
476
477        // Apply temporal decay (before MMR)
478        if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
479            let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
480            match self.sqlite.message_timestamps(&ids).await {
481                Ok(timestamps) => {
482                    apply_temporal_decay(
483                        &mut ranked,
484                        &timestamps,
485                        self.temporal_decay_half_life_days,
486                    );
487                    ranked
488                        .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
489                }
490                Err(e) => {
491                    tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
492                }
493            }
494        }
495
496        // Apply MMR re-ranking (after decay, before truncation)
497        if self.mmr_enabled && !vector_results.is_empty() {
498            if let Some(qdrant) = &self.qdrant {
499                let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
500                match qdrant.get_vectors(&ids).await {
501                    Ok(vec_map) if !vec_map.is_empty() => {
502                        ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
503                    }
504                    Ok(_) => {
505                        ranked.truncate(limit);
506                    }
507                    Err(e) => {
508                        tracing::warn!("MMR: failed to fetch vectors: {e:#}");
509                        ranked.truncate(limit);
510                    }
511                }
512            } else {
513                ranked.truncate(limit);
514            }
515        } else {
516            ranked.truncate(limit);
517        }
518
519        let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
520        let messages = self.sqlite.messages_by_ids(&ids).await?;
521        let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
522
523        let recalled = ranked
524            .iter()
525            .filter_map(|(msg_id, score)| {
526                msg_map.get(msg_id).map(|msg| RecalledMessage {
527                    message: msg.clone(),
528                    #[expect(clippy::cast_possible_truncation)]
529                    score: *score as f32,
530                })
531            })
532            .collect();
533
534        Ok(recalled)
535    }
536
537    /// Check whether an embedding exists for a given message ID.
538    ///
539    /// # Errors
540    ///
541    /// Returns an error if the `SQLite` query fails.
542    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
543        match &self.qdrant {
544            Some(qdrant) => qdrant.has_embedding(message_id).await,
545            None => Ok(false),
546        }
547    }
548
549    /// Embed all messages that do not yet have embeddings.
550    ///
551    /// Returns the count of successfully embedded messages.
552    ///
553    /// # Errors
554    ///
555    /// Returns an error if collection initialization or database query fails.
556    /// Individual embedding failures are logged but do not stop processing.
557    pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
558        let Some(qdrant) = &self.qdrant else {
559            return Ok(0);
560        };
561        if !self.provider.supports_embeddings() {
562            return Ok(0);
563        }
564
565        let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
566
567        if unembedded.is_empty() {
568            return Ok(0);
569        }
570
571        let probe = self.provider.embed("probe").await?;
572        let vector_size = u64::try_from(probe.len())?;
573        qdrant.ensure_collection(vector_size).await?;
574
575        let mut count = 0;
576        for (msg_id, conversation_id, role, content) in &unembedded {
577            match self.provider.embed(content).await {
578                Ok(vector) => {
579                    if let Err(e) = qdrant
580                        .store(
581                            *msg_id,
582                            *conversation_id,
583                            role,
584                            vector,
585                            MessageKind::Regular,
586                            &self.embedding_model,
587                        )
588                        .await
589                    {
590                        tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
591                        continue;
592                    }
593                    count += 1;
594                }
595                Err(e) => {
596                    tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
597                }
598            }
599        }
600
601        tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
602        Ok(count)
603    }
604
605    /// Store a session summary into the dedicated `zeph_session_summaries` Qdrant collection.
606    ///
607    /// # Errors
608    ///
609    /// Returns an error if embedding or Qdrant storage fails.
610    pub async fn store_session_summary(
611        &self,
612        conversation_id: ConversationId,
613        summary_text: &str,
614    ) -> Result<(), MemoryError> {
615        let Some(qdrant) = &self.qdrant else {
616            return Ok(());
617        };
618        if !self.provider.supports_embeddings() {
619            return Ok(());
620        }
621
622        let vector = self.provider.embed(summary_text).await?;
623        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
624        qdrant
625            .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
626            .await?;
627
628        let payload = serde_json::json!({
629            "conversation_id": conversation_id.0,
630            "summary_text": summary_text,
631        });
632
633        qdrant
634            .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
635            .await?;
636
637        tracing::debug!(
638            conversation_id = conversation_id.0,
639            "stored session summary"
640        );
641        Ok(())
642    }
643
644    /// Search session summaries from other conversations.
645    ///
646    /// # Errors
647    ///
648    /// Returns an error if embedding or Qdrant search fails.
649    pub async fn search_session_summaries(
650        &self,
651        query: &str,
652        limit: usize,
653        exclude_conversation_id: Option<ConversationId>,
654    ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
655        let Some(qdrant) = &self.qdrant else {
656            return Ok(Vec::new());
657        };
658        if !self.provider.supports_embeddings() {
659            return Ok(Vec::new());
660        }
661
662        let vector = self.provider.embed(query).await?;
663        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
664        qdrant
665            .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
666            .await?;
667
668        let filter = exclude_conversation_id.map(|cid| VectorFilter {
669            must: vec![],
670            must_not: vec![FieldCondition {
671                field: "conversation_id".into(),
672                value: FieldValue::Integer(cid.0),
673            }],
674        });
675
676        let points = qdrant
677            .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
678            .await?;
679
680        let results = points
681            .into_iter()
682            .filter_map(|point| {
683                let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
684                let conversation_id =
685                    ConversationId(point.payload.get("conversation_id")?.as_i64()?);
686                Some(SessionSummaryResult {
687                    summary_text,
688                    score: point.score,
689                    conversation_id,
690                })
691            })
692            .collect();
693
694        Ok(results)
695    }
696
697    /// Access the underlying `SqliteStore` for operations that don't involve semantics.
698    #[must_use]
699    pub fn sqlite(&self) -> &SqliteStore {
700        &self.sqlite
701    }
702
703    /// Check if the vector store backend is reachable.
704    ///
705    /// Performs a real health check (Qdrant gRPC ping or `SQLite` query)
706    /// instead of just checking whether the client was created.
707    pub async fn is_vector_store_connected(&self) -> bool {
708        match self.qdrant.as_ref() {
709            Some(store) => store.health_check().await,
710            None => false,
711        }
712    }
713
714    /// Check if a vector store client is configured (may not be connected).
715    #[must_use]
716    pub fn has_vector_store(&self) -> bool {
717        self.qdrant.is_some()
718    }
719
720    /// Count messages in a conversation.
721    ///
722    /// # Errors
723    ///
724    /// Returns an error if the query fails.
725    pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
726        self.sqlite.count_messages(conversation_id).await
727    }
728
729    /// Count messages not yet covered by any summary.
730    ///
731    /// # Errors
732    ///
733    /// Returns an error if the query fails.
734    pub async fn unsummarized_message_count(
735        &self,
736        conversation_id: ConversationId,
737    ) -> Result<i64, MemoryError> {
738        let after_id = self
739            .sqlite
740            .latest_summary_last_message_id(conversation_id)
741            .await?
742            .unwrap_or(MessageId(0));
743        self.sqlite
744            .count_messages_after(conversation_id, after_id)
745            .await
746    }
747
748    /// Load all summaries for a conversation.
749    ///
750    /// # Errors
751    ///
752    /// Returns an error if the query fails.
753    pub async fn load_summaries(
754        &self,
755        conversation_id: ConversationId,
756    ) -> Result<Vec<Summary>, MemoryError> {
757        let rows = self.sqlite.load_summaries(conversation_id).await?;
758        let summaries = rows
759            .into_iter()
760            .map(
761                |(
762                    id,
763                    conversation_id,
764                    content,
765                    first_message_id,
766                    last_message_id,
767                    token_estimate,
768                )| {
769                    Summary {
770                        id,
771                        conversation_id,
772                        content,
773                        first_message_id,
774                        last_message_id,
775                        token_estimate,
776                    }
777                },
778            )
779            .collect();
780        Ok(summaries)
781    }
782
783    /// Generate a summary of the oldest unsummarized messages.
784    ///
785    /// Returns `Ok(None)` if there are not enough messages to summarize.
786    ///
787    /// # Errors
788    ///
789    /// Returns an error if LLM call or database operation fails.
790    pub async fn summarize(
791        &self,
792        conversation_id: ConversationId,
793        message_count: usize,
794    ) -> Result<Option<i64>, MemoryError> {
795        let total = self.sqlite.count_messages(conversation_id).await?;
796
797        if total <= i64::try_from(message_count)? {
798            return Ok(None);
799        }
800
801        let after_id = self
802            .sqlite
803            .latest_summary_last_message_id(conversation_id)
804            .await?
805            .unwrap_or(MessageId(0));
806
807        let messages = self
808            .sqlite
809            .load_messages_range(conversation_id, after_id, message_count)
810            .await?;
811
812        if messages.is_empty() {
813            return Ok(None);
814        }
815
816        let prompt = build_summarization_prompt(&messages);
817        let chat_messages = vec![Message {
818            role: Role::User,
819            content: prompt,
820            parts: vec![],
821            metadata: MessageMetadata::default(),
822        }];
823
824        let structured = match self
825            .provider
826            .chat_typed_erased::<StructuredSummary>(&chat_messages)
827            .await
828        {
829            Ok(s) => s,
830            Err(e) => {
831                tracing::warn!(
832                    "structured summarization failed, falling back to plain text: {e:#}"
833                );
834                let plain = self.provider.chat(&chat_messages).await?;
835                StructuredSummary {
836                    summary: plain,
837                    key_facts: vec![],
838                    entities: vec![],
839                }
840            }
841        };
842        let summary_text = &structured.summary;
843
844        let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
845        let first_message_id = messages[0].0;
846        let last_message_id = messages[messages.len() - 1].0;
847
848        let summary_id = self
849            .sqlite
850            .save_summary(
851                conversation_id,
852                summary_text,
853                first_message_id,
854                last_message_id,
855                token_estimate,
856            )
857            .await?;
858
859        if let Some(qdrant) = &self.qdrant
860            && self.provider.supports_embeddings()
861        {
862            match self.provider.embed(summary_text).await {
863                Ok(vector) => {
864                    // Ensure collection exists before storing
865                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
866                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
867                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
868                    } else if let Err(e) = qdrant
869                        .store(
870                            MessageId(summary_id),
871                            conversation_id,
872                            "system",
873                            vector,
874                            MessageKind::Summary,
875                            &self.embedding_model,
876                        )
877                        .await
878                    {
879                        tracing::warn!("Failed to embed summary: {e:#}");
880                    }
881                }
882                Err(e) => {
883                    tracing::warn!("Failed to generate summary embedding: {e:#}");
884                }
885            }
886        }
887
888        // Store key facts as individual Qdrant points
889        if !structured.key_facts.is_empty() {
890            self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
891                .await;
892        }
893
894        Ok(Some(summary_id))
895    }
896
897    async fn store_key_facts(
898        &self,
899        conversation_id: ConversationId,
900        source_summary_id: i64,
901        key_facts: &[String],
902    ) {
903        let Some(qdrant) = &self.qdrant else {
904            return;
905        };
906        if !self.provider.supports_embeddings() {
907            return;
908        }
909
910        let Some(first_fact) = key_facts.first() else {
911            return;
912        };
913        let first_vector = match self.provider.embed(first_fact).await {
914            Ok(v) => v,
915            Err(e) => {
916                tracing::warn!("Failed to embed key fact: {e:#}");
917                return;
918            }
919        };
920        let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
921        if let Err(e) = qdrant
922            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
923            .await
924        {
925            tracing::warn!("Failed to ensure key_facts collection: {e:#}");
926            return;
927        }
928
929        let first_payload = serde_json::json!({
930            "conversation_id": conversation_id.0,
931            "fact_text": first_fact,
932            "source_summary_id": source_summary_id,
933        });
934        if let Err(e) = qdrant
935            .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
936            .await
937        {
938            tracing::warn!("Failed to store key fact: {e:#}");
939        }
940
941        for fact in &key_facts[1..] {
942            match self.provider.embed(fact).await {
943                Ok(vector) => {
944                    let payload = serde_json::json!({
945                        "conversation_id": conversation_id.0,
946                        "fact_text": fact,
947                        "source_summary_id": source_summary_id,
948                    });
949                    if let Err(e) = qdrant
950                        .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
951                        .await
952                    {
953                        tracing::warn!("Failed to store key fact: {e:#}");
954                    }
955                }
956                Err(e) => {
957                    tracing::warn!("Failed to embed key fact: {e:#}");
958                }
959            }
960        }
961    }
962
963    /// Search key facts extracted from conversation summaries.
964    ///
965    /// # Errors
966    ///
967    /// Returns an error if embedding or Qdrant search fails.
968    pub async fn search_key_facts(
969        &self,
970        query: &str,
971        limit: usize,
972    ) -> Result<Vec<String>, MemoryError> {
973        let Some(qdrant) = &self.qdrant else {
974            return Ok(Vec::new());
975        };
976        if !self.provider.supports_embeddings() {
977            return Ok(Vec::new());
978        }
979
980        let vector = self.provider.embed(query).await?;
981        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
982        qdrant
983            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
984            .await?;
985
986        let points = qdrant
987            .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
988            .await?;
989
990        let facts = points
991            .into_iter()
992            .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
993            .collect();
994
995        Ok(facts)
996    }
997}
998
999#[cfg(test)]
1000mod tests {
1001    use zeph_llm::mock::MockProvider;
1002    use zeph_llm::provider::Role;
1003
1004    use super::*;
1005
1006    fn test_provider() -> AnyProvider {
1007        AnyProvider::Mock(MockProvider::default())
1008    }
1009
1010    async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1011        let provider = test_provider();
1012        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1013
1014        SemanticMemory {
1015            sqlite,
1016            qdrant: None,
1017            provider,
1018            embedding_model: "test-model".into(),
1019            vector_weight: 0.7,
1020            keyword_weight: 0.3,
1021            temporal_decay_enabled: false,
1022            temporal_decay_half_life_days: 30,
1023            mmr_enabled: false,
1024            mmr_lambda: 0.7,
1025            token_counter: Arc::new(TokenCounter::new()),
1026        }
1027    }
1028
1029    #[tokio::test]
1030    async fn remember_saves_to_sqlite() {
1031        let memory = test_semantic_memory(false).await;
1032
1033        let cid = memory.sqlite.create_conversation().await.unwrap();
1034        let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1035
1036        assert_eq!(msg_id, MessageId(1));
1037
1038        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1039        assert_eq!(history.len(), 1);
1040        assert_eq!(history[0].role, Role::User);
1041        assert_eq!(history[0].content, "hello");
1042    }
1043
1044    #[tokio::test]
1045    async fn remember_with_parts_saves_parts_json() {
1046        let memory = test_semantic_memory(false).await;
1047        let cid = memory.sqlite.create_conversation().await.unwrap();
1048
1049        let parts_json =
1050            r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1051        let (msg_id, _embedding_stored) = memory
1052            .remember_with_parts(cid, "assistant", "tool output", parts_json)
1053            .await
1054            .unwrap();
1055        assert!(msg_id > MessageId(0));
1056
1057        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1058        assert_eq!(history.len(), 1);
1059        assert_eq!(history[0].content, "tool output");
1060    }
1061
1062    #[tokio::test]
1063    async fn recall_returns_empty_without_qdrant() {
1064        let memory = test_semantic_memory(true).await;
1065
1066        let recalled = memory.recall("test", 5, None).await.unwrap();
1067        assert!(recalled.is_empty());
1068    }
1069
1070    #[tokio::test]
1071    async fn has_embedding_without_qdrant() {
1072        let memory = test_semantic_memory(true).await;
1073
1074        let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1075        assert!(!has_embedding);
1076    }
1077
1078    #[tokio::test]
1079    async fn embed_missing_without_qdrant() {
1080        let memory = test_semantic_memory(true).await;
1081
1082        let count = memory.embed_missing().await.unwrap();
1083        assert_eq!(count, 0);
1084    }
1085
1086    #[tokio::test]
1087    async fn sqlite_accessor() {
1088        let memory = test_semantic_memory(false).await;
1089
1090        let cid = memory.sqlite().create_conversation().await.unwrap();
1091        assert_eq!(cid, ConversationId(1));
1092
1093        memory
1094            .sqlite()
1095            .save_message(cid, "user", "test")
1096            .await
1097            .unwrap();
1098
1099        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1100        assert_eq!(history.len(), 1);
1101    }
1102
1103    #[tokio::test]
1104    async fn has_vector_store_returns_false_when_unavailable() {
1105        let memory = test_semantic_memory(false).await;
1106        assert!(!memory.has_vector_store());
1107    }
1108
1109    #[tokio::test]
1110    async fn is_vector_store_connected_returns_false_when_unavailable() {
1111        let memory = test_semantic_memory(false).await;
1112        assert!(!memory.is_vector_store_connected().await);
1113    }
1114
1115    #[tokio::test]
1116    async fn recall_returns_empty_when_embeddings_not_supported() {
1117        let memory = test_semantic_memory(false).await;
1118
1119        let recalled = memory.recall("test", 5, None).await.unwrap();
1120        assert!(recalled.is_empty());
1121    }
1122
1123    #[tokio::test]
1124    async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1125        let memory = test_semantic_memory(false).await;
1126
1127        let cid = memory.sqlite().create_conversation().await.unwrap();
1128        memory
1129            .sqlite()
1130            .save_message(cid, "user", "test")
1131            .await
1132            .unwrap();
1133
1134        let count = memory.embed_missing().await.unwrap();
1135        assert_eq!(count, 0);
1136    }
1137
1138    #[tokio::test]
1139    async fn message_count_empty_conversation() {
1140        let memory = test_semantic_memory(false).await;
1141        let cid = memory.sqlite().create_conversation().await.unwrap();
1142
1143        let count = memory.message_count(cid).await.unwrap();
1144        assert_eq!(count, 0);
1145    }
1146
1147    #[tokio::test]
1148    async fn message_count_after_saves() {
1149        let memory = test_semantic_memory(false).await;
1150        let cid = memory.sqlite().create_conversation().await.unwrap();
1151
1152        memory.remember(cid, "user", "msg1").await.unwrap();
1153        memory.remember(cid, "assistant", "msg2").await.unwrap();
1154
1155        let count = memory.message_count(cid).await.unwrap();
1156        assert_eq!(count, 2);
1157    }
1158
1159    #[tokio::test]
1160    async fn unsummarized_count_decreases_after_summary() {
1161        let memory = test_semantic_memory(false).await;
1162        let cid = memory.sqlite().create_conversation().await.unwrap();
1163
1164        for i in 0..10 {
1165            memory
1166                .remember(cid, "user", &format!("msg{i}"))
1167                .await
1168                .unwrap();
1169        }
1170        assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1171
1172        memory.summarize(cid, 5).await.unwrap();
1173
1174        assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1175        assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1176    }
1177
1178    #[tokio::test]
1179    async fn load_summaries_empty() {
1180        let memory = test_semantic_memory(false).await;
1181        let cid = memory.sqlite().create_conversation().await.unwrap();
1182
1183        let summaries = memory.load_summaries(cid).await.unwrap();
1184        assert!(summaries.is_empty());
1185    }
1186
1187    #[tokio::test]
1188    async fn load_summaries_ordered() {
1189        let memory = test_semantic_memory(false).await;
1190        let cid = memory.sqlite().create_conversation().await.unwrap();
1191
1192        let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1193        let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1194        let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1195
1196        let s1 = memory
1197            .sqlite()
1198            .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1199            .await
1200            .unwrap();
1201        let s2 = memory
1202            .sqlite()
1203            .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1204            .await
1205            .unwrap();
1206
1207        let summaries = memory.load_summaries(cid).await.unwrap();
1208        assert_eq!(summaries.len(), 2);
1209        assert_eq!(summaries[0].id, s1);
1210        assert_eq!(summaries[0].content, "summary1");
1211        assert_eq!(summaries[1].id, s2);
1212        assert_eq!(summaries[1].content, "summary2");
1213    }
1214
1215    #[tokio::test]
1216    async fn summarize_below_threshold() {
1217        let memory = test_semantic_memory(false).await;
1218        let cid = memory.sqlite().create_conversation().await.unwrap();
1219
1220        memory.remember(cid, "user", "hello").await.unwrap();
1221
1222        let result = memory.summarize(cid, 10).await.unwrap();
1223        assert!(result.is_none());
1224    }
1225
1226    #[tokio::test]
1227    async fn summarize_stores_summary() {
1228        let memory = test_semantic_memory(false).await;
1229        let cid = memory.sqlite().create_conversation().await.unwrap();
1230
1231        for i in 0..5 {
1232            memory
1233                .remember(cid, "user", &format!("message {i}"))
1234                .await
1235                .unwrap();
1236        }
1237
1238        let summary_id = memory.summarize(cid, 3).await.unwrap();
1239        assert!(summary_id.is_some());
1240
1241        let summaries = memory.load_summaries(cid).await.unwrap();
1242        assert_eq!(summaries.len(), 1);
1243        assert_eq!(summaries[0].id, summary_id.unwrap());
1244        assert!(!summaries[0].content.is_empty());
1245    }
1246
1247    #[tokio::test]
1248    async fn summarize_respects_previous_summaries() {
1249        let memory = test_semantic_memory(false).await;
1250        let cid = memory.sqlite().create_conversation().await.unwrap();
1251
1252        for i in 0..10 {
1253            memory
1254                .remember(cid, "user", &format!("message {i}"))
1255                .await
1256                .unwrap();
1257        }
1258
1259        let s1 = memory.summarize(cid, 3).await.unwrap();
1260        assert!(s1.is_some());
1261
1262        let s2 = memory.summarize(cid, 3).await.unwrap();
1263        assert!(s2.is_some());
1264
1265        let summaries = memory.load_summaries(cid).await.unwrap();
1266        assert_eq!(summaries.len(), 2);
1267        assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1268    }
1269
1270    #[tokio::test]
1271    async fn remember_multiple_messages_increments_ids() {
1272        let memory = test_semantic_memory(false).await;
1273        let cid = memory.sqlite.create_conversation().await.unwrap();
1274
1275        let id1 = memory.remember(cid, "user", "first").await.unwrap();
1276        let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1277        let id3 = memory.remember(cid, "user", "third").await.unwrap();
1278
1279        assert!(id1 < id2);
1280        assert!(id2 < id3);
1281    }
1282
1283    #[tokio::test]
1284    async fn message_count_across_conversations() {
1285        let memory = test_semantic_memory(false).await;
1286        let cid1 = memory.sqlite().create_conversation().await.unwrap();
1287        let cid2 = memory.sqlite().create_conversation().await.unwrap();
1288
1289        memory.remember(cid1, "user", "msg1").await.unwrap();
1290        memory.remember(cid1, "user", "msg2").await.unwrap();
1291        memory.remember(cid2, "user", "msg3").await.unwrap();
1292
1293        assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1294        assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1295    }
1296
1297    #[tokio::test]
1298    async fn summarize_exact_threshold_returns_none() {
1299        let memory = test_semantic_memory(false).await;
1300        let cid = memory.sqlite().create_conversation().await.unwrap();
1301
1302        for i in 0..3 {
1303            memory
1304                .remember(cid, "user", &format!("msg {i}"))
1305                .await
1306                .unwrap();
1307        }
1308
1309        let result = memory.summarize(cid, 3).await.unwrap();
1310        assert!(result.is_none());
1311    }
1312
1313    #[tokio::test]
1314    async fn summarize_one_above_threshold_produces_summary() {
1315        let memory = test_semantic_memory(false).await;
1316        let cid = memory.sqlite().create_conversation().await.unwrap();
1317
1318        for i in 0..4 {
1319            memory
1320                .remember(cid, "user", &format!("msg {i}"))
1321                .await
1322                .unwrap();
1323        }
1324
1325        let result = memory.summarize(cid, 3).await.unwrap();
1326        assert!(result.is_some());
1327    }
1328
1329    #[tokio::test]
1330    async fn summary_fields_populated() {
1331        let memory = test_semantic_memory(false).await;
1332        let cid = memory.sqlite().create_conversation().await.unwrap();
1333
1334        for i in 0..5 {
1335            memory
1336                .remember(cid, "user", &format!("msg {i}"))
1337                .await
1338                .unwrap();
1339        }
1340
1341        memory.summarize(cid, 3).await.unwrap();
1342        let summaries = memory.load_summaries(cid).await.unwrap();
1343        let s = &summaries[0];
1344
1345        assert_eq!(s.conversation_id, cid);
1346        assert!(s.first_message_id > MessageId(0));
1347        assert!(s.last_message_id >= s.first_message_id);
1348        assert!(s.token_estimate >= 0);
1349        assert!(!s.content.is_empty());
1350    }
1351
1352    #[test]
1353    fn build_summarization_prompt_format() {
1354        let messages = vec![
1355            (MessageId(1), "user".into(), "Hello".into()),
1356            (MessageId(2), "assistant".into(), "Hi there".into()),
1357        ];
1358        let prompt = build_summarization_prompt(&messages);
1359        assert!(prompt.contains("user: Hello"));
1360        assert!(prompt.contains("assistant: Hi there"));
1361        assert!(prompt.contains("key_facts"));
1362    }
1363
1364    #[test]
1365    fn build_summarization_prompt_empty() {
1366        let messages: Vec<(MessageId, String, String)> = vec![];
1367        let prompt = build_summarization_prompt(&messages);
1368        assert!(prompt.contains("key_facts"));
1369    }
1370
1371    #[test]
1372    fn structured_summary_deserialize() {
1373        let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1374        let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1375        assert_eq!(ss.summary, "s");
1376        assert_eq!(ss.key_facts.len(), 2);
1377        assert_eq!(ss.entities.len(), 1);
1378    }
1379
1380    #[test]
1381    fn structured_summary_empty_facts() {
1382        let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
1383        let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1384        assert!(ss.key_facts.is_empty());
1385        assert!(ss.entities.is_empty());
1386    }
1387
1388    #[tokio::test]
1389    async fn search_key_facts_no_qdrant_empty() {
1390        let memory = test_semantic_memory(false).await;
1391        let facts = memory.search_key_facts("query", 5).await.unwrap();
1392        assert!(facts.is_empty());
1393    }
1394
1395    #[test]
1396    fn recalled_message_debug() {
1397        let recalled = RecalledMessage {
1398            message: Message {
1399                role: Role::User,
1400                content: "test".into(),
1401                parts: vec![],
1402                metadata: MessageMetadata::default(),
1403            },
1404            score: 0.95,
1405        };
1406        let dbg = format!("{recalled:?}");
1407        assert!(dbg.contains("RecalledMessage"));
1408        assert!(dbg.contains("0.95"));
1409    }
1410
1411    #[test]
1412    fn summary_clone() {
1413        let summary = Summary {
1414            id: 1,
1415            conversation_id: ConversationId(2),
1416            content: "test summary".into(),
1417            first_message_id: MessageId(1),
1418            last_message_id: MessageId(5),
1419            token_estimate: 10,
1420        };
1421        let cloned = summary.clone();
1422        assert_eq!(summary.id, cloned.id);
1423        assert_eq!(summary.content, cloned.content);
1424    }
1425
1426    #[tokio::test]
1427    async fn remember_preserves_role_mapping() {
1428        let memory = test_semantic_memory(false).await;
1429        let cid = memory.sqlite.create_conversation().await.unwrap();
1430
1431        memory.remember(cid, "user", "u").await.unwrap();
1432        memory.remember(cid, "assistant", "a").await.unwrap();
1433        memory.remember(cid, "system", "s").await.unwrap();
1434
1435        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1436        assert_eq!(history.len(), 3);
1437        assert_eq!(history[0].role, Role::User);
1438        assert_eq!(history[1].role, Role::Assistant);
1439        assert_eq!(history[2].role, Role::System);
1440    }
1441
1442    #[tokio::test]
1443    async fn new_with_invalid_qdrant_url_graceful() {
1444        let mut mock = MockProvider::default();
1445        mock.supports_embeddings = true;
1446        let provider = AnyProvider::Mock(mock);
1447        let result =
1448            SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
1449        assert!(result.is_ok());
1450    }
1451
1452    #[tokio::test]
1453    async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
1454        // Build SemanticMemory with EmbeddingStore backed by SQLite instead of Qdrant
1455        let mut mock = MockProvider::default();
1456        mock.supports_embeddings = true;
1457        // Provide deterministic embedding vectors: embed returns a fixed 4-element vector
1458        // MockProvider.embed always returns the same vector, so cosine similarity = 1.0
1459        let provider = AnyProvider::Mock(mock);
1460
1461        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1462        let pool = sqlite.pool().clone();
1463        let qdrant = Some(crate::embedding_store::EmbeddingStore::new_sqlite(pool));
1464
1465        let memory = SemanticMemory {
1466            sqlite,
1467            qdrant,
1468            provider,
1469            embedding_model: "test-model".into(),
1470            vector_weight: 0.7,
1471            keyword_weight: 0.3,
1472            temporal_decay_enabled: false,
1473            temporal_decay_half_life_days: 30,
1474            mmr_enabled: false,
1475            mmr_lambda: 0.7,
1476            token_counter: Arc::new(TokenCounter::new()),
1477        };
1478
1479        let cid = memory.sqlite().create_conversation().await.unwrap();
1480
1481        // remember → stores in SQLite + SQLite vector store
1482        let id1 = memory
1483            .remember(cid, "user", "rust async programming")
1484            .await
1485            .unwrap();
1486        let id2 = memory
1487            .remember(cid, "assistant", "use tokio for async")
1488            .await
1489            .unwrap();
1490        assert!(id1 < id2);
1491
1492        // recall → should return results via FTS5 keyword search
1493        let recalled = memory.recall("rust", 5, None).await.unwrap();
1494        assert!(
1495            !recalled.is_empty(),
1496            "recall must return at least one result"
1497        );
1498
1499        // Verify history is accessible
1500        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1501        assert_eq!(history.len(), 2);
1502        assert_eq!(history[0].content, "rust async programming");
1503    }
1504
1505    #[tokio::test]
1506    async fn remember_with_embeddings_supported_but_no_qdrant() {
1507        let memory = test_semantic_memory(true).await;
1508        let cid = memory.sqlite.create_conversation().await.unwrap();
1509
1510        let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
1511        assert!(msg_id > MessageId(0));
1512
1513        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1514        assert_eq!(history.len(), 1);
1515        assert_eq!(history[0].content, "hello embed");
1516    }
1517
1518    #[tokio::test]
1519    async fn remember_verifies_content_via_load_history() {
1520        let memory = test_semantic_memory(false).await;
1521        let cid = memory.sqlite.create_conversation().await.unwrap();
1522
1523        memory.remember(cid, "user", "alpha").await.unwrap();
1524        memory.remember(cid, "assistant", "beta").await.unwrap();
1525        memory.remember(cid, "user", "gamma").await.unwrap();
1526
1527        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1528        assert_eq!(history.len(), 3);
1529        assert_eq!(history[0].content, "alpha");
1530        assert_eq!(history[1].content, "beta");
1531        assert_eq!(history[2].content, "gamma");
1532    }
1533
1534    #[tokio::test]
1535    async fn message_count_multiple_conversations_isolated() {
1536        let memory = test_semantic_memory(false).await;
1537        let cid1 = memory.sqlite().create_conversation().await.unwrap();
1538        let cid2 = memory.sqlite().create_conversation().await.unwrap();
1539        let cid3 = memory.sqlite().create_conversation().await.unwrap();
1540
1541        for _ in 0..5 {
1542            memory.remember(cid1, "user", "msg").await.unwrap();
1543        }
1544        for _ in 0..3 {
1545            memory.remember(cid2, "user", "msg").await.unwrap();
1546        }
1547
1548        assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
1549        assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
1550        assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
1551    }
1552
1553    #[tokio::test]
1554    async fn summarize_empty_messages_range_returns_none() {
1555        let memory = test_semantic_memory(false).await;
1556        let cid = memory.sqlite().create_conversation().await.unwrap();
1557
1558        for i in 0..6 {
1559            memory
1560                .remember(cid, "user", &format!("msg {i}"))
1561                .await
1562                .unwrap();
1563        }
1564
1565        memory.summarize(cid, 3).await.unwrap();
1566        memory.summarize(cid, 3).await.unwrap();
1567
1568        let summaries = memory.load_summaries(cid).await.unwrap();
1569        assert_eq!(summaries.len(), 2);
1570    }
1571
1572    #[tokio::test]
1573    async fn summarize_token_estimate_populated() {
1574        let memory = test_semantic_memory(false).await;
1575        let cid = memory.sqlite().create_conversation().await.unwrap();
1576
1577        for i in 0..5 {
1578            memory
1579                .remember(cid, "user", &format!("message {i}"))
1580                .await
1581                .unwrap();
1582        }
1583
1584        memory.summarize(cid, 3).await.unwrap();
1585        let summaries = memory.load_summaries(cid).await.unwrap();
1586        let token_est = summaries[0].token_estimate;
1587        assert!(token_est > 0);
1588    }
1589
1590    #[tokio::test]
1591    async fn summarize_fails_when_provider_chat_fails() {
1592        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1593        let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
1594            "http://127.0.0.1:1",
1595            "test".into(),
1596            "embed".into(),
1597        ));
1598        let memory = SemanticMemory {
1599            sqlite,
1600            qdrant: None,
1601            provider,
1602            embedding_model: "test".into(),
1603            vector_weight: 0.7,
1604            keyword_weight: 0.3,
1605            temporal_decay_enabled: false,
1606            temporal_decay_half_life_days: 30,
1607            mmr_enabled: false,
1608            mmr_lambda: 0.7,
1609            token_counter: Arc::new(TokenCounter::new()),
1610        };
1611        let cid = memory.sqlite().create_conversation().await.unwrap();
1612
1613        for i in 0..5 {
1614            memory
1615                .remember(cid, "user", &format!("msg {i}"))
1616                .await
1617                .unwrap();
1618        }
1619
1620        let result = memory.summarize(cid, 3).await;
1621        assert!(result.is_err());
1622    }
1623
1624    #[tokio::test]
1625    async fn embed_missing_without_embedding_support_returns_zero() {
1626        let memory = test_semantic_memory(false).await;
1627        let cid = memory.sqlite().create_conversation().await.unwrap();
1628        memory
1629            .sqlite()
1630            .save_message(cid, "user", "test message")
1631            .await
1632            .unwrap();
1633
1634        let count = memory.embed_missing().await.unwrap();
1635        assert_eq!(count, 0);
1636    }
1637
1638    #[tokio::test]
1639    async fn has_embedding_returns_false_when_no_qdrant() {
1640        let memory = test_semantic_memory(false).await;
1641        let cid = memory.sqlite.create_conversation().await.unwrap();
1642        let msg_id = memory.remember(cid, "user", "test").await.unwrap();
1643        assert!(!memory.has_embedding(msg_id).await.unwrap());
1644    }
1645
1646    #[tokio::test]
1647    async fn recall_empty_without_qdrant_regardless_of_filter() {
1648        let memory = test_semantic_memory(true).await;
1649        let filter = SearchFilter {
1650            conversation_id: Some(ConversationId(1)),
1651            role: None,
1652        };
1653        let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
1654        assert!(recalled.is_empty());
1655    }
1656
1657    #[tokio::test]
1658    async fn summarize_message_range_bounds() {
1659        let memory = test_semantic_memory(false).await;
1660        let cid = memory.sqlite().create_conversation().await.unwrap();
1661
1662        for i in 0..8 {
1663            memory
1664                .remember(cid, "user", &format!("msg {i}"))
1665                .await
1666                .unwrap();
1667        }
1668
1669        let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
1670        let summaries = memory.load_summaries(cid).await.unwrap();
1671        assert_eq!(summaries.len(), 1);
1672        assert_eq!(summaries[0].id, summary_id);
1673        assert!(summaries[0].first_message_id >= MessageId(1));
1674        assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
1675    }
1676
1677    #[test]
1678    fn build_summarization_prompt_preserves_order() {
1679        let messages = vec![
1680            (MessageId(1), "user".into(), "first".into()),
1681            (MessageId(2), "assistant".into(), "second".into()),
1682            (MessageId(3), "user".into(), "third".into()),
1683        ];
1684        let prompt = build_summarization_prompt(&messages);
1685        let first_pos = prompt.find("user: first").unwrap();
1686        let second_pos = prompt.find("assistant: second").unwrap();
1687        let third_pos = prompt.find("user: third").unwrap();
1688        assert!(first_pos < second_pos);
1689        assert!(second_pos < third_pos);
1690    }
1691
1692    #[test]
1693    fn summary_debug() {
1694        let summary = Summary {
1695            id: 1,
1696            conversation_id: ConversationId(2),
1697            content: "test".into(),
1698            first_message_id: MessageId(1),
1699            last_message_id: MessageId(5),
1700            token_estimate: 10,
1701        };
1702        let dbg = format!("{summary:?}");
1703        assert!(dbg.contains("Summary"));
1704    }
1705
1706    #[tokio::test]
1707    async fn message_count_nonexistent_conversation() {
1708        let memory = test_semantic_memory(false).await;
1709        let count = memory.message_count(ConversationId(999)).await.unwrap();
1710        assert_eq!(count, 0);
1711    }
1712
1713    #[tokio::test]
1714    async fn load_summaries_nonexistent_conversation() {
1715        let memory = test_semantic_memory(false).await;
1716        let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
1717        assert!(summaries.is_empty());
1718    }
1719
1720    #[tokio::test]
1721    async fn store_session_summary_no_qdrant_noop() {
1722        let memory = test_semantic_memory(true).await;
1723        let result = memory
1724            .store_session_summary(ConversationId(1), "test summary")
1725            .await;
1726        assert!(result.is_ok());
1727    }
1728
1729    #[tokio::test]
1730    async fn store_session_summary_no_embeddings_noop() {
1731        let memory = test_semantic_memory(false).await;
1732        let result = memory
1733            .store_session_summary(ConversationId(1), "test summary")
1734            .await;
1735        assert!(result.is_ok());
1736    }
1737
1738    #[tokio::test]
1739    async fn search_session_summaries_no_qdrant_empty() {
1740        let memory = test_semantic_memory(true).await;
1741        let results = memory
1742            .search_session_summaries("query", 5, None)
1743            .await
1744            .unwrap();
1745        assert!(results.is_empty());
1746    }
1747
1748    #[tokio::test]
1749    async fn search_session_summaries_no_embeddings_empty() {
1750        let memory = test_semantic_memory(false).await;
1751        let results = memory
1752            .search_session_summaries("query", 5, Some(ConversationId(1)))
1753            .await
1754            .unwrap();
1755        assert!(results.is_empty());
1756    }
1757
1758    #[test]
1759    fn session_summary_result_debug() {
1760        let result = SessionSummaryResult {
1761            summary_text: "test".into(),
1762            score: 0.9,
1763            conversation_id: ConversationId(1),
1764        };
1765        let dbg = format!("{result:?}");
1766        assert!(dbg.contains("SessionSummaryResult"));
1767    }
1768
1769    #[test]
1770    fn session_summary_result_clone() {
1771        let result = SessionSummaryResult {
1772            summary_text: "test".into(),
1773            score: 0.9,
1774            conversation_id: ConversationId(1),
1775        };
1776        let cloned = result.clone();
1777        assert_eq!(result.summary_text, cloned.summary_text);
1778        assert_eq!(result.conversation_id, cloned.conversation_id);
1779    }
1780
1781    #[tokio::test]
1782    async fn recall_fts5_fallback_without_qdrant() {
1783        let memory = test_semantic_memory(false).await;
1784        let cid = memory.sqlite.create_conversation().await.unwrap();
1785
1786        memory
1787            .remember(cid, "user", "rust programming guide")
1788            .await
1789            .unwrap();
1790        memory
1791            .remember(cid, "assistant", "python tutorial")
1792            .await
1793            .unwrap();
1794        memory
1795            .remember(cid, "user", "advanced rust patterns")
1796            .await
1797            .unwrap();
1798
1799        let recalled = memory.recall("rust", 5, None).await.unwrap();
1800        assert_eq!(recalled.len(), 2);
1801        assert!(recalled[0].score >= recalled[1].score);
1802    }
1803
1804    #[tokio::test]
1805    async fn recall_fts5_fallback_with_filter() {
1806        let memory = test_semantic_memory(false).await;
1807        let cid1 = memory.sqlite.create_conversation().await.unwrap();
1808        let cid2 = memory.sqlite.create_conversation().await.unwrap();
1809
1810        memory.remember(cid1, "user", "hello world").await.unwrap();
1811        memory
1812            .remember(cid2, "user", "hello universe")
1813            .await
1814            .unwrap();
1815
1816        let filter = SearchFilter {
1817            conversation_id: Some(cid1),
1818            role: None,
1819        };
1820        let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
1821        assert_eq!(recalled.len(), 1);
1822    }
1823
1824    #[tokio::test]
1825    async fn recall_fts5_no_matches_returns_empty() {
1826        let memory = test_semantic_memory(false).await;
1827        let cid = memory.sqlite.create_conversation().await.unwrap();
1828
1829        memory.remember(cid, "user", "hello world").await.unwrap();
1830
1831        let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
1832        assert!(recalled.is_empty());
1833    }
1834
1835    #[tokio::test]
1836    async fn recall_fts5_respects_limit() {
1837        let memory = test_semantic_memory(false).await;
1838        let cid = memory.sqlite.create_conversation().await.unwrap();
1839
1840        for i in 0..10 {
1841            memory
1842                .remember(cid, "user", &format!("test message number {i}"))
1843                .await
1844                .unwrap();
1845        }
1846
1847        let recalled = memory.recall("test", 3, None).await.unwrap();
1848        assert_eq!(recalled.len(), 3);
1849    }
1850
1851    // Priority 2: summarize fallback path
1852
1853    #[tokio::test]
1854    async fn summarize_fallback_to_plain_text_when_structured_fails() {
1855        // Use OllamaProvider pointing at an unreachable URL for chat_typed_erased,
1856        // but MockProvider for the plain chat call.
1857        // The easiest way: MockProvider returns non-JSON plain text so chat_typed_erased
1858        // (which uses chat() + JSON parse) will fail to parse, then falls back to chat().
1859        // However MockProvider.chat_typed calls chat() which returns default_response.
1860        // chat_typed tries to parse it as JSON → fails → retries → fails → returns StructuredParse error.
1861        // Then the fallback calls plain chat() which succeeds.
1862        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1863        let mut mock = MockProvider::default();
1864        // First two calls go to chat_typed (attempt + retry), third call is the plain fallback
1865        mock.default_response = "plain text summary".into();
1866        let provider = AnyProvider::Mock(mock);
1867
1868        let memory = SemanticMemory {
1869            sqlite,
1870            qdrant: None,
1871            provider,
1872            embedding_model: "test".into(),
1873            vector_weight: 0.7,
1874            keyword_weight: 0.3,
1875            temporal_decay_enabled: false,
1876            temporal_decay_half_life_days: 30,
1877            mmr_enabled: false,
1878            mmr_lambda: 0.7,
1879            token_counter: Arc::new(TokenCounter::new()),
1880        };
1881
1882        let cid = memory.sqlite().create_conversation().await.unwrap();
1883        for i in 0..5 {
1884            memory
1885                .remember(cid, "user", &format!("msg {i}"))
1886                .await
1887                .unwrap();
1888        }
1889
1890        let result = memory.summarize(cid, 3).await;
1891        // The summarize will either succeed (with plain text fallback) or fail
1892        // depending on how many retries chat_typed_erased does internally.
1893        // With MockProvider returning non-JSON plain text, chat_typed fails to parse.
1894        // The fallback plain chat() returns "plain text summary".
1895        // Result should be Ok with a summary stored.
1896        assert!(result.is_ok());
1897        let summaries = memory.load_summaries(cid).await.unwrap();
1898        assert_eq!(summaries.len(), 1);
1899        assert!(!summaries[0].content.is_empty());
1900    }
1901
1902    // Temporal decay tests
1903
1904    #[test]
1905    fn temporal_decay_disabled_leaves_scores_unchanged() {
1906        let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
1907        let timestamps = std::collections::HashMap::new();
1908        apply_temporal_decay(&mut ranked, &timestamps, 30);
1909        assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
1910        assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
1911    }
1912
1913    #[test]
1914    fn temporal_decay_zero_age_preserves_score() {
1915        let now = std::time::SystemTime::now()
1916            .duration_since(std::time::UNIX_EPOCH)
1917            .unwrap_or_default()
1918            .as_secs()
1919            .cast_signed();
1920        let mut ranked = vec![(MessageId(1), 1.0f64)];
1921        let mut timestamps = std::collections::HashMap::new();
1922        timestamps.insert(MessageId(1), now);
1923        apply_temporal_decay(&mut ranked, &timestamps, 30);
1924        // age = 0 days, exp(0) = 1.0 → no change
1925        assert!((ranked[0].1 - 1.0).abs() < 0.01);
1926    }
1927
1928    #[test]
1929    fn temporal_decay_half_life_halves_score() {
1930        // Age exactly half_life_days → score should be halved
1931        let half_life = 30u32;
1932        let age_secs = i64::from(half_life) * 86400;
1933        let now = std::time::SystemTime::now()
1934            .duration_since(std::time::UNIX_EPOCH)
1935            .unwrap_or_default()
1936            .as_secs()
1937            .cast_signed();
1938        let ts = now - age_secs;
1939        let mut ranked = vec![(MessageId(1), 1.0f64)];
1940        let mut timestamps = std::collections::HashMap::new();
1941        timestamps.insert(MessageId(1), ts);
1942        apply_temporal_decay(&mut ranked, &timestamps, half_life);
1943        // exp(-ln2) = 0.5
1944        assert!(
1945            (ranked[0].1 - 0.5).abs() < 0.01,
1946            "score was {}",
1947            ranked[0].1
1948        );
1949    }
1950
1951    // MMR tests
1952
1953    #[test]
1954    fn mmr_empty_input_returns_empty() {
1955        let ranked = vec![];
1956        let vectors = std::collections::HashMap::new();
1957        let result = apply_mmr(&ranked, &vectors, 0.7, 5);
1958        assert!(result.is_empty());
1959    }
1960
1961    #[test]
1962    fn mmr_returns_up_to_limit() {
1963        let ranked = vec![
1964            (MessageId(1), 1.0f64),
1965            (MessageId(2), 0.9f64),
1966            (MessageId(3), 0.8f64),
1967        ];
1968        let mut vectors = std::collections::HashMap::new();
1969        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
1970        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
1971        vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
1972        let result = apply_mmr(&ranked, &vectors, 0.7, 2);
1973        assert_eq!(result.len(), 2);
1974    }
1975
1976    #[test]
1977    fn mmr_without_vectors_picks_by_relevance() {
1978        let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
1979        let vectors = std::collections::HashMap::new();
1980        let result = apply_mmr(&ranked, &vectors, 0.7, 2);
1981        assert_eq!(result.len(), 2);
1982        assert_eq!(result[0].0, MessageId(1));
1983    }
1984
1985    #[test]
1986    fn mmr_prefers_diverse_over_redundant() {
1987        // Two candidates with same relevance but msg 2 is orthogonal (more diverse)
1988        let ranked = vec![
1989            (MessageId(1), 1.0f64), // selected first
1990            (MessageId(2), 0.9f64), // orthogonal to 1
1991            (MessageId(3), 0.9f64), // parallel to 1 (redundant)
1992        ];
1993        let mut vectors = std::collections::HashMap::new();
1994        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
1995        vectors.insert(MessageId(2), vec![0.0f32, 1.0]); // orthogonal
1996        vectors.insert(MessageId(3), vec![1.0f32, 0.0]); // same as 1
1997        let result = apply_mmr(&ranked, &vectors, 0.5, 2);
1998        assert_eq!(result.len(), 2);
1999        assert_eq!(result[0].0, MessageId(1));
2000        // msg 2 should be preferred over msg 3 (diverse)
2001        assert_eq!(result[1].0, MessageId(2));
2002    }
2003
2004    #[test]
2005    fn temporal_decay_half_life_zero_is_noop() {
2006        let now = std::time::SystemTime::now()
2007            .duration_since(std::time::UNIX_EPOCH)
2008            .unwrap_or_default()
2009            .as_secs()
2010            .cast_signed();
2011        let age_secs = 30i64 * 86400;
2012        let ts = now - age_secs;
2013        let mut ranked = vec![(MessageId(1), 1.0f64)];
2014        let mut timestamps = std::collections::HashMap::new();
2015        timestamps.insert(MessageId(1), ts);
2016        // half_life=0 → guard returns early, score must remain 1.0
2017        apply_temporal_decay(&mut ranked, &timestamps, 0);
2018        assert!(
2019            (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2020            "score was {}",
2021            ranked[0].1
2022        );
2023    }
2024
2025    #[test]
2026    fn temporal_decay_huge_age_near_zero() {
2027        let now = std::time::SystemTime::now()
2028            .duration_since(std::time::UNIX_EPOCH)
2029            .unwrap_or_default()
2030            .as_secs()
2031            .cast_signed();
2032        // 10 years = ~3650 days
2033        let age_secs = 3650i64 * 86400;
2034        let ts = now - age_secs;
2035        let mut ranked = vec![(MessageId(1), 1.0f64)];
2036        let mut timestamps = std::collections::HashMap::new();
2037        timestamps.insert(MessageId(1), ts);
2038        apply_temporal_decay(&mut ranked, &timestamps, 30);
2039        // After 3650 days with half_life=30, score should be essentially 0
2040        assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2041    }
2042
2043    #[test]
2044    fn temporal_decay_small_half_life() {
2045        // Very small half_life (1 day), age = 7 days → 2^(-7) ≈ 0.0078
2046        let now = std::time::SystemTime::now()
2047            .duration_since(std::time::UNIX_EPOCH)
2048            .unwrap_or_default()
2049            .as_secs()
2050            .cast_signed();
2051        let ts = now - 7 * 86400i64;
2052        let mut ranked = vec![(MessageId(1), 1.0f64)];
2053        let mut timestamps = std::collections::HashMap::new();
2054        timestamps.insert(MessageId(1), ts);
2055        apply_temporal_decay(&mut ranked, &timestamps, 1);
2056        assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2057    }
2058
2059    #[test]
2060    fn mmr_lambda_zero_max_diversity() {
2061        // lambda=0 → pure diversity: second item should be most dissimilar
2062        let ranked = vec![
2063            (MessageId(1), 1.0f64),  // selected first (always highest relevance)
2064            (MessageId(2), 0.9f64),  // orthogonal to 1
2065            (MessageId(3), 0.85f64), // parallel to 1 (max_sim=1.0)
2066        ];
2067        let mut vectors = std::collections::HashMap::new();
2068        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2069        vectors.insert(MessageId(2), vec![0.0f32, 1.0]); // orthogonal
2070        vectors.insert(MessageId(3), vec![1.0f32, 0.0]); // same direction
2071        let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2072        assert_eq!(result.len(), 3);
2073        // After 1 is selected: mmr(2) = 0 - (1-0)*0 = 0, mmr(3) = 0 - 1*1 = -1 → 2 wins
2074        assert_eq!(result[1].0, MessageId(2));
2075    }
2076
2077    #[test]
2078    fn mmr_lambda_one_pure_relevance() {
2079        // lambda=1 → pure relevance, should pick in relevance order
2080        let ranked = vec![
2081            (MessageId(1), 1.0f64),
2082            (MessageId(2), 0.8f64),
2083            (MessageId(3), 0.6f64),
2084        ];
2085        let mut vectors = std::collections::HashMap::new();
2086        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2087        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2088        vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2089        let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2090        assert_eq!(result.len(), 3);
2091        assert_eq!(result[0].0, MessageId(1));
2092        assert_eq!(result[1].0, MessageId(2));
2093        assert_eq!(result[2].0, MessageId(3));
2094    }
2095
2096    #[test]
2097    fn mmr_limit_zero_returns_empty() {
2098        let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2099        let mut vectors = std::collections::HashMap::new();
2100        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2101        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2102        let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2103        assert!(result.is_empty());
2104    }
2105
2106    #[test]
2107    fn mmr_duplicate_vectors_penalizes_second() {
2108        // Two items with identical embeddings: second should be heavily penalized
2109        let ranked = vec![
2110            (MessageId(1), 1.0f64),
2111            (MessageId(2), 1.0f64), // same relevance, same direction
2112            (MessageId(3), 0.9f64), // orthogonal, lower relevance
2113        ];
2114        let mut vectors = std::collections::HashMap::new();
2115        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2116        vectors.insert(MessageId(2), vec![1.0f32, 0.0]); // duplicate
2117        vectors.insert(MessageId(3), vec![0.0f32, 1.0]); // orthogonal
2118        let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2119        assert_eq!(result.len(), 3);
2120        assert_eq!(result[0].0, MessageId(1));
2121        // msg3 (orthogonal) should be preferred over msg2 (duplicate) with lambda=0.5
2122        assert_eq!(result[1].0, MessageId(3));
2123    }
2124
2125    // Priority 3: proptest
2126
2127    use proptest::prelude::*;
2128
2129    proptest! {
2130        #[test]
2131        fn count_tokens_never_panics(s in ".*") {
2132            let counter = crate::token_counter::TokenCounter::new();
2133            let _ = counter.count_tokens(&s);
2134        }
2135    }
2136}