Skip to main content

zeph_memory/
semantic.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8
9use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
10use crate::error::MemoryError;
11use crate::sqlite::SqliteStore;
12use crate::token_counter::TokenCounter;
13use crate::types::{ConversationId, MessageId};
14use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
15
16const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
17const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
18
19#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
20pub struct StructuredSummary {
21    pub summary: String,
22    pub key_facts: Vec<String>,
23    pub entities: Vec<String>,
24}
25
26#[derive(Debug)]
27pub struct RecalledMessage {
28    pub message: Message,
29    pub score: f32,
30}
31
32#[derive(Debug, Clone)]
33pub struct Summary {
34    pub id: i64,
35    pub conversation_id: ConversationId,
36    pub content: String,
37    pub first_message_id: MessageId,
38    pub last_message_id: MessageId,
39    pub token_estimate: i64,
40}
41
42#[derive(Debug, Clone)]
43pub struct SessionSummaryResult {
44    pub summary_text: String,
45    pub score: f32,
46    pub conversation_id: ConversationId,
47}
48
49fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
50    if a.len() != b.len() || a.is_empty() {
51        return 0.0;
52    }
53    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
54    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
55    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
56    if norm_a == 0.0 || norm_b == 0.0 {
57        return 0.0;
58    }
59    dot / (norm_a * norm_b)
60}
61
62fn apply_temporal_decay(
63    ranked: &mut [(MessageId, f64)],
64    timestamps: &std::collections::HashMap<MessageId, i64>,
65    half_life_days: u32,
66) {
67    if half_life_days == 0 {
68        return;
69    }
70    let now = std::time::SystemTime::now()
71        .duration_since(std::time::UNIX_EPOCH)
72        .unwrap_or_default()
73        .as_secs()
74        .cast_signed();
75    let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
76
77    for (msg_id, score) in ranked.iter_mut() {
78        if let Some(&ts) = timestamps.get(msg_id) {
79            #[allow(clippy::cast_precision_loss)]
80            let age_days = (now - ts).max(0) as f64 / 86400.0;
81            *score *= (-lambda * age_days).exp();
82        }
83    }
84}
85
86fn apply_mmr(
87    ranked: &[(MessageId, f64)],
88    vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
89    lambda: f32,
90    limit: usize,
91) -> Vec<(MessageId, f64)> {
92    if ranked.is_empty() || limit == 0 {
93        return Vec::new();
94    }
95
96    let lambda = f64::from(lambda);
97    let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
98    let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
99
100    while selected.len() < limit && !remaining.is_empty() {
101        let best_idx = if selected.is_empty() {
102            // Pick highest relevance first
103            0
104        } else {
105            let mut best = 0usize;
106            let mut best_score = f64::NEG_INFINITY;
107
108            for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
109                let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
110                    selected
111                        .iter()
112                        .filter_map(|(sel_id, _)| vectors.get(sel_id))
113                        .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
114                        .fold(f64::NEG_INFINITY, f64::max)
115                } else {
116                    0.0
117                };
118                let max_sim = if max_sim == f64::NEG_INFINITY {
119                    0.0
120                } else {
121                    max_sim
122                };
123                let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
124                if mmr_score > best_score {
125                    best_score = mmr_score;
126                    best = i;
127                }
128            }
129            best
130        };
131
132        selected.push(remaining.remove(best_idx));
133    }
134
135    selected
136}
137
138fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
139    let mut prompt = String::from(
140        "Summarize the following conversation. Extract key facts, decisions, entities, \
141         and context needed to continue the conversation.\n\n\
142         Respond in JSON with fields: summary (string), key_facts (list of strings), \
143         entities (list of strings).\n\nConversation:\n",
144    );
145
146    for (_, role, content) in messages {
147        prompt.push_str(role);
148        prompt.push_str(": ");
149        prompt.push_str(content);
150        prompt.push('\n');
151    }
152
153    prompt
154}
155
156pub struct SemanticMemory {
157    sqlite: SqliteStore,
158    qdrant: Option<EmbeddingStore>,
159    provider: AnyProvider,
160    embedding_model: String,
161    vector_weight: f64,
162    keyword_weight: f64,
163    temporal_decay_enabled: bool,
164    temporal_decay_half_life_days: u32,
165    mmr_enabled: bool,
166    mmr_lambda: f32,
167    pub token_counter: Arc<TokenCounter>,
168}
169
170impl SemanticMemory {
171    /// Create a new `SemanticMemory` instance with default hybrid search weights (0.7/0.3).
172    ///
173    /// Qdrant connection is best-effort: if unavailable, semantic search is disabled.
174    ///
175    /// # Errors
176    ///
177    /// Returns an error if `SQLite` cannot be initialized.
178    pub async fn new(
179        sqlite_path: &str,
180        qdrant_url: &str,
181        provider: AnyProvider,
182        embedding_model: &str,
183    ) -> Result<Self, MemoryError> {
184        Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
185    }
186
187    /// Create a new `SemanticMemory` with custom vector/keyword weights for hybrid search.
188    ///
189    /// # Errors
190    ///
191    /// Returns an error if `SQLite` cannot be initialized.
192    pub async fn with_weights(
193        sqlite_path: &str,
194        qdrant_url: &str,
195        provider: AnyProvider,
196        embedding_model: &str,
197        vector_weight: f64,
198        keyword_weight: f64,
199    ) -> Result<Self, MemoryError> {
200        Self::with_weights_and_pool_size(
201            sqlite_path,
202            qdrant_url,
203            provider,
204            embedding_model,
205            vector_weight,
206            keyword_weight,
207            5,
208        )
209        .await
210    }
211
212    /// Create a new `SemanticMemory` with custom weights and configurable pool size.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if `SQLite` cannot be initialized.
217    pub async fn with_weights_and_pool_size(
218        sqlite_path: &str,
219        qdrant_url: &str,
220        provider: AnyProvider,
221        embedding_model: &str,
222        vector_weight: f64,
223        keyword_weight: f64,
224        pool_size: u32,
225    ) -> Result<Self, MemoryError> {
226        let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
227        let pool = sqlite.pool().clone();
228
229        let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
230            Ok(store) => Some(store),
231            Err(e) => {
232                tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
233                None
234            }
235        };
236
237        Ok(Self {
238            sqlite,
239            qdrant,
240            provider,
241            embedding_model: embedding_model.into(),
242            vector_weight,
243            keyword_weight,
244            temporal_decay_enabled: false,
245            temporal_decay_half_life_days: 30,
246            mmr_enabled: false,
247            mmr_lambda: 0.7,
248            token_counter: Arc::new(TokenCounter::new()),
249        })
250    }
251
252    /// Configure temporal decay and MMR re-ranking options.
253    #[must_use]
254    pub fn with_ranking_options(
255        mut self,
256        temporal_decay_enabled: bool,
257        temporal_decay_half_life_days: u32,
258        mmr_enabled: bool,
259        mmr_lambda: f32,
260    ) -> Self {
261        self.temporal_decay_enabled = temporal_decay_enabled;
262        self.temporal_decay_half_life_days = temporal_decay_half_life_days;
263        self.mmr_enabled = mmr_enabled;
264        self.mmr_lambda = mmr_lambda;
265        self
266    }
267
268    /// Create a `SemanticMemory` using the `SQLite`-embedded vector backend.
269    ///
270    /// # Errors
271    ///
272    /// Returns an error if `SQLite` cannot be initialized.
273    pub async fn with_sqlite_backend(
274        sqlite_path: &str,
275        provider: AnyProvider,
276        embedding_model: &str,
277        vector_weight: f64,
278        keyword_weight: f64,
279    ) -> Result<Self, MemoryError> {
280        Self::with_sqlite_backend_and_pool_size(
281            sqlite_path,
282            provider,
283            embedding_model,
284            vector_weight,
285            keyword_weight,
286            5,
287        )
288        .await
289    }
290
291    /// Create a `SemanticMemory` using the `SQLite`-embedded vector backend with configurable pool size.
292    ///
293    /// # Errors
294    ///
295    /// Returns an error if `SQLite` cannot be initialized.
296    pub async fn with_sqlite_backend_and_pool_size(
297        sqlite_path: &str,
298        provider: AnyProvider,
299        embedding_model: &str,
300        vector_weight: f64,
301        keyword_weight: f64,
302        pool_size: u32,
303    ) -> Result<Self, MemoryError> {
304        let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
305        let pool = sqlite.pool().clone();
306        let store = EmbeddingStore::new_sqlite(pool);
307
308        Ok(Self {
309            sqlite,
310            qdrant: Some(store),
311            provider,
312            embedding_model: embedding_model.into(),
313            vector_weight,
314            keyword_weight,
315            temporal_decay_enabled: false,
316            temporal_decay_half_life_days: 30,
317            mmr_enabled: false,
318            mmr_lambda: 0.7,
319            token_counter: Arc::new(TokenCounter::new()),
320        })
321    }
322
323    /// Save a message to `SQLite` and optionally embed and store in Qdrant.
324    ///
325    /// Returns the message ID assigned by `SQLite`.
326    ///
327    /// # Errors
328    ///
329    /// Returns an error if the `SQLite` save fails. Embedding failures are logged but not
330    /// propagated.
331    pub async fn remember(
332        &self,
333        conversation_id: ConversationId,
334        role: &str,
335        content: &str,
336    ) -> Result<MessageId, MemoryError> {
337        let message_id = self
338            .sqlite
339            .save_message(conversation_id, role, content)
340            .await?;
341
342        if let Some(qdrant) = &self.qdrant
343            && self.provider.supports_embeddings()
344        {
345            match self.provider.embed(content).await {
346                Ok(vector) => {
347                    // Ensure collection exists before storing
348                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
349                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
350                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
351                    } else if let Err(e) = qdrant
352                        .store(
353                            message_id,
354                            conversation_id,
355                            role,
356                            vector,
357                            MessageKind::Regular,
358                            &self.embedding_model,
359                        )
360                        .await
361                    {
362                        tracing::warn!("Failed to store embedding: {e:#}");
363                    }
364                }
365                Err(e) => {
366                    tracing::warn!("Failed to generate embedding: {e:#}");
367                }
368            }
369        }
370
371        Ok(message_id)
372    }
373
374    /// Save a message with pre-serialized parts JSON to `SQLite` and optionally embed in Qdrant.
375    ///
376    /// Returns `(message_id, embedding_stored)` tuple where `embedding_stored` is `true` if
377    /// an embedding was successfully generated and stored in Qdrant.
378    ///
379    /// # Errors
380    ///
381    /// Returns an error if the `SQLite` save fails.
382    pub async fn remember_with_parts(
383        &self,
384        conversation_id: ConversationId,
385        role: &str,
386        content: &str,
387        parts_json: &str,
388    ) -> Result<(MessageId, bool), MemoryError> {
389        let message_id = self
390            .sqlite
391            .save_message_with_parts(conversation_id, role, content, parts_json)
392            .await?;
393
394        let mut embedding_stored = false;
395
396        if let Some(qdrant) = &self.qdrant
397            && self.provider.supports_embeddings()
398        {
399            match self.provider.embed(content).await {
400                Ok(vector) => {
401                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
402                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
403                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
404                    } else if let Err(e) = qdrant
405                        .store(
406                            message_id,
407                            conversation_id,
408                            role,
409                            vector,
410                            MessageKind::Regular,
411                            &self.embedding_model,
412                        )
413                        .await
414                    {
415                        tracing::warn!("Failed to store embedding: {e:#}");
416                    } else {
417                        embedding_stored = true;
418                    }
419                }
420                Err(e) => {
421                    tracing::warn!("Failed to generate embedding: {e:#}");
422                }
423            }
424        }
425
426        Ok((message_id, embedding_stored))
427    }
428
429    /// Save a message to `SQLite` without generating an embedding.
430    ///
431    /// Use this when embedding is intentionally skipped (e.g. autosave disabled for assistant).
432    ///
433    /// # Errors
434    ///
435    /// Returns an error if the `SQLite` save fails.
436    pub async fn save_only(
437        &self,
438        conversation_id: ConversationId,
439        role: &str,
440        content: &str,
441        parts_json: &str,
442    ) -> Result<MessageId, MemoryError> {
443        self.sqlite
444            .save_message_with_parts(conversation_id, role, content, parts_json)
445            .await
446    }
447
448    /// Recall relevant messages using hybrid search (vector + FTS5 keyword).
449    ///
450    /// When Qdrant is available, runs both vector and keyword searches, then merges
451    /// results using weighted scoring. When Qdrant is unavailable, falls back to
452    /// FTS5-only keyword search.
453    ///
454    /// # Errors
455    ///
456    /// Returns an error if embedding generation, Qdrant search, or FTS5 query fails.
457    #[allow(clippy::too_many_lines)]
458    pub async fn recall(
459        &self,
460        query: &str,
461        limit: usize,
462        filter: Option<SearchFilter>,
463    ) -> Result<Vec<RecalledMessage>, MemoryError> {
464        let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
465
466        // FTS5 keyword search (always available)
467        let keyword_results = match self
468            .sqlite
469            .keyword_search(query, limit * 2, conversation_id)
470            .await
471        {
472            Ok(results) => results,
473            Err(e) => {
474                tracing::warn!("FTS5 keyword search failed: {e:#}");
475                Vec::new()
476            }
477        };
478
479        // Vector search (only when Qdrant available)
480        let vector_results = if let Some(qdrant) = &self.qdrant
481            && self.provider.supports_embeddings()
482        {
483            let query_vector = self.provider.embed(query).await?;
484            let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
485            qdrant.ensure_collection(vector_size).await?;
486            qdrant.search(&query_vector, limit * 2, filter).await?
487        } else {
488            Vec::new()
489        };
490
491        // Merge results with weighted scoring
492        let mut scores: std::collections::HashMap<MessageId, f64> =
493            std::collections::HashMap::new();
494
495        if !vector_results.is_empty() {
496            let max_vs = vector_results
497                .iter()
498                .map(|r| r.score)
499                .fold(f32::NEG_INFINITY, f32::max);
500            let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
501            for r in &vector_results {
502                let normalized = f64::from(r.score / norm);
503                *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
504            }
505        }
506
507        if !keyword_results.is_empty() {
508            let max_ks = keyword_results
509                .iter()
510                .map(|r| r.1)
511                .fold(f64::NEG_INFINITY, f64::max);
512            let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
513            for &(msg_id, score) in &keyword_results {
514                let normalized = score / norm;
515                *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
516            }
517        }
518
519        if scores.is_empty() {
520            return Ok(Vec::new());
521        }
522
523        // Sort by combined score descending
524        let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
525        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
526
527        // Apply temporal decay (before MMR)
528        if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
529            let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
530            match self.sqlite.message_timestamps(&ids).await {
531                Ok(timestamps) => {
532                    apply_temporal_decay(
533                        &mut ranked,
534                        &timestamps,
535                        self.temporal_decay_half_life_days,
536                    );
537                    ranked
538                        .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
539                }
540                Err(e) => {
541                    tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
542                }
543            }
544        }
545
546        // Apply MMR re-ranking (after decay, before truncation)
547        if self.mmr_enabled && !vector_results.is_empty() {
548            if let Some(qdrant) = &self.qdrant {
549                let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
550                match qdrant.get_vectors(&ids).await {
551                    Ok(vec_map) if !vec_map.is_empty() => {
552                        ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
553                    }
554                    Ok(_) => {
555                        ranked.truncate(limit);
556                    }
557                    Err(e) => {
558                        tracing::warn!("MMR: failed to fetch vectors: {e:#}");
559                        ranked.truncate(limit);
560                    }
561                }
562            } else {
563                ranked.truncate(limit);
564            }
565        } else {
566            ranked.truncate(limit);
567        }
568
569        let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
570        let messages = self.sqlite.messages_by_ids(&ids).await?;
571        let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
572
573        let recalled = ranked
574            .iter()
575            .filter_map(|(msg_id, score)| {
576                msg_map.get(msg_id).map(|msg| RecalledMessage {
577                    message: msg.clone(),
578                    #[expect(clippy::cast_possible_truncation)]
579                    score: *score as f32,
580                })
581            })
582            .collect();
583
584        Ok(recalled)
585    }
586
587    /// Check whether an embedding exists for a given message ID.
588    ///
589    /// # Errors
590    ///
591    /// Returns an error if the `SQLite` query fails.
592    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
593        match &self.qdrant {
594            Some(qdrant) => qdrant.has_embedding(message_id).await,
595            None => Ok(false),
596        }
597    }
598
599    /// Embed all messages that do not yet have embeddings.
600    ///
601    /// Returns the count of successfully embedded messages.
602    ///
603    /// # Errors
604    ///
605    /// Returns an error if collection initialization or database query fails.
606    /// Individual embedding failures are logged but do not stop processing.
607    pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
608        let Some(qdrant) = &self.qdrant else {
609            return Ok(0);
610        };
611        if !self.provider.supports_embeddings() {
612            return Ok(0);
613        }
614
615        let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
616
617        if unembedded.is_empty() {
618            return Ok(0);
619        }
620
621        let probe = self.provider.embed("probe").await?;
622        let vector_size = u64::try_from(probe.len())?;
623        qdrant.ensure_collection(vector_size).await?;
624
625        let mut count = 0;
626        for (msg_id, conversation_id, role, content) in &unembedded {
627            match self.provider.embed(content).await {
628                Ok(vector) => {
629                    if let Err(e) = qdrant
630                        .store(
631                            *msg_id,
632                            *conversation_id,
633                            role,
634                            vector,
635                            MessageKind::Regular,
636                            &self.embedding_model,
637                        )
638                        .await
639                    {
640                        tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
641                        continue;
642                    }
643                    count += 1;
644                }
645                Err(e) => {
646                    tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
647                }
648            }
649        }
650
651        tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
652        Ok(count)
653    }
654
655    /// Store a session summary into the dedicated `zeph_session_summaries` Qdrant collection.
656    ///
657    /// # Errors
658    ///
659    /// Returns an error if embedding or Qdrant storage fails.
660    pub async fn store_session_summary(
661        &self,
662        conversation_id: ConversationId,
663        summary_text: &str,
664    ) -> Result<(), MemoryError> {
665        let Some(qdrant) = &self.qdrant else {
666            return Ok(());
667        };
668        if !self.provider.supports_embeddings() {
669            return Ok(());
670        }
671
672        let vector = self.provider.embed(summary_text).await?;
673        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
674        qdrant
675            .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
676            .await?;
677
678        let payload = serde_json::json!({
679            "conversation_id": conversation_id.0,
680            "summary_text": summary_text,
681        });
682
683        qdrant
684            .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
685            .await?;
686
687        tracing::debug!(
688            conversation_id = conversation_id.0,
689            "stored session summary"
690        );
691        Ok(())
692    }
693
694    /// Search session summaries from other conversations.
695    ///
696    /// # Errors
697    ///
698    /// Returns an error if embedding or Qdrant search fails.
699    pub async fn search_session_summaries(
700        &self,
701        query: &str,
702        limit: usize,
703        exclude_conversation_id: Option<ConversationId>,
704    ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
705        let Some(qdrant) = &self.qdrant else {
706            return Ok(Vec::new());
707        };
708        if !self.provider.supports_embeddings() {
709            return Ok(Vec::new());
710        }
711
712        let vector = self.provider.embed(query).await?;
713        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
714        qdrant
715            .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
716            .await?;
717
718        let filter = exclude_conversation_id.map(|cid| VectorFilter {
719            must: vec![],
720            must_not: vec![FieldCondition {
721                field: "conversation_id".into(),
722                value: FieldValue::Integer(cid.0),
723            }],
724        });
725
726        let points = qdrant
727            .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
728            .await?;
729
730        let results = points
731            .into_iter()
732            .filter_map(|point| {
733                let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
734                let conversation_id =
735                    ConversationId(point.payload.get("conversation_id")?.as_i64()?);
736                Some(SessionSummaryResult {
737                    summary_text,
738                    score: point.score,
739                    conversation_id,
740                })
741            })
742            .collect();
743
744        Ok(results)
745    }
746
747    /// Access the underlying `SqliteStore` for operations that don't involve semantics.
748    #[must_use]
749    pub fn sqlite(&self) -> &SqliteStore {
750        &self.sqlite
751    }
752
753    /// Check if the vector store backend is reachable.
754    ///
755    /// Performs a real health check (Qdrant gRPC ping or `SQLite` query)
756    /// instead of just checking whether the client was created.
757    pub async fn is_vector_store_connected(&self) -> bool {
758        match self.qdrant.as_ref() {
759            Some(store) => store.health_check().await,
760            None => false,
761        }
762    }
763
764    /// Check if a vector store client is configured (may not be connected).
765    #[must_use]
766    pub fn has_vector_store(&self) -> bool {
767        self.qdrant.is_some()
768    }
769
770    /// Count messages in a conversation.
771    ///
772    /// # Errors
773    ///
774    /// Returns an error if the query fails.
775    pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
776        self.sqlite.count_messages(conversation_id).await
777    }
778
779    /// Count messages not yet covered by any summary.
780    ///
781    /// # Errors
782    ///
783    /// Returns an error if the query fails.
784    pub async fn unsummarized_message_count(
785        &self,
786        conversation_id: ConversationId,
787    ) -> Result<i64, MemoryError> {
788        let after_id = self
789            .sqlite
790            .latest_summary_last_message_id(conversation_id)
791            .await?
792            .unwrap_or(MessageId(0));
793        self.sqlite
794            .count_messages_after(conversation_id, after_id)
795            .await
796    }
797
798    /// Load all summaries for a conversation.
799    ///
800    /// # Errors
801    ///
802    /// Returns an error if the query fails.
803    pub async fn load_summaries(
804        &self,
805        conversation_id: ConversationId,
806    ) -> Result<Vec<Summary>, MemoryError> {
807        let rows = self.sqlite.load_summaries(conversation_id).await?;
808        let summaries = rows
809            .into_iter()
810            .map(
811                |(
812                    id,
813                    conversation_id,
814                    content,
815                    first_message_id,
816                    last_message_id,
817                    token_estimate,
818                )| {
819                    Summary {
820                        id,
821                        conversation_id,
822                        content,
823                        first_message_id,
824                        last_message_id,
825                        token_estimate,
826                    }
827                },
828            )
829            .collect();
830        Ok(summaries)
831    }
832
833    /// Generate a summary of the oldest unsummarized messages.
834    ///
835    /// Returns `Ok(None)` if there are not enough messages to summarize.
836    ///
837    /// # Errors
838    ///
839    /// Returns an error if LLM call or database operation fails.
840    pub async fn summarize(
841        &self,
842        conversation_id: ConversationId,
843        message_count: usize,
844    ) -> Result<Option<i64>, MemoryError> {
845        let total = self.sqlite.count_messages(conversation_id).await?;
846
847        if total <= i64::try_from(message_count)? {
848            return Ok(None);
849        }
850
851        let after_id = self
852            .sqlite
853            .latest_summary_last_message_id(conversation_id)
854            .await?
855            .unwrap_or(MessageId(0));
856
857        let messages = self
858            .sqlite
859            .load_messages_range(conversation_id, after_id, message_count)
860            .await?;
861
862        if messages.is_empty() {
863            return Ok(None);
864        }
865
866        let prompt = build_summarization_prompt(&messages);
867        let chat_messages = vec![Message {
868            role: Role::User,
869            content: prompt,
870            parts: vec![],
871            metadata: MessageMetadata::default(),
872        }];
873
874        let structured = match self
875            .provider
876            .chat_typed_erased::<StructuredSummary>(&chat_messages)
877            .await
878        {
879            Ok(s) => s,
880            Err(e) => {
881                tracing::warn!(
882                    "structured summarization failed, falling back to plain text: {e:#}"
883                );
884                let plain = self.provider.chat(&chat_messages).await?;
885                StructuredSummary {
886                    summary: plain,
887                    key_facts: vec![],
888                    entities: vec![],
889                }
890            }
891        };
892        let summary_text = &structured.summary;
893
894        let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
895        let first_message_id = messages[0].0;
896        let last_message_id = messages[messages.len() - 1].0;
897
898        let summary_id = self
899            .sqlite
900            .save_summary(
901                conversation_id,
902                summary_text,
903                first_message_id,
904                last_message_id,
905                token_estimate,
906            )
907            .await?;
908
909        if let Some(qdrant) = &self.qdrant
910            && self.provider.supports_embeddings()
911        {
912            match self.provider.embed(summary_text).await {
913                Ok(vector) => {
914                    // Ensure collection exists before storing
915                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
916                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
917                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
918                    } else if let Err(e) = qdrant
919                        .store(
920                            MessageId(summary_id),
921                            conversation_id,
922                            "system",
923                            vector,
924                            MessageKind::Summary,
925                            &self.embedding_model,
926                        )
927                        .await
928                    {
929                        tracing::warn!("Failed to embed summary: {e:#}");
930                    }
931                }
932                Err(e) => {
933                    tracing::warn!("Failed to generate summary embedding: {e:#}");
934                }
935            }
936        }
937
938        // Store key facts as individual Qdrant points
939        if !structured.key_facts.is_empty() {
940            self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
941                .await;
942        }
943
944        Ok(Some(summary_id))
945    }
946
947    async fn store_key_facts(
948        &self,
949        conversation_id: ConversationId,
950        source_summary_id: i64,
951        key_facts: &[String],
952    ) {
953        let Some(qdrant) = &self.qdrant else {
954            return;
955        };
956        if !self.provider.supports_embeddings() {
957            return;
958        }
959
960        let Some(first_fact) = key_facts.first() else {
961            return;
962        };
963        let first_vector = match self.provider.embed(first_fact).await {
964            Ok(v) => v,
965            Err(e) => {
966                tracing::warn!("Failed to embed key fact: {e:#}");
967                return;
968            }
969        };
970        let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
971        if let Err(e) = qdrant
972            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
973            .await
974        {
975            tracing::warn!("Failed to ensure key_facts collection: {e:#}");
976            return;
977        }
978
979        let first_payload = serde_json::json!({
980            "conversation_id": conversation_id.0,
981            "fact_text": first_fact,
982            "source_summary_id": source_summary_id,
983        });
984        if let Err(e) = qdrant
985            .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
986            .await
987        {
988            tracing::warn!("Failed to store key fact: {e:#}");
989        }
990
991        for fact in &key_facts[1..] {
992            match self.provider.embed(fact).await {
993                Ok(vector) => {
994                    let payload = serde_json::json!({
995                        "conversation_id": conversation_id.0,
996                        "fact_text": fact,
997                        "source_summary_id": source_summary_id,
998                    });
999                    if let Err(e) = qdrant
1000                        .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
1001                        .await
1002                    {
1003                        tracing::warn!("Failed to store key fact: {e:#}");
1004                    }
1005                }
1006                Err(e) => {
1007                    tracing::warn!("Failed to embed key fact: {e:#}");
1008                }
1009            }
1010        }
1011    }
1012
1013    /// Search key facts extracted from conversation summaries.
1014    ///
1015    /// # Errors
1016    ///
1017    /// Returns an error if embedding or Qdrant search fails.
1018    pub async fn search_key_facts(
1019        &self,
1020        query: &str,
1021        limit: usize,
1022    ) -> Result<Vec<String>, MemoryError> {
1023        let Some(qdrant) = &self.qdrant else {
1024            return Ok(Vec::new());
1025        };
1026        if !self.provider.supports_embeddings() {
1027            return Ok(Vec::new());
1028        }
1029
1030        let vector = self.provider.embed(query).await?;
1031        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1032        qdrant
1033            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1034            .await?;
1035
1036        let points = qdrant
1037            .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
1038            .await?;
1039
1040        let facts = points
1041            .into_iter()
1042            .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
1043            .collect();
1044
1045        Ok(facts)
1046    }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051    use zeph_llm::mock::MockProvider;
1052    use zeph_llm::provider::Role;
1053
1054    use super::*;
1055
1056    fn test_provider() -> AnyProvider {
1057        AnyProvider::Mock(MockProvider::default())
1058    }
1059
1060    async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1061        let provider = test_provider();
1062        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1063
1064        SemanticMemory {
1065            sqlite,
1066            qdrant: None,
1067            provider,
1068            embedding_model: "test-model".into(),
1069            vector_weight: 0.7,
1070            keyword_weight: 0.3,
1071            temporal_decay_enabled: false,
1072            temporal_decay_half_life_days: 30,
1073            mmr_enabled: false,
1074            mmr_lambda: 0.7,
1075            token_counter: Arc::new(TokenCounter::new()),
1076        }
1077    }
1078
1079    #[tokio::test]
1080    async fn remember_saves_to_sqlite() {
1081        let memory = test_semantic_memory(false).await;
1082
1083        let cid = memory.sqlite.create_conversation().await.unwrap();
1084        let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1085
1086        assert_eq!(msg_id, MessageId(1));
1087
1088        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1089        assert_eq!(history.len(), 1);
1090        assert_eq!(history[0].role, Role::User);
1091        assert_eq!(history[0].content, "hello");
1092    }
1093
1094    #[tokio::test]
1095    async fn remember_with_parts_saves_parts_json() {
1096        let memory = test_semantic_memory(false).await;
1097        let cid = memory.sqlite.create_conversation().await.unwrap();
1098
1099        let parts_json =
1100            r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1101        let (msg_id, _embedding_stored) = memory
1102            .remember_with_parts(cid, "assistant", "tool output", parts_json)
1103            .await
1104            .unwrap();
1105        assert!(msg_id > MessageId(0));
1106
1107        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1108        assert_eq!(history.len(), 1);
1109        assert_eq!(history[0].content, "tool output");
1110    }
1111
1112    #[tokio::test]
1113    async fn recall_returns_empty_without_qdrant() {
1114        let memory = test_semantic_memory(true).await;
1115
1116        let recalled = memory.recall("test", 5, None).await.unwrap();
1117        assert!(recalled.is_empty());
1118    }
1119
1120    #[tokio::test]
1121    async fn has_embedding_without_qdrant() {
1122        let memory = test_semantic_memory(true).await;
1123
1124        let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1125        assert!(!has_embedding);
1126    }
1127
1128    #[tokio::test]
1129    async fn embed_missing_without_qdrant() {
1130        let memory = test_semantic_memory(true).await;
1131
1132        let count = memory.embed_missing().await.unwrap();
1133        assert_eq!(count, 0);
1134    }
1135
1136    #[tokio::test]
1137    async fn sqlite_accessor() {
1138        let memory = test_semantic_memory(false).await;
1139
1140        let cid = memory.sqlite().create_conversation().await.unwrap();
1141        assert_eq!(cid, ConversationId(1));
1142
1143        memory
1144            .sqlite()
1145            .save_message(cid, "user", "test")
1146            .await
1147            .unwrap();
1148
1149        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1150        assert_eq!(history.len(), 1);
1151    }
1152
1153    #[tokio::test]
1154    async fn has_vector_store_returns_false_when_unavailable() {
1155        let memory = test_semantic_memory(false).await;
1156        assert!(!memory.has_vector_store());
1157    }
1158
1159    #[tokio::test]
1160    async fn is_vector_store_connected_returns_false_when_unavailable() {
1161        let memory = test_semantic_memory(false).await;
1162        assert!(!memory.is_vector_store_connected().await);
1163    }
1164
1165    #[tokio::test]
1166    async fn recall_returns_empty_when_embeddings_not_supported() {
1167        let memory = test_semantic_memory(false).await;
1168
1169        let recalled = memory.recall("test", 5, None).await.unwrap();
1170        assert!(recalled.is_empty());
1171    }
1172
1173    #[tokio::test]
1174    async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1175        let memory = test_semantic_memory(false).await;
1176
1177        let cid = memory.sqlite().create_conversation().await.unwrap();
1178        memory
1179            .sqlite()
1180            .save_message(cid, "user", "test")
1181            .await
1182            .unwrap();
1183
1184        let count = memory.embed_missing().await.unwrap();
1185        assert_eq!(count, 0);
1186    }
1187
1188    #[tokio::test]
1189    async fn message_count_empty_conversation() {
1190        let memory = test_semantic_memory(false).await;
1191        let cid = memory.sqlite().create_conversation().await.unwrap();
1192
1193        let count = memory.message_count(cid).await.unwrap();
1194        assert_eq!(count, 0);
1195    }
1196
1197    #[tokio::test]
1198    async fn message_count_after_saves() {
1199        let memory = test_semantic_memory(false).await;
1200        let cid = memory.sqlite().create_conversation().await.unwrap();
1201
1202        memory.remember(cid, "user", "msg1").await.unwrap();
1203        memory.remember(cid, "assistant", "msg2").await.unwrap();
1204
1205        let count = memory.message_count(cid).await.unwrap();
1206        assert_eq!(count, 2);
1207    }
1208
1209    #[tokio::test]
1210    async fn unsummarized_count_decreases_after_summary() {
1211        let memory = test_semantic_memory(false).await;
1212        let cid = memory.sqlite().create_conversation().await.unwrap();
1213
1214        for i in 0..10 {
1215            memory
1216                .remember(cid, "user", &format!("msg{i}"))
1217                .await
1218                .unwrap();
1219        }
1220        assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1221
1222        memory.summarize(cid, 5).await.unwrap();
1223
1224        assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1225        assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1226    }
1227
1228    #[tokio::test]
1229    async fn load_summaries_empty() {
1230        let memory = test_semantic_memory(false).await;
1231        let cid = memory.sqlite().create_conversation().await.unwrap();
1232
1233        let summaries = memory.load_summaries(cid).await.unwrap();
1234        assert!(summaries.is_empty());
1235    }
1236
1237    #[tokio::test]
1238    async fn load_summaries_ordered() {
1239        let memory = test_semantic_memory(false).await;
1240        let cid = memory.sqlite().create_conversation().await.unwrap();
1241
1242        let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1243        let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1244        let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1245
1246        let s1 = memory
1247            .sqlite()
1248            .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1249            .await
1250            .unwrap();
1251        let s2 = memory
1252            .sqlite()
1253            .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1254            .await
1255            .unwrap();
1256
1257        let summaries = memory.load_summaries(cid).await.unwrap();
1258        assert_eq!(summaries.len(), 2);
1259        assert_eq!(summaries[0].id, s1);
1260        assert_eq!(summaries[0].content, "summary1");
1261        assert_eq!(summaries[1].id, s2);
1262        assert_eq!(summaries[1].content, "summary2");
1263    }
1264
1265    #[tokio::test]
1266    async fn summarize_below_threshold() {
1267        let memory = test_semantic_memory(false).await;
1268        let cid = memory.sqlite().create_conversation().await.unwrap();
1269
1270        memory.remember(cid, "user", "hello").await.unwrap();
1271
1272        let result = memory.summarize(cid, 10).await.unwrap();
1273        assert!(result.is_none());
1274    }
1275
1276    #[tokio::test]
1277    async fn summarize_stores_summary() {
1278        let memory = test_semantic_memory(false).await;
1279        let cid = memory.sqlite().create_conversation().await.unwrap();
1280
1281        for i in 0..5 {
1282            memory
1283                .remember(cid, "user", &format!("message {i}"))
1284                .await
1285                .unwrap();
1286        }
1287
1288        let summary_id = memory.summarize(cid, 3).await.unwrap();
1289        assert!(summary_id.is_some());
1290
1291        let summaries = memory.load_summaries(cid).await.unwrap();
1292        assert_eq!(summaries.len(), 1);
1293        assert_eq!(summaries[0].id, summary_id.unwrap());
1294        assert!(!summaries[0].content.is_empty());
1295    }
1296
1297    #[tokio::test]
1298    async fn summarize_respects_previous_summaries() {
1299        let memory = test_semantic_memory(false).await;
1300        let cid = memory.sqlite().create_conversation().await.unwrap();
1301
1302        for i in 0..10 {
1303            memory
1304                .remember(cid, "user", &format!("message {i}"))
1305                .await
1306                .unwrap();
1307        }
1308
1309        let s1 = memory.summarize(cid, 3).await.unwrap();
1310        assert!(s1.is_some());
1311
1312        let s2 = memory.summarize(cid, 3).await.unwrap();
1313        assert!(s2.is_some());
1314
1315        let summaries = memory.load_summaries(cid).await.unwrap();
1316        assert_eq!(summaries.len(), 2);
1317        assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1318    }
1319
1320    #[tokio::test]
1321    async fn remember_multiple_messages_increments_ids() {
1322        let memory = test_semantic_memory(false).await;
1323        let cid = memory.sqlite.create_conversation().await.unwrap();
1324
1325        let id1 = memory.remember(cid, "user", "first").await.unwrap();
1326        let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1327        let id3 = memory.remember(cid, "user", "third").await.unwrap();
1328
1329        assert!(id1 < id2);
1330        assert!(id2 < id3);
1331    }
1332
1333    #[tokio::test]
1334    async fn message_count_across_conversations() {
1335        let memory = test_semantic_memory(false).await;
1336        let cid1 = memory.sqlite().create_conversation().await.unwrap();
1337        let cid2 = memory.sqlite().create_conversation().await.unwrap();
1338
1339        memory.remember(cid1, "user", "msg1").await.unwrap();
1340        memory.remember(cid1, "user", "msg2").await.unwrap();
1341        memory.remember(cid2, "user", "msg3").await.unwrap();
1342
1343        assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1344        assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1345    }
1346
1347    #[tokio::test]
1348    async fn summarize_exact_threshold_returns_none() {
1349        let memory = test_semantic_memory(false).await;
1350        let cid = memory.sqlite().create_conversation().await.unwrap();
1351
1352        for i in 0..3 {
1353            memory
1354                .remember(cid, "user", &format!("msg {i}"))
1355                .await
1356                .unwrap();
1357        }
1358
1359        let result = memory.summarize(cid, 3).await.unwrap();
1360        assert!(result.is_none());
1361    }
1362
1363    #[tokio::test]
1364    async fn summarize_one_above_threshold_produces_summary() {
1365        let memory = test_semantic_memory(false).await;
1366        let cid = memory.sqlite().create_conversation().await.unwrap();
1367
1368        for i in 0..4 {
1369            memory
1370                .remember(cid, "user", &format!("msg {i}"))
1371                .await
1372                .unwrap();
1373        }
1374
1375        let result = memory.summarize(cid, 3).await.unwrap();
1376        assert!(result.is_some());
1377    }
1378
1379    #[tokio::test]
1380    async fn summary_fields_populated() {
1381        let memory = test_semantic_memory(false).await;
1382        let cid = memory.sqlite().create_conversation().await.unwrap();
1383
1384        for i in 0..5 {
1385            memory
1386                .remember(cid, "user", &format!("msg {i}"))
1387                .await
1388                .unwrap();
1389        }
1390
1391        memory.summarize(cid, 3).await.unwrap();
1392        let summaries = memory.load_summaries(cid).await.unwrap();
1393        let s = &summaries[0];
1394
1395        assert_eq!(s.conversation_id, cid);
1396        assert!(s.first_message_id > MessageId(0));
1397        assert!(s.last_message_id >= s.first_message_id);
1398        assert!(s.token_estimate >= 0);
1399        assert!(!s.content.is_empty());
1400    }
1401
1402    #[test]
1403    fn build_summarization_prompt_format() {
1404        let messages = vec![
1405            (MessageId(1), "user".into(), "Hello".into()),
1406            (MessageId(2), "assistant".into(), "Hi there".into()),
1407        ];
1408        let prompt = build_summarization_prompt(&messages);
1409        assert!(prompt.contains("user: Hello"));
1410        assert!(prompt.contains("assistant: Hi there"));
1411        assert!(prompt.contains("key_facts"));
1412    }
1413
1414    #[test]
1415    fn build_summarization_prompt_empty() {
1416        let messages: Vec<(MessageId, String, String)> = vec![];
1417        let prompt = build_summarization_prompt(&messages);
1418        assert!(prompt.contains("key_facts"));
1419    }
1420
1421    #[test]
1422    fn structured_summary_deserialize() {
1423        let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1424        let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1425        assert_eq!(ss.summary, "s");
1426        assert_eq!(ss.key_facts.len(), 2);
1427        assert_eq!(ss.entities.len(), 1);
1428    }
1429
1430    #[test]
1431    fn structured_summary_empty_facts() {
1432        let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
1433        let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1434        assert!(ss.key_facts.is_empty());
1435        assert!(ss.entities.is_empty());
1436    }
1437
1438    #[tokio::test]
1439    async fn search_key_facts_no_qdrant_empty() {
1440        let memory = test_semantic_memory(false).await;
1441        let facts = memory.search_key_facts("query", 5).await.unwrap();
1442        assert!(facts.is_empty());
1443    }
1444
1445    #[test]
1446    fn recalled_message_debug() {
1447        let recalled = RecalledMessage {
1448            message: Message {
1449                role: Role::User,
1450                content: "test".into(),
1451                parts: vec![],
1452                metadata: MessageMetadata::default(),
1453            },
1454            score: 0.95,
1455        };
1456        let dbg = format!("{recalled:?}");
1457        assert!(dbg.contains("RecalledMessage"));
1458        assert!(dbg.contains("0.95"));
1459    }
1460
1461    #[test]
1462    fn summary_clone() {
1463        let summary = Summary {
1464            id: 1,
1465            conversation_id: ConversationId(2),
1466            content: "test summary".into(),
1467            first_message_id: MessageId(1),
1468            last_message_id: MessageId(5),
1469            token_estimate: 10,
1470        };
1471        let cloned = summary.clone();
1472        assert_eq!(summary.id, cloned.id);
1473        assert_eq!(summary.content, cloned.content);
1474    }
1475
1476    #[tokio::test]
1477    async fn remember_preserves_role_mapping() {
1478        let memory = test_semantic_memory(false).await;
1479        let cid = memory.sqlite.create_conversation().await.unwrap();
1480
1481        memory.remember(cid, "user", "u").await.unwrap();
1482        memory.remember(cid, "assistant", "a").await.unwrap();
1483        memory.remember(cid, "system", "s").await.unwrap();
1484
1485        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1486        assert_eq!(history.len(), 3);
1487        assert_eq!(history[0].role, Role::User);
1488        assert_eq!(history[1].role, Role::Assistant);
1489        assert_eq!(history[2].role, Role::System);
1490    }
1491
1492    #[tokio::test]
1493    async fn new_with_invalid_qdrant_url_graceful() {
1494        let mut mock = MockProvider::default();
1495        mock.supports_embeddings = true;
1496        let provider = AnyProvider::Mock(mock);
1497        let result =
1498            SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
1499        assert!(result.is_ok());
1500    }
1501
1502    #[tokio::test]
1503    async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
1504        // Build SemanticMemory with EmbeddingStore backed by SQLite instead of Qdrant
1505        let mut mock = MockProvider::default();
1506        mock.supports_embeddings = true;
1507        // Provide deterministic embedding vectors: embed returns a fixed 4-element vector
1508        // MockProvider.embed always returns the same vector, so cosine similarity = 1.0
1509        let provider = AnyProvider::Mock(mock);
1510
1511        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1512        let pool = sqlite.pool().clone();
1513        let qdrant = Some(crate::embedding_store::EmbeddingStore::new_sqlite(pool));
1514
1515        let memory = SemanticMemory {
1516            sqlite,
1517            qdrant,
1518            provider,
1519            embedding_model: "test-model".into(),
1520            vector_weight: 0.7,
1521            keyword_weight: 0.3,
1522            temporal_decay_enabled: false,
1523            temporal_decay_half_life_days: 30,
1524            mmr_enabled: false,
1525            mmr_lambda: 0.7,
1526            token_counter: Arc::new(TokenCounter::new()),
1527        };
1528
1529        let cid = memory.sqlite().create_conversation().await.unwrap();
1530
1531        // remember → stores in SQLite + SQLite vector store
1532        let id1 = memory
1533            .remember(cid, "user", "rust async programming")
1534            .await
1535            .unwrap();
1536        let id2 = memory
1537            .remember(cid, "assistant", "use tokio for async")
1538            .await
1539            .unwrap();
1540        assert!(id1 < id2);
1541
1542        // recall → should return results via FTS5 keyword search
1543        let recalled = memory.recall("rust", 5, None).await.unwrap();
1544        assert!(
1545            !recalled.is_empty(),
1546            "recall must return at least one result"
1547        );
1548
1549        // Verify history is accessible
1550        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1551        assert_eq!(history.len(), 2);
1552        assert_eq!(history[0].content, "rust async programming");
1553    }
1554
1555    #[tokio::test]
1556    async fn remember_with_embeddings_supported_but_no_qdrant() {
1557        let memory = test_semantic_memory(true).await;
1558        let cid = memory.sqlite.create_conversation().await.unwrap();
1559
1560        let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
1561        assert!(msg_id > MessageId(0));
1562
1563        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1564        assert_eq!(history.len(), 1);
1565        assert_eq!(history[0].content, "hello embed");
1566    }
1567
1568    #[tokio::test]
1569    async fn remember_verifies_content_via_load_history() {
1570        let memory = test_semantic_memory(false).await;
1571        let cid = memory.sqlite.create_conversation().await.unwrap();
1572
1573        memory.remember(cid, "user", "alpha").await.unwrap();
1574        memory.remember(cid, "assistant", "beta").await.unwrap();
1575        memory.remember(cid, "user", "gamma").await.unwrap();
1576
1577        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1578        assert_eq!(history.len(), 3);
1579        assert_eq!(history[0].content, "alpha");
1580        assert_eq!(history[1].content, "beta");
1581        assert_eq!(history[2].content, "gamma");
1582    }
1583
1584    #[tokio::test]
1585    async fn message_count_multiple_conversations_isolated() {
1586        let memory = test_semantic_memory(false).await;
1587        let cid1 = memory.sqlite().create_conversation().await.unwrap();
1588        let cid2 = memory.sqlite().create_conversation().await.unwrap();
1589        let cid3 = memory.sqlite().create_conversation().await.unwrap();
1590
1591        for _ in 0..5 {
1592            memory.remember(cid1, "user", "msg").await.unwrap();
1593        }
1594        for _ in 0..3 {
1595            memory.remember(cid2, "user", "msg").await.unwrap();
1596        }
1597
1598        assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
1599        assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
1600        assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
1601    }
1602
1603    #[tokio::test]
1604    async fn summarize_empty_messages_range_returns_none() {
1605        let memory = test_semantic_memory(false).await;
1606        let cid = memory.sqlite().create_conversation().await.unwrap();
1607
1608        for i in 0..6 {
1609            memory
1610                .remember(cid, "user", &format!("msg {i}"))
1611                .await
1612                .unwrap();
1613        }
1614
1615        memory.summarize(cid, 3).await.unwrap();
1616        memory.summarize(cid, 3).await.unwrap();
1617
1618        let summaries = memory.load_summaries(cid).await.unwrap();
1619        assert_eq!(summaries.len(), 2);
1620    }
1621
1622    #[tokio::test]
1623    async fn summarize_token_estimate_populated() {
1624        let memory = test_semantic_memory(false).await;
1625        let cid = memory.sqlite().create_conversation().await.unwrap();
1626
1627        for i in 0..5 {
1628            memory
1629                .remember(cid, "user", &format!("message {i}"))
1630                .await
1631                .unwrap();
1632        }
1633
1634        memory.summarize(cid, 3).await.unwrap();
1635        let summaries = memory.load_summaries(cid).await.unwrap();
1636        let token_est = summaries[0].token_estimate;
1637        assert!(token_est > 0);
1638    }
1639
1640    #[tokio::test]
1641    async fn summarize_fails_when_provider_chat_fails() {
1642        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1643        let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
1644            "http://127.0.0.1:1",
1645            "test".into(),
1646            "embed".into(),
1647        ));
1648        let memory = SemanticMemory {
1649            sqlite,
1650            qdrant: None,
1651            provider,
1652            embedding_model: "test".into(),
1653            vector_weight: 0.7,
1654            keyword_weight: 0.3,
1655            temporal_decay_enabled: false,
1656            temporal_decay_half_life_days: 30,
1657            mmr_enabled: false,
1658            mmr_lambda: 0.7,
1659            token_counter: Arc::new(TokenCounter::new()),
1660        };
1661        let cid = memory.sqlite().create_conversation().await.unwrap();
1662
1663        for i in 0..5 {
1664            memory
1665                .remember(cid, "user", &format!("msg {i}"))
1666                .await
1667                .unwrap();
1668        }
1669
1670        let result = memory.summarize(cid, 3).await;
1671        assert!(result.is_err());
1672    }
1673
1674    #[tokio::test]
1675    async fn embed_missing_without_embedding_support_returns_zero() {
1676        let memory = test_semantic_memory(false).await;
1677        let cid = memory.sqlite().create_conversation().await.unwrap();
1678        memory
1679            .sqlite()
1680            .save_message(cid, "user", "test message")
1681            .await
1682            .unwrap();
1683
1684        let count = memory.embed_missing().await.unwrap();
1685        assert_eq!(count, 0);
1686    }
1687
1688    #[tokio::test]
1689    async fn has_embedding_returns_false_when_no_qdrant() {
1690        let memory = test_semantic_memory(false).await;
1691        let cid = memory.sqlite.create_conversation().await.unwrap();
1692        let msg_id = memory.remember(cid, "user", "test").await.unwrap();
1693        assert!(!memory.has_embedding(msg_id).await.unwrap());
1694    }
1695
1696    #[tokio::test]
1697    async fn recall_empty_without_qdrant_regardless_of_filter() {
1698        let memory = test_semantic_memory(true).await;
1699        let filter = SearchFilter {
1700            conversation_id: Some(ConversationId(1)),
1701            role: None,
1702        };
1703        let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
1704        assert!(recalled.is_empty());
1705    }
1706
1707    #[tokio::test]
1708    async fn summarize_message_range_bounds() {
1709        let memory = test_semantic_memory(false).await;
1710        let cid = memory.sqlite().create_conversation().await.unwrap();
1711
1712        for i in 0..8 {
1713            memory
1714                .remember(cid, "user", &format!("msg {i}"))
1715                .await
1716                .unwrap();
1717        }
1718
1719        let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
1720        let summaries = memory.load_summaries(cid).await.unwrap();
1721        assert_eq!(summaries.len(), 1);
1722        assert_eq!(summaries[0].id, summary_id);
1723        assert!(summaries[0].first_message_id >= MessageId(1));
1724        assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
1725    }
1726
1727    #[test]
1728    fn build_summarization_prompt_preserves_order() {
1729        let messages = vec![
1730            (MessageId(1), "user".into(), "first".into()),
1731            (MessageId(2), "assistant".into(), "second".into()),
1732            (MessageId(3), "user".into(), "third".into()),
1733        ];
1734        let prompt = build_summarization_prompt(&messages);
1735        let first_pos = prompt.find("user: first").unwrap();
1736        let second_pos = prompt.find("assistant: second").unwrap();
1737        let third_pos = prompt.find("user: third").unwrap();
1738        assert!(first_pos < second_pos);
1739        assert!(second_pos < third_pos);
1740    }
1741
1742    #[test]
1743    fn summary_debug() {
1744        let summary = Summary {
1745            id: 1,
1746            conversation_id: ConversationId(2),
1747            content: "test".into(),
1748            first_message_id: MessageId(1),
1749            last_message_id: MessageId(5),
1750            token_estimate: 10,
1751        };
1752        let dbg = format!("{summary:?}");
1753        assert!(dbg.contains("Summary"));
1754    }
1755
1756    #[tokio::test]
1757    async fn message_count_nonexistent_conversation() {
1758        let memory = test_semantic_memory(false).await;
1759        let count = memory.message_count(ConversationId(999)).await.unwrap();
1760        assert_eq!(count, 0);
1761    }
1762
1763    #[tokio::test]
1764    async fn load_summaries_nonexistent_conversation() {
1765        let memory = test_semantic_memory(false).await;
1766        let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
1767        assert!(summaries.is_empty());
1768    }
1769
1770    #[tokio::test]
1771    async fn store_session_summary_no_qdrant_noop() {
1772        let memory = test_semantic_memory(true).await;
1773        let result = memory
1774            .store_session_summary(ConversationId(1), "test summary")
1775            .await;
1776        assert!(result.is_ok());
1777    }
1778
1779    #[tokio::test]
1780    async fn store_session_summary_no_embeddings_noop() {
1781        let memory = test_semantic_memory(false).await;
1782        let result = memory
1783            .store_session_summary(ConversationId(1), "test summary")
1784            .await;
1785        assert!(result.is_ok());
1786    }
1787
1788    #[tokio::test]
1789    async fn search_session_summaries_no_qdrant_empty() {
1790        let memory = test_semantic_memory(true).await;
1791        let results = memory
1792            .search_session_summaries("query", 5, None)
1793            .await
1794            .unwrap();
1795        assert!(results.is_empty());
1796    }
1797
1798    #[tokio::test]
1799    async fn search_session_summaries_no_embeddings_empty() {
1800        let memory = test_semantic_memory(false).await;
1801        let results = memory
1802            .search_session_summaries("query", 5, Some(ConversationId(1)))
1803            .await
1804            .unwrap();
1805        assert!(results.is_empty());
1806    }
1807
1808    #[test]
1809    fn session_summary_result_debug() {
1810        let result = SessionSummaryResult {
1811            summary_text: "test".into(),
1812            score: 0.9,
1813            conversation_id: ConversationId(1),
1814        };
1815        let dbg = format!("{result:?}");
1816        assert!(dbg.contains("SessionSummaryResult"));
1817    }
1818
1819    #[test]
1820    fn session_summary_result_clone() {
1821        let result = SessionSummaryResult {
1822            summary_text: "test".into(),
1823            score: 0.9,
1824            conversation_id: ConversationId(1),
1825        };
1826        let cloned = result.clone();
1827        assert_eq!(result.summary_text, cloned.summary_text);
1828        assert_eq!(result.conversation_id, cloned.conversation_id);
1829    }
1830
1831    #[tokio::test]
1832    async fn recall_fts5_fallback_without_qdrant() {
1833        let memory = test_semantic_memory(false).await;
1834        let cid = memory.sqlite.create_conversation().await.unwrap();
1835
1836        memory
1837            .remember(cid, "user", "rust programming guide")
1838            .await
1839            .unwrap();
1840        memory
1841            .remember(cid, "assistant", "python tutorial")
1842            .await
1843            .unwrap();
1844        memory
1845            .remember(cid, "user", "advanced rust patterns")
1846            .await
1847            .unwrap();
1848
1849        let recalled = memory.recall("rust", 5, None).await.unwrap();
1850        assert_eq!(recalled.len(), 2);
1851        assert!(recalled[0].score >= recalled[1].score);
1852    }
1853
1854    #[tokio::test]
1855    async fn recall_fts5_fallback_with_filter() {
1856        let memory = test_semantic_memory(false).await;
1857        let cid1 = memory.sqlite.create_conversation().await.unwrap();
1858        let cid2 = memory.sqlite.create_conversation().await.unwrap();
1859
1860        memory.remember(cid1, "user", "hello world").await.unwrap();
1861        memory
1862            .remember(cid2, "user", "hello universe")
1863            .await
1864            .unwrap();
1865
1866        let filter = SearchFilter {
1867            conversation_id: Some(cid1),
1868            role: None,
1869        };
1870        let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
1871        assert_eq!(recalled.len(), 1);
1872    }
1873
1874    #[tokio::test]
1875    async fn recall_fts5_no_matches_returns_empty() {
1876        let memory = test_semantic_memory(false).await;
1877        let cid = memory.sqlite.create_conversation().await.unwrap();
1878
1879        memory.remember(cid, "user", "hello world").await.unwrap();
1880
1881        let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
1882        assert!(recalled.is_empty());
1883    }
1884
1885    #[tokio::test]
1886    async fn recall_fts5_respects_limit() {
1887        let memory = test_semantic_memory(false).await;
1888        let cid = memory.sqlite.create_conversation().await.unwrap();
1889
1890        for i in 0..10 {
1891            memory
1892                .remember(cid, "user", &format!("test message number {i}"))
1893                .await
1894                .unwrap();
1895        }
1896
1897        let recalled = memory.recall("test", 3, None).await.unwrap();
1898        assert_eq!(recalled.len(), 3);
1899    }
1900
1901    // Priority 2: summarize fallback path
1902
1903    #[tokio::test]
1904    async fn summarize_fallback_to_plain_text_when_structured_fails() {
1905        // Use OllamaProvider pointing at an unreachable URL for chat_typed_erased,
1906        // but MockProvider for the plain chat call.
1907        // The easiest way: MockProvider returns non-JSON plain text so chat_typed_erased
1908        // (which uses chat() + JSON parse) will fail to parse, then falls back to chat().
1909        // However MockProvider.chat_typed calls chat() which returns default_response.
1910        // chat_typed tries to parse it as JSON → fails → retries → fails → returns StructuredParse error.
1911        // Then the fallback calls plain chat() which succeeds.
1912        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1913        let mut mock = MockProvider::default();
1914        // First two calls go to chat_typed (attempt + retry), third call is the plain fallback
1915        mock.default_response = "plain text summary".into();
1916        let provider = AnyProvider::Mock(mock);
1917
1918        let memory = SemanticMemory {
1919            sqlite,
1920            qdrant: None,
1921            provider,
1922            embedding_model: "test".into(),
1923            vector_weight: 0.7,
1924            keyword_weight: 0.3,
1925            temporal_decay_enabled: false,
1926            temporal_decay_half_life_days: 30,
1927            mmr_enabled: false,
1928            mmr_lambda: 0.7,
1929            token_counter: Arc::new(TokenCounter::new()),
1930        };
1931
1932        let cid = memory.sqlite().create_conversation().await.unwrap();
1933        for i in 0..5 {
1934            memory
1935                .remember(cid, "user", &format!("msg {i}"))
1936                .await
1937                .unwrap();
1938        }
1939
1940        let result = memory.summarize(cid, 3).await;
1941        // The summarize will either succeed (with plain text fallback) or fail
1942        // depending on how many retries chat_typed_erased does internally.
1943        // With MockProvider returning non-JSON plain text, chat_typed fails to parse.
1944        // The fallback plain chat() returns "plain text summary".
1945        // Result should be Ok with a summary stored.
1946        assert!(result.is_ok());
1947        let summaries = memory.load_summaries(cid).await.unwrap();
1948        assert_eq!(summaries.len(), 1);
1949        assert!(!summaries[0].content.is_empty());
1950    }
1951
1952    // Temporal decay tests
1953
1954    #[test]
1955    fn temporal_decay_disabled_leaves_scores_unchanged() {
1956        let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
1957        let timestamps = std::collections::HashMap::new();
1958        apply_temporal_decay(&mut ranked, &timestamps, 30);
1959        assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
1960        assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
1961    }
1962
1963    #[test]
1964    fn temporal_decay_zero_age_preserves_score() {
1965        let now = std::time::SystemTime::now()
1966            .duration_since(std::time::UNIX_EPOCH)
1967            .unwrap_or_default()
1968            .as_secs()
1969            .cast_signed();
1970        let mut ranked = vec![(MessageId(1), 1.0f64)];
1971        let mut timestamps = std::collections::HashMap::new();
1972        timestamps.insert(MessageId(1), now);
1973        apply_temporal_decay(&mut ranked, &timestamps, 30);
1974        // age = 0 days, exp(0) = 1.0 → no change
1975        assert!((ranked[0].1 - 1.0).abs() < 0.01);
1976    }
1977
1978    #[test]
1979    fn temporal_decay_half_life_halves_score() {
1980        // Age exactly half_life_days → score should be halved
1981        let half_life = 30u32;
1982        let age_secs = i64::from(half_life) * 86400;
1983        let now = std::time::SystemTime::now()
1984            .duration_since(std::time::UNIX_EPOCH)
1985            .unwrap_or_default()
1986            .as_secs()
1987            .cast_signed();
1988        let ts = now - age_secs;
1989        let mut ranked = vec![(MessageId(1), 1.0f64)];
1990        let mut timestamps = std::collections::HashMap::new();
1991        timestamps.insert(MessageId(1), ts);
1992        apply_temporal_decay(&mut ranked, &timestamps, half_life);
1993        // exp(-ln2) = 0.5
1994        assert!(
1995            (ranked[0].1 - 0.5).abs() < 0.01,
1996            "score was {}",
1997            ranked[0].1
1998        );
1999    }
2000
2001    // MMR tests
2002
2003    #[test]
2004    fn mmr_empty_input_returns_empty() {
2005        let ranked = vec![];
2006        let vectors = std::collections::HashMap::new();
2007        let result = apply_mmr(&ranked, &vectors, 0.7, 5);
2008        assert!(result.is_empty());
2009    }
2010
2011    #[test]
2012    fn mmr_returns_up_to_limit() {
2013        let ranked = vec![
2014            (MessageId(1), 1.0f64),
2015            (MessageId(2), 0.9f64),
2016            (MessageId(3), 0.8f64),
2017        ];
2018        let mut vectors = std::collections::HashMap::new();
2019        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2020        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2021        vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
2022        let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2023        assert_eq!(result.len(), 2);
2024    }
2025
2026    #[test]
2027    fn mmr_without_vectors_picks_by_relevance() {
2028        let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2029        let vectors = std::collections::HashMap::new();
2030        let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2031        assert_eq!(result.len(), 2);
2032        assert_eq!(result[0].0, MessageId(1));
2033    }
2034
2035    #[test]
2036    fn mmr_prefers_diverse_over_redundant() {
2037        // Two candidates with same relevance but msg 2 is orthogonal (more diverse)
2038        let ranked = vec![
2039            (MessageId(1), 1.0f64), // selected first
2040            (MessageId(2), 0.9f64), // orthogonal to 1
2041            (MessageId(3), 0.9f64), // parallel to 1 (redundant)
2042        ];
2043        let mut vectors = std::collections::HashMap::new();
2044        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2045        vectors.insert(MessageId(2), vec![0.0f32, 1.0]); // orthogonal
2046        vectors.insert(MessageId(3), vec![1.0f32, 0.0]); // same as 1
2047        let result = apply_mmr(&ranked, &vectors, 0.5, 2);
2048        assert_eq!(result.len(), 2);
2049        assert_eq!(result[0].0, MessageId(1));
2050        // msg 2 should be preferred over msg 3 (diverse)
2051        assert_eq!(result[1].0, MessageId(2));
2052    }
2053
2054    #[test]
2055    fn temporal_decay_half_life_zero_is_noop() {
2056        let now = std::time::SystemTime::now()
2057            .duration_since(std::time::UNIX_EPOCH)
2058            .unwrap_or_default()
2059            .as_secs()
2060            .cast_signed();
2061        let age_secs = 30i64 * 86400;
2062        let ts = now - age_secs;
2063        let mut ranked = vec![(MessageId(1), 1.0f64)];
2064        let mut timestamps = std::collections::HashMap::new();
2065        timestamps.insert(MessageId(1), ts);
2066        // half_life=0 → guard returns early, score must remain 1.0
2067        apply_temporal_decay(&mut ranked, &timestamps, 0);
2068        assert!(
2069            (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2070            "score was {}",
2071            ranked[0].1
2072        );
2073    }
2074
2075    #[test]
2076    fn temporal_decay_huge_age_near_zero() {
2077        let now = std::time::SystemTime::now()
2078            .duration_since(std::time::UNIX_EPOCH)
2079            .unwrap_or_default()
2080            .as_secs()
2081            .cast_signed();
2082        // 10 years = ~3650 days
2083        let age_secs = 3650i64 * 86400;
2084        let ts = now - age_secs;
2085        let mut ranked = vec![(MessageId(1), 1.0f64)];
2086        let mut timestamps = std::collections::HashMap::new();
2087        timestamps.insert(MessageId(1), ts);
2088        apply_temporal_decay(&mut ranked, &timestamps, 30);
2089        // After 3650 days with half_life=30, score should be essentially 0
2090        assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2091    }
2092
2093    #[test]
2094    fn temporal_decay_small_half_life() {
2095        // Very small half_life (1 day), age = 7 days → 2^(-7) ≈ 0.0078
2096        let now = std::time::SystemTime::now()
2097            .duration_since(std::time::UNIX_EPOCH)
2098            .unwrap_or_default()
2099            .as_secs()
2100            .cast_signed();
2101        let ts = now - 7 * 86400i64;
2102        let mut ranked = vec![(MessageId(1), 1.0f64)];
2103        let mut timestamps = std::collections::HashMap::new();
2104        timestamps.insert(MessageId(1), ts);
2105        apply_temporal_decay(&mut ranked, &timestamps, 1);
2106        assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2107    }
2108
2109    #[test]
2110    fn mmr_lambda_zero_max_diversity() {
2111        // lambda=0 → pure diversity: second item should be most dissimilar
2112        let ranked = vec![
2113            (MessageId(1), 1.0f64),  // selected first (always highest relevance)
2114            (MessageId(2), 0.9f64),  // orthogonal to 1
2115            (MessageId(3), 0.85f64), // parallel to 1 (max_sim=1.0)
2116        ];
2117        let mut vectors = std::collections::HashMap::new();
2118        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2119        vectors.insert(MessageId(2), vec![0.0f32, 1.0]); // orthogonal
2120        vectors.insert(MessageId(3), vec![1.0f32, 0.0]); // same direction
2121        let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2122        assert_eq!(result.len(), 3);
2123        // After 1 is selected: mmr(2) = 0 - (1-0)*0 = 0, mmr(3) = 0 - 1*1 = -1 → 2 wins
2124        assert_eq!(result[1].0, MessageId(2));
2125    }
2126
2127    #[test]
2128    fn mmr_lambda_one_pure_relevance() {
2129        // lambda=1 → pure relevance, should pick in relevance order
2130        let ranked = vec![
2131            (MessageId(1), 1.0f64),
2132            (MessageId(2), 0.8f64),
2133            (MessageId(3), 0.6f64),
2134        ];
2135        let mut vectors = std::collections::HashMap::new();
2136        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2137        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2138        vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2139        let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2140        assert_eq!(result.len(), 3);
2141        assert_eq!(result[0].0, MessageId(1));
2142        assert_eq!(result[1].0, MessageId(2));
2143        assert_eq!(result[2].0, MessageId(3));
2144    }
2145
2146    #[test]
2147    fn mmr_limit_zero_returns_empty() {
2148        let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2149        let mut vectors = std::collections::HashMap::new();
2150        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2151        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2152        let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2153        assert!(result.is_empty());
2154    }
2155
2156    #[test]
2157    fn mmr_duplicate_vectors_penalizes_second() {
2158        // Two items with identical embeddings: second should be heavily penalized
2159        let ranked = vec![
2160            (MessageId(1), 1.0f64),
2161            (MessageId(2), 1.0f64), // same relevance, same direction
2162            (MessageId(3), 0.9f64), // orthogonal, lower relevance
2163        ];
2164        let mut vectors = std::collections::HashMap::new();
2165        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2166        vectors.insert(MessageId(2), vec![1.0f32, 0.0]); // duplicate
2167        vectors.insert(MessageId(3), vec![0.0f32, 1.0]); // orthogonal
2168        let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2169        assert_eq!(result.len(), 3);
2170        assert_eq!(result[0].0, MessageId(1));
2171        // msg3 (orthogonal) should be preferred over msg2 (duplicate) with lambda=0.5
2172        assert_eq!(result[1].0, MessageId(3));
2173    }
2174
2175    // Priority 3: proptest
2176
2177    use proptest::prelude::*;
2178
2179    proptest! {
2180        #[test]
2181        fn count_tokens_never_panics(s in ".*") {
2182            let counter = crate::token_counter::TokenCounter::new();
2183            let _ = counter.count_tokens(&s);
2184        }
2185    }
2186}