Skip to main content

zeph_memory/
semantic.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8
9use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
10use crate::error::MemoryError;
11use crate::sqlite::SqliteStore;
12use crate::token_counter::TokenCounter;
13use crate::types::{ConversationId, MessageId};
14use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
15
16const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
17const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
18const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
19
20#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
21pub struct StructuredSummary {
22    pub summary: String,
23    pub key_facts: Vec<String>,
24    pub entities: Vec<String>,
25}
26
27#[derive(Debug)]
28pub struct RecalledMessage {
29    pub message: Message,
30    pub score: f32,
31}
32
33#[derive(Debug, Clone)]
34pub struct Summary {
35    pub id: i64,
36    pub conversation_id: ConversationId,
37    pub content: String,
38    pub first_message_id: MessageId,
39    pub last_message_id: MessageId,
40    pub token_estimate: i64,
41}
42
43#[derive(Debug, Clone)]
44pub struct SessionSummaryResult {
45    pub summary_text: String,
46    pub score: f32,
47    pub conversation_id: ConversationId,
48}
49
50fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
51    if a.len() != b.len() || a.is_empty() {
52        return 0.0;
53    }
54    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
55    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
56    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
57    if norm_a == 0.0 || norm_b == 0.0 {
58        return 0.0;
59    }
60    dot / (norm_a * norm_b)
61}
62
63fn apply_temporal_decay(
64    ranked: &mut [(MessageId, f64)],
65    timestamps: &std::collections::HashMap<MessageId, i64>,
66    half_life_days: u32,
67) {
68    if half_life_days == 0 {
69        return;
70    }
71    let now = std::time::SystemTime::now()
72        .duration_since(std::time::UNIX_EPOCH)
73        .unwrap_or_default()
74        .as_secs()
75        .cast_signed();
76    let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
77
78    for (msg_id, score) in ranked.iter_mut() {
79        if let Some(&ts) = timestamps.get(msg_id) {
80            #[allow(clippy::cast_precision_loss)]
81            let age_days = (now - ts).max(0) as f64 / 86400.0;
82            *score *= (-lambda * age_days).exp();
83        }
84    }
85}
86
87fn apply_mmr(
88    ranked: &[(MessageId, f64)],
89    vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
90    lambda: f32,
91    limit: usize,
92) -> Vec<(MessageId, f64)> {
93    if ranked.is_empty() || limit == 0 {
94        return Vec::new();
95    }
96
97    let lambda = f64::from(lambda);
98    let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
99    let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
100
101    while selected.len() < limit && !remaining.is_empty() {
102        let best_idx = if selected.is_empty() {
103            // Pick highest relevance first
104            0
105        } else {
106            let mut best = 0usize;
107            let mut best_score = f64::NEG_INFINITY;
108
109            for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
110                let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
111                    selected
112                        .iter()
113                        .filter_map(|(sel_id, _)| vectors.get(sel_id))
114                        .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
115                        .fold(f64::NEG_INFINITY, f64::max)
116                } else {
117                    0.0
118                };
119                let max_sim = if max_sim == f64::NEG_INFINITY {
120                    0.0
121                } else {
122                    max_sim
123                };
124                let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
125                if mmr_score > best_score {
126                    best_score = mmr_score;
127                    best = i;
128                }
129            }
130            best
131        };
132
133        selected.push(remaining.remove(best_idx));
134    }
135
136    selected
137}
138
139fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
140    let mut prompt = String::from(
141        "Summarize the following conversation. Extract key facts, decisions, entities, \
142         and context needed to continue the conversation.\n\n\
143         Respond in JSON with fields: summary (string), key_facts (list of strings), \
144         entities (list of strings).\n\nConversation:\n",
145    );
146
147    for (_, role, content) in messages {
148        prompt.push_str(role);
149        prompt.push_str(": ");
150        prompt.push_str(content);
151        prompt.push('\n');
152    }
153
154    prompt
155}
156
157pub struct SemanticMemory {
158    sqlite: SqliteStore,
159    qdrant: Option<EmbeddingStore>,
160    provider: AnyProvider,
161    embedding_model: String,
162    vector_weight: f64,
163    keyword_weight: f64,
164    temporal_decay_enabled: bool,
165    temporal_decay_half_life_days: u32,
166    mmr_enabled: bool,
167    mmr_lambda: f32,
168    pub token_counter: Arc<TokenCounter>,
169}
170
171impl SemanticMemory {
172    /// Create a new `SemanticMemory` instance with default hybrid search weights (0.7/0.3).
173    ///
174    /// Qdrant connection is best-effort: if unavailable, semantic search is disabled.
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if `SQLite` cannot be initialized.
179    pub async fn new(
180        sqlite_path: &str,
181        qdrant_url: &str,
182        provider: AnyProvider,
183        embedding_model: &str,
184    ) -> Result<Self, MemoryError> {
185        Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
186    }
187
188    /// Create a new `SemanticMemory` with custom vector/keyword weights for hybrid search.
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if `SQLite` cannot be initialized.
193    pub async fn with_weights(
194        sqlite_path: &str,
195        qdrant_url: &str,
196        provider: AnyProvider,
197        embedding_model: &str,
198        vector_weight: f64,
199        keyword_weight: f64,
200    ) -> Result<Self, MemoryError> {
201        Self::with_weights_and_pool_size(
202            sqlite_path,
203            qdrant_url,
204            provider,
205            embedding_model,
206            vector_weight,
207            keyword_weight,
208            5,
209        )
210        .await
211    }
212
213    /// Create a new `SemanticMemory` with custom weights and configurable pool size.
214    ///
215    /// # Errors
216    ///
217    /// Returns an error if `SQLite` cannot be initialized.
218    pub async fn with_weights_and_pool_size(
219        sqlite_path: &str,
220        qdrant_url: &str,
221        provider: AnyProvider,
222        embedding_model: &str,
223        vector_weight: f64,
224        keyword_weight: f64,
225        pool_size: u32,
226    ) -> Result<Self, MemoryError> {
227        let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
228        let pool = sqlite.pool().clone();
229
230        let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
231            Ok(store) => Some(store),
232            Err(e) => {
233                tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
234                None
235            }
236        };
237
238        Ok(Self {
239            sqlite,
240            qdrant,
241            provider,
242            embedding_model: embedding_model.into(),
243            vector_weight,
244            keyword_weight,
245            temporal_decay_enabled: false,
246            temporal_decay_half_life_days: 30,
247            mmr_enabled: false,
248            mmr_lambda: 0.7,
249            token_counter: Arc::new(TokenCounter::new()),
250        })
251    }
252
253    /// Configure temporal decay and MMR re-ranking options.
254    #[must_use]
255    pub fn with_ranking_options(
256        mut self,
257        temporal_decay_enabled: bool,
258        temporal_decay_half_life_days: u32,
259        mmr_enabled: bool,
260        mmr_lambda: f32,
261    ) -> Self {
262        self.temporal_decay_enabled = temporal_decay_enabled;
263        self.temporal_decay_half_life_days = temporal_decay_half_life_days;
264        self.mmr_enabled = mmr_enabled;
265        self.mmr_lambda = mmr_lambda;
266        self
267    }
268
269    /// Construct a `SemanticMemory` from pre-built parts.
270    ///
271    /// Intended for tests that need full control over the backing stores.
272    #[cfg(any(test, feature = "mock"))]
273    #[must_use]
274    pub fn from_parts(
275        sqlite: SqliteStore,
276        qdrant: Option<EmbeddingStore>,
277        provider: AnyProvider,
278        embedding_model: impl Into<String>,
279        vector_weight: f64,
280        keyword_weight: f64,
281        token_counter: Arc<TokenCounter>,
282    ) -> Self {
283        Self {
284            sqlite,
285            qdrant,
286            provider,
287            embedding_model: embedding_model.into(),
288            vector_weight,
289            keyword_weight,
290            temporal_decay_enabled: false,
291            temporal_decay_half_life_days: 30,
292            mmr_enabled: false,
293            mmr_lambda: 0.7,
294            token_counter,
295        }
296    }
297
298    /// Create a `SemanticMemory` using the `SQLite`-embedded vector backend.
299    ///
300    /// # Errors
301    ///
302    /// Returns an error if `SQLite` cannot be initialized.
303    pub async fn with_sqlite_backend(
304        sqlite_path: &str,
305        provider: AnyProvider,
306        embedding_model: &str,
307        vector_weight: f64,
308        keyword_weight: f64,
309    ) -> Result<Self, MemoryError> {
310        Self::with_sqlite_backend_and_pool_size(
311            sqlite_path,
312            provider,
313            embedding_model,
314            vector_weight,
315            keyword_weight,
316            5,
317        )
318        .await
319    }
320
321    /// Create a `SemanticMemory` using the `SQLite`-embedded vector backend with configurable pool size.
322    ///
323    /// # Errors
324    ///
325    /// Returns an error if `SQLite` cannot be initialized.
326    pub async fn with_sqlite_backend_and_pool_size(
327        sqlite_path: &str,
328        provider: AnyProvider,
329        embedding_model: &str,
330        vector_weight: f64,
331        keyword_weight: f64,
332        pool_size: u32,
333    ) -> Result<Self, MemoryError> {
334        let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
335        let pool = sqlite.pool().clone();
336        let store = EmbeddingStore::new_sqlite(pool);
337
338        Ok(Self {
339            sqlite,
340            qdrant: Some(store),
341            provider,
342            embedding_model: embedding_model.into(),
343            vector_weight,
344            keyword_weight,
345            temporal_decay_enabled: false,
346            temporal_decay_half_life_days: 30,
347            mmr_enabled: false,
348            mmr_lambda: 0.7,
349            token_counter: Arc::new(TokenCounter::new()),
350        })
351    }
352
353    /// Save a message to `SQLite` and optionally embed and store in Qdrant.
354    ///
355    /// Returns the message ID assigned by `SQLite`.
356    ///
357    /// # Errors
358    ///
359    /// Returns an error if the `SQLite` save fails. Embedding failures are logged but not
360    /// propagated.
361    pub async fn remember(
362        &self,
363        conversation_id: ConversationId,
364        role: &str,
365        content: &str,
366    ) -> Result<MessageId, MemoryError> {
367        let message_id = self
368            .sqlite
369            .save_message(conversation_id, role, content)
370            .await?;
371
372        if let Some(qdrant) = &self.qdrant
373            && self.provider.supports_embeddings()
374        {
375            match self.provider.embed(content).await {
376                Ok(vector) => {
377                    // Ensure collection exists before storing
378                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
379                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
380                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
381                    } else if let Err(e) = qdrant
382                        .store(
383                            message_id,
384                            conversation_id,
385                            role,
386                            vector,
387                            MessageKind::Regular,
388                            &self.embedding_model,
389                        )
390                        .await
391                    {
392                        tracing::warn!("Failed to store embedding: {e:#}");
393                    }
394                }
395                Err(e) => {
396                    tracing::warn!("Failed to generate embedding: {e:#}");
397                }
398            }
399        }
400
401        Ok(message_id)
402    }
403
404    /// Save a message with pre-serialized parts JSON to `SQLite` and optionally embed in Qdrant.
405    ///
406    /// Returns `(message_id, embedding_stored)` tuple where `embedding_stored` is `true` if
407    /// an embedding was successfully generated and stored in Qdrant.
408    ///
409    /// # Errors
410    ///
411    /// Returns an error if the `SQLite` save fails.
412    pub async fn remember_with_parts(
413        &self,
414        conversation_id: ConversationId,
415        role: &str,
416        content: &str,
417        parts_json: &str,
418    ) -> Result<(MessageId, bool), MemoryError> {
419        let message_id = self
420            .sqlite
421            .save_message_with_parts(conversation_id, role, content, parts_json)
422            .await?;
423
424        let mut embedding_stored = false;
425
426        if let Some(qdrant) = &self.qdrant
427            && self.provider.supports_embeddings()
428        {
429            match self.provider.embed(content).await {
430                Ok(vector) => {
431                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
432                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
433                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
434                    } else if let Err(e) = qdrant
435                        .store(
436                            message_id,
437                            conversation_id,
438                            role,
439                            vector,
440                            MessageKind::Regular,
441                            &self.embedding_model,
442                        )
443                        .await
444                    {
445                        tracing::warn!("Failed to store embedding: {e:#}");
446                    } else {
447                        embedding_stored = true;
448                    }
449                }
450                Err(e) => {
451                    tracing::warn!("Failed to generate embedding: {e:#}");
452                }
453            }
454        }
455
456        Ok((message_id, embedding_stored))
457    }
458
459    /// Save a message to `SQLite` without generating an embedding.
460    ///
461    /// Use this when embedding is intentionally skipped (e.g. autosave disabled for assistant).
462    ///
463    /// # Errors
464    ///
465    /// Returns an error if the `SQLite` save fails.
466    pub async fn save_only(
467        &self,
468        conversation_id: ConversationId,
469        role: &str,
470        content: &str,
471        parts_json: &str,
472    ) -> Result<MessageId, MemoryError> {
473        self.sqlite
474            .save_message_with_parts(conversation_id, role, content, parts_json)
475            .await
476    }
477
478    /// Recall relevant messages using hybrid search (vector + FTS5 keyword).
479    ///
480    /// When Qdrant is available, runs both vector and keyword searches, then merges
481    /// results using weighted scoring. When Qdrant is unavailable, falls back to
482    /// FTS5-only keyword search.
483    ///
484    /// # Errors
485    ///
486    /// Returns an error if embedding generation, Qdrant search, or FTS5 query fails.
487    #[allow(clippy::too_many_lines)]
488    pub async fn recall(
489        &self,
490        query: &str,
491        limit: usize,
492        filter: Option<SearchFilter>,
493    ) -> Result<Vec<RecalledMessage>, MemoryError> {
494        let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
495
496        // FTS5 keyword search (always available)
497        let keyword_results = match self
498            .sqlite
499            .keyword_search(query, limit * 2, conversation_id)
500            .await
501        {
502            Ok(results) => results,
503            Err(e) => {
504                tracing::warn!("FTS5 keyword search failed: {e:#}");
505                Vec::new()
506            }
507        };
508
509        // Vector search (only when Qdrant available)
510        let vector_results = if let Some(qdrant) = &self.qdrant
511            && self.provider.supports_embeddings()
512        {
513            let query_vector = self.provider.embed(query).await?;
514            let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
515            qdrant.ensure_collection(vector_size).await?;
516            qdrant.search(&query_vector, limit * 2, filter).await?
517        } else {
518            Vec::new()
519        };
520
521        // Merge results with weighted scoring
522        let mut scores: std::collections::HashMap<MessageId, f64> =
523            std::collections::HashMap::new();
524
525        if !vector_results.is_empty() {
526            let max_vs = vector_results
527                .iter()
528                .map(|r| r.score)
529                .fold(f32::NEG_INFINITY, f32::max);
530            let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
531            for r in &vector_results {
532                let normalized = f64::from(r.score / norm);
533                *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
534            }
535        }
536
537        if !keyword_results.is_empty() {
538            let max_ks = keyword_results
539                .iter()
540                .map(|r| r.1)
541                .fold(f64::NEG_INFINITY, f64::max);
542            let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
543            for &(msg_id, score) in &keyword_results {
544                let normalized = score / norm;
545                *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
546            }
547        }
548
549        if scores.is_empty() {
550            return Ok(Vec::new());
551        }
552
553        // Sort by combined score descending
554        let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
555        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
556
557        // Apply temporal decay (before MMR)
558        if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
559            let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
560            match self.sqlite.message_timestamps(&ids).await {
561                Ok(timestamps) => {
562                    apply_temporal_decay(
563                        &mut ranked,
564                        &timestamps,
565                        self.temporal_decay_half_life_days,
566                    );
567                    ranked
568                        .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
569                }
570                Err(e) => {
571                    tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
572                }
573            }
574        }
575
576        // Apply MMR re-ranking (after decay, before truncation)
577        if self.mmr_enabled && !vector_results.is_empty() {
578            if let Some(qdrant) = &self.qdrant {
579                let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
580                match qdrant.get_vectors(&ids).await {
581                    Ok(vec_map) if !vec_map.is_empty() => {
582                        ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
583                    }
584                    Ok(_) => {
585                        ranked.truncate(limit);
586                    }
587                    Err(e) => {
588                        tracing::warn!("MMR: failed to fetch vectors: {e:#}");
589                        ranked.truncate(limit);
590                    }
591                }
592            } else {
593                ranked.truncate(limit);
594            }
595        } else {
596            ranked.truncate(limit);
597        }
598
599        let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
600        let messages = self.sqlite.messages_by_ids(&ids).await?;
601        let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
602
603        let recalled = ranked
604            .iter()
605            .filter_map(|(msg_id, score)| {
606                msg_map.get(msg_id).map(|msg| RecalledMessage {
607                    message: msg.clone(),
608                    #[expect(clippy::cast_possible_truncation)]
609                    score: *score as f32,
610                })
611            })
612            .collect();
613
614        Ok(recalled)
615    }
616
617    /// Check whether an embedding exists for a given message ID.
618    ///
619    /// # Errors
620    ///
621    /// Returns an error if the `SQLite` query fails.
622    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
623        match &self.qdrant {
624            Some(qdrant) => qdrant.has_embedding(message_id).await,
625            None => Ok(false),
626        }
627    }
628
629    /// Embed all messages that do not yet have embeddings.
630    ///
631    /// Returns the count of successfully embedded messages.
632    ///
633    /// # Errors
634    ///
635    /// Returns an error if collection initialization or database query fails.
636    /// Individual embedding failures are logged but do not stop processing.
637    pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
638        let Some(qdrant) = &self.qdrant else {
639            return Ok(0);
640        };
641        if !self.provider.supports_embeddings() {
642            return Ok(0);
643        }
644
645        let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
646
647        if unembedded.is_empty() {
648            return Ok(0);
649        }
650
651        let probe = self.provider.embed("probe").await?;
652        let vector_size = u64::try_from(probe.len())?;
653        qdrant.ensure_collection(vector_size).await?;
654
655        let mut count = 0;
656        for (msg_id, conversation_id, role, content) in &unembedded {
657            match self.provider.embed(content).await {
658                Ok(vector) => {
659                    if let Err(e) = qdrant
660                        .store(
661                            *msg_id,
662                            *conversation_id,
663                            role,
664                            vector,
665                            MessageKind::Regular,
666                            &self.embedding_model,
667                        )
668                        .await
669                    {
670                        tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
671                        continue;
672                    }
673                    count += 1;
674                }
675                Err(e) => {
676                    tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
677                }
678            }
679        }
680
681        tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
682        Ok(count)
683    }
684
685    /// Store a session summary into the dedicated `zeph_session_summaries` Qdrant collection.
686    ///
687    /// # Errors
688    ///
689    /// Returns an error if embedding or Qdrant storage fails.
690    pub async fn store_session_summary(
691        &self,
692        conversation_id: ConversationId,
693        summary_text: &str,
694    ) -> Result<(), MemoryError> {
695        let Some(qdrant) = &self.qdrant else {
696            return Ok(());
697        };
698        if !self.provider.supports_embeddings() {
699            return Ok(());
700        }
701
702        let vector = self.provider.embed(summary_text).await?;
703        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
704        qdrant
705            .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
706            .await?;
707
708        let payload = serde_json::json!({
709            "conversation_id": conversation_id.0,
710            "summary_text": summary_text,
711        });
712
713        qdrant
714            .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
715            .await?;
716
717        tracing::debug!(
718            conversation_id = conversation_id.0,
719            "stored session summary"
720        );
721        Ok(())
722    }
723
724    /// Search session summaries from other conversations.
725    ///
726    /// # Errors
727    ///
728    /// Returns an error if embedding or Qdrant search fails.
729    pub async fn search_session_summaries(
730        &self,
731        query: &str,
732        limit: usize,
733        exclude_conversation_id: Option<ConversationId>,
734    ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
735        let Some(qdrant) = &self.qdrant else {
736            return Ok(Vec::new());
737        };
738        if !self.provider.supports_embeddings() {
739            return Ok(Vec::new());
740        }
741
742        let vector = self.provider.embed(query).await?;
743        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
744        qdrant
745            .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
746            .await?;
747
748        let filter = exclude_conversation_id.map(|cid| VectorFilter {
749            must: vec![],
750            must_not: vec![FieldCondition {
751                field: "conversation_id".into(),
752                value: FieldValue::Integer(cid.0),
753            }],
754        });
755
756        let points = qdrant
757            .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
758            .await?;
759
760        let results = points
761            .into_iter()
762            .filter_map(|point| {
763                let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
764                let conversation_id =
765                    ConversationId(point.payload.get("conversation_id")?.as_i64()?);
766                Some(SessionSummaryResult {
767                    summary_text,
768                    score: point.score,
769                    conversation_id,
770                })
771            })
772            .collect();
773
774        Ok(results)
775    }
776
777    /// Access the underlying `SqliteStore` for operations that don't involve semantics.
778    #[must_use]
779    pub fn sqlite(&self) -> &SqliteStore {
780        &self.sqlite
781    }
782
783    /// Check if the vector store backend is reachable.
784    ///
785    /// Performs a real health check (Qdrant gRPC ping or `SQLite` query)
786    /// instead of just checking whether the client was created.
787    pub async fn is_vector_store_connected(&self) -> bool {
788        match self.qdrant.as_ref() {
789            Some(store) => store.health_check().await,
790            None => false,
791        }
792    }
793
794    /// Check if a vector store client is configured (may not be connected).
795    #[must_use]
796    pub fn has_vector_store(&self) -> bool {
797        self.qdrant.is_some()
798    }
799
800    /// Count messages in a conversation.
801    ///
802    /// # Errors
803    ///
804    /// Returns an error if the query fails.
805    pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
806        self.sqlite.count_messages(conversation_id).await
807    }
808
809    /// Count messages not yet covered by any summary.
810    ///
811    /// # Errors
812    ///
813    /// Returns an error if the query fails.
814    pub async fn unsummarized_message_count(
815        &self,
816        conversation_id: ConversationId,
817    ) -> Result<i64, MemoryError> {
818        let after_id = self
819            .sqlite
820            .latest_summary_last_message_id(conversation_id)
821            .await?
822            .unwrap_or(MessageId(0));
823        self.sqlite
824            .count_messages_after(conversation_id, after_id)
825            .await
826    }
827
828    /// Load all summaries for a conversation.
829    ///
830    /// # Errors
831    ///
832    /// Returns an error if the query fails.
833    pub async fn load_summaries(
834        &self,
835        conversation_id: ConversationId,
836    ) -> Result<Vec<Summary>, MemoryError> {
837        let rows = self.sqlite.load_summaries(conversation_id).await?;
838        let summaries = rows
839            .into_iter()
840            .map(
841                |(
842                    id,
843                    conversation_id,
844                    content,
845                    first_message_id,
846                    last_message_id,
847                    token_estimate,
848                )| {
849                    Summary {
850                        id,
851                        conversation_id,
852                        content,
853                        first_message_id,
854                        last_message_id,
855                        token_estimate,
856                    }
857                },
858            )
859            .collect();
860        Ok(summaries)
861    }
862
863    /// Generate a summary of the oldest unsummarized messages.
864    ///
865    /// Returns `Ok(None)` if there are not enough messages to summarize.
866    ///
867    /// # Errors
868    ///
869    /// Returns an error if LLM call or database operation fails.
870    pub async fn summarize(
871        &self,
872        conversation_id: ConversationId,
873        message_count: usize,
874    ) -> Result<Option<i64>, MemoryError> {
875        let total = self.sqlite.count_messages(conversation_id).await?;
876
877        if total <= i64::try_from(message_count)? {
878            return Ok(None);
879        }
880
881        let after_id = self
882            .sqlite
883            .latest_summary_last_message_id(conversation_id)
884            .await?
885            .unwrap_or(MessageId(0));
886
887        let messages = self
888            .sqlite
889            .load_messages_range(conversation_id, after_id, message_count)
890            .await?;
891
892        if messages.is_empty() {
893            return Ok(None);
894        }
895
896        let prompt = build_summarization_prompt(&messages);
897        let chat_messages = vec![Message {
898            role: Role::User,
899            content: prompt,
900            parts: vec![],
901            metadata: MessageMetadata::default(),
902        }];
903
904        let structured = match self
905            .provider
906            .chat_typed_erased::<StructuredSummary>(&chat_messages)
907            .await
908        {
909            Ok(s) => s,
910            Err(e) => {
911                tracing::warn!(
912                    "structured summarization failed, falling back to plain text: {e:#}"
913                );
914                let plain = self.provider.chat(&chat_messages).await?;
915                StructuredSummary {
916                    summary: plain,
917                    key_facts: vec![],
918                    entities: vec![],
919                }
920            }
921        };
922        let summary_text = &structured.summary;
923
924        let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
925        let first_message_id = messages[0].0;
926        let last_message_id = messages[messages.len() - 1].0;
927
928        let summary_id = self
929            .sqlite
930            .save_summary(
931                conversation_id,
932                summary_text,
933                first_message_id,
934                last_message_id,
935                token_estimate,
936            )
937            .await?;
938
939        if let Some(qdrant) = &self.qdrant
940            && self.provider.supports_embeddings()
941        {
942            match self.provider.embed(summary_text).await {
943                Ok(vector) => {
944                    // Ensure collection exists before storing
945                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
946                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
947                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
948                    } else if let Err(e) = qdrant
949                        .store(
950                            MessageId(summary_id),
951                            conversation_id,
952                            "system",
953                            vector,
954                            MessageKind::Summary,
955                            &self.embedding_model,
956                        )
957                        .await
958                    {
959                        tracing::warn!("Failed to embed summary: {e:#}");
960                    }
961                }
962                Err(e) => {
963                    tracing::warn!("Failed to generate summary embedding: {e:#}");
964                }
965            }
966        }
967
968        // Store key facts as individual Qdrant points
969        if !structured.key_facts.is_empty() {
970            self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
971                .await;
972        }
973
974        Ok(Some(summary_id))
975    }
976
977    async fn store_key_facts(
978        &self,
979        conversation_id: ConversationId,
980        source_summary_id: i64,
981        key_facts: &[String],
982    ) {
983        let Some(qdrant) = &self.qdrant else {
984            return;
985        };
986        if !self.provider.supports_embeddings() {
987            return;
988        }
989
990        let Some(first_fact) = key_facts.first() else {
991            return;
992        };
993        let first_vector = match self.provider.embed(first_fact).await {
994            Ok(v) => v,
995            Err(e) => {
996                tracing::warn!("Failed to embed key fact: {e:#}");
997                return;
998            }
999        };
1000        let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
1001        if let Err(e) = qdrant
1002            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1003            .await
1004        {
1005            tracing::warn!("Failed to ensure key_facts collection: {e:#}");
1006            return;
1007        }
1008
1009        let first_payload = serde_json::json!({
1010            "conversation_id": conversation_id.0,
1011            "fact_text": first_fact,
1012            "source_summary_id": source_summary_id,
1013        });
1014        if let Err(e) = qdrant
1015            .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
1016            .await
1017        {
1018            tracing::warn!("Failed to store key fact: {e:#}");
1019        }
1020
1021        for fact in &key_facts[1..] {
1022            match self.provider.embed(fact).await {
1023                Ok(vector) => {
1024                    let payload = serde_json::json!({
1025                        "conversation_id": conversation_id.0,
1026                        "fact_text": fact,
1027                        "source_summary_id": source_summary_id,
1028                    });
1029                    if let Err(e) = qdrant
1030                        .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
1031                        .await
1032                    {
1033                        tracing::warn!("Failed to store key fact: {e:#}");
1034                    }
1035                }
1036                Err(e) => {
1037                    tracing::warn!("Failed to embed key fact: {e:#}");
1038                }
1039            }
1040        }
1041    }
1042
1043    /// Search key facts extracted from conversation summaries.
1044    ///
1045    /// # Errors
1046    ///
1047    /// Returns an error if embedding or Qdrant search fails.
1048    pub async fn search_key_facts(
1049        &self,
1050        query: &str,
1051        limit: usize,
1052    ) -> Result<Vec<String>, MemoryError> {
1053        let Some(qdrant) = &self.qdrant else {
1054            return Ok(Vec::new());
1055        };
1056        if !self.provider.supports_embeddings() {
1057            return Ok(Vec::new());
1058        }
1059
1060        let vector = self.provider.embed(query).await?;
1061        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1062        qdrant
1063            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1064            .await?;
1065
1066        let points = qdrant
1067            .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
1068            .await?;
1069
1070        let facts = points
1071            .into_iter()
1072            .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
1073            .collect();
1074
1075        Ok(facts)
1076    }
1077
1078    /// Search a named document collection by semantic similarity.
1079    ///
1080    /// Returns up to `limit` scored vector points whose payloads contain ingested document chunks.
1081    /// Returns an empty vec when Qdrant is unavailable, the collection does not exist,
1082    /// or the provider does not support embeddings.
1083    ///
1084    /// # Errors
1085    ///
1086    /// Returns an error if embedding generation or Qdrant search fails.
1087    pub async fn search_document_collection(
1088        &self,
1089        collection: &str,
1090        query: &str,
1091        limit: usize,
1092    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
1093        let Some(qdrant) = &self.qdrant else {
1094            return Ok(Vec::new());
1095        };
1096        if !self.provider.supports_embeddings() {
1097            return Ok(Vec::new());
1098        }
1099        if !qdrant.collection_exists(collection).await? {
1100            return Ok(Vec::new());
1101        }
1102        let vector = self.provider.embed(query).await?;
1103        qdrant
1104            .search_collection(collection, &vector, limit, None)
1105            .await
1106    }
1107
1108    /// Store an embedding for a user correction in the vector store.
1109    ///
1110    /// Silently skips if no vector store is configured or embeddings are unsupported.
1111    ///
1112    /// # Errors
1113    ///
1114    /// Returns an error if embedding generation or vector store write fails.
1115    pub async fn store_correction_embedding(
1116        &self,
1117        correction_id: i64,
1118        correction_text: &str,
1119    ) -> Result<(), MemoryError> {
1120        let Some(ref store) = self.qdrant else {
1121            return Ok(());
1122        };
1123        if !self.provider.supports_embeddings() {
1124            return Ok(());
1125        }
1126        let embedding = self
1127            .provider
1128            .embed(correction_text)
1129            .await
1130            .map_err(|e| MemoryError::Other(e.to_string()))?;
1131        let payload = serde_json::json!({ "correction_id": correction_id });
1132        store
1133            .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
1134            .await?;
1135        Ok(())
1136    }
1137
1138    /// Retrieve corrections semantically similar to `query`.
1139    ///
1140    /// Returns up to `limit` corrections scoring above `min_score`.
1141    /// Returns an empty vec if no vector store is configured.
1142    ///
1143    /// # Errors
1144    ///
1145    /// Returns an error if embedding generation or vector search fails.
1146    pub async fn retrieve_similar_corrections(
1147        &self,
1148        query: &str,
1149        limit: usize,
1150        min_score: f32,
1151    ) -> Result<Vec<crate::sqlite::corrections::UserCorrectionRow>, MemoryError> {
1152        let Some(ref store) = self.qdrant else {
1153            return Ok(vec![]);
1154        };
1155        if !self.provider.supports_embeddings() {
1156            return Ok(vec![]);
1157        }
1158        let embedding = self
1159            .provider
1160            .embed(query)
1161            .await
1162            .map_err(|e| MemoryError::Other(e.to_string()))?;
1163        let scored = store
1164            .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
1165            .await
1166            .unwrap_or_default();
1167
1168        let mut results = Vec::new();
1169        for point in scored {
1170            if point.score < min_score {
1171                continue;
1172            }
1173            if let Some(id_val) = point.payload.get("correction_id")
1174                && let Some(id) = id_val.as_i64()
1175            {
1176                let rows = self.sqlite.load_corrections_for_id(id).await?;
1177                results.extend(rows);
1178            }
1179        }
1180        Ok(results)
1181    }
1182}
1183
1184#[cfg(test)]
1185mod tests {
1186    use zeph_llm::mock::MockProvider;
1187    use zeph_llm::provider::Role;
1188
1189    use super::*;
1190
1191    fn test_provider() -> AnyProvider {
1192        AnyProvider::Mock(MockProvider::default())
1193    }
1194
1195    async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1196        let provider = test_provider();
1197        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1198
1199        SemanticMemory {
1200            sqlite,
1201            qdrant: None,
1202            provider,
1203            embedding_model: "test-model".into(),
1204            vector_weight: 0.7,
1205            keyword_weight: 0.3,
1206            temporal_decay_enabled: false,
1207            temporal_decay_half_life_days: 30,
1208            mmr_enabled: false,
1209            mmr_lambda: 0.7,
1210            token_counter: Arc::new(TokenCounter::new()),
1211        }
1212    }
1213
1214    #[tokio::test]
1215    async fn remember_saves_to_sqlite() {
1216        let memory = test_semantic_memory(false).await;
1217
1218        let cid = memory.sqlite.create_conversation().await.unwrap();
1219        let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1220
1221        assert_eq!(msg_id, MessageId(1));
1222
1223        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1224        assert_eq!(history.len(), 1);
1225        assert_eq!(history[0].role, Role::User);
1226        assert_eq!(history[0].content, "hello");
1227    }
1228
1229    #[tokio::test]
1230    async fn remember_with_parts_saves_parts_json() {
1231        let memory = test_semantic_memory(false).await;
1232        let cid = memory.sqlite.create_conversation().await.unwrap();
1233
1234        let parts_json =
1235            r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1236        let (msg_id, _embedding_stored) = memory
1237            .remember_with_parts(cid, "assistant", "tool output", parts_json)
1238            .await
1239            .unwrap();
1240        assert!(msg_id > MessageId(0));
1241
1242        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1243        assert_eq!(history.len(), 1);
1244        assert_eq!(history[0].content, "tool output");
1245    }
1246
1247    #[tokio::test]
1248    async fn recall_returns_empty_without_qdrant() {
1249        let memory = test_semantic_memory(true).await;
1250
1251        let recalled = memory.recall("test", 5, None).await.unwrap();
1252        assert!(recalled.is_empty());
1253    }
1254
1255    #[tokio::test]
1256    async fn has_embedding_without_qdrant() {
1257        let memory = test_semantic_memory(true).await;
1258
1259        let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1260        assert!(!has_embedding);
1261    }
1262
1263    #[tokio::test]
1264    async fn embed_missing_without_qdrant() {
1265        let memory = test_semantic_memory(true).await;
1266
1267        let count = memory.embed_missing().await.unwrap();
1268        assert_eq!(count, 0);
1269    }
1270
1271    #[tokio::test]
1272    async fn sqlite_accessor() {
1273        let memory = test_semantic_memory(false).await;
1274
1275        let cid = memory.sqlite().create_conversation().await.unwrap();
1276        assert_eq!(cid, ConversationId(1));
1277
1278        memory
1279            .sqlite()
1280            .save_message(cid, "user", "test")
1281            .await
1282            .unwrap();
1283
1284        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1285        assert_eq!(history.len(), 1);
1286    }
1287
1288    #[tokio::test]
1289    async fn has_vector_store_returns_false_when_unavailable() {
1290        let memory = test_semantic_memory(false).await;
1291        assert!(!memory.has_vector_store());
1292    }
1293
1294    #[tokio::test]
1295    async fn is_vector_store_connected_returns_false_when_unavailable() {
1296        let memory = test_semantic_memory(false).await;
1297        assert!(!memory.is_vector_store_connected().await);
1298    }
1299
1300    #[tokio::test]
1301    async fn recall_returns_empty_when_embeddings_not_supported() {
1302        let memory = test_semantic_memory(false).await;
1303
1304        let recalled = memory.recall("test", 5, None).await.unwrap();
1305        assert!(recalled.is_empty());
1306    }
1307
1308    #[tokio::test]
1309    async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1310        let memory = test_semantic_memory(false).await;
1311
1312        let cid = memory.sqlite().create_conversation().await.unwrap();
1313        memory
1314            .sqlite()
1315            .save_message(cid, "user", "test")
1316            .await
1317            .unwrap();
1318
1319        let count = memory.embed_missing().await.unwrap();
1320        assert_eq!(count, 0);
1321    }
1322
1323    #[tokio::test]
1324    async fn message_count_empty_conversation() {
1325        let memory = test_semantic_memory(false).await;
1326        let cid = memory.sqlite().create_conversation().await.unwrap();
1327
1328        let count = memory.message_count(cid).await.unwrap();
1329        assert_eq!(count, 0);
1330    }
1331
1332    #[tokio::test]
1333    async fn message_count_after_saves() {
1334        let memory = test_semantic_memory(false).await;
1335        let cid = memory.sqlite().create_conversation().await.unwrap();
1336
1337        memory.remember(cid, "user", "msg1").await.unwrap();
1338        memory.remember(cid, "assistant", "msg2").await.unwrap();
1339
1340        let count = memory.message_count(cid).await.unwrap();
1341        assert_eq!(count, 2);
1342    }
1343
1344    #[tokio::test]
1345    async fn unsummarized_count_decreases_after_summary() {
1346        let memory = test_semantic_memory(false).await;
1347        let cid = memory.sqlite().create_conversation().await.unwrap();
1348
1349        for i in 0..10 {
1350            memory
1351                .remember(cid, "user", &format!("msg{i}"))
1352                .await
1353                .unwrap();
1354        }
1355        assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1356
1357        memory.summarize(cid, 5).await.unwrap();
1358
1359        assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1360        assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1361    }
1362
1363    #[tokio::test]
1364    async fn load_summaries_empty() {
1365        let memory = test_semantic_memory(false).await;
1366        let cid = memory.sqlite().create_conversation().await.unwrap();
1367
1368        let summaries = memory.load_summaries(cid).await.unwrap();
1369        assert!(summaries.is_empty());
1370    }
1371
1372    #[tokio::test]
1373    async fn load_summaries_ordered() {
1374        let memory = test_semantic_memory(false).await;
1375        let cid = memory.sqlite().create_conversation().await.unwrap();
1376
1377        let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1378        let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1379        let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1380
1381        let s1 = memory
1382            .sqlite()
1383            .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1384            .await
1385            .unwrap();
1386        let s2 = memory
1387            .sqlite()
1388            .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1389            .await
1390            .unwrap();
1391
1392        let summaries = memory.load_summaries(cid).await.unwrap();
1393        assert_eq!(summaries.len(), 2);
1394        assert_eq!(summaries[0].id, s1);
1395        assert_eq!(summaries[0].content, "summary1");
1396        assert_eq!(summaries[1].id, s2);
1397        assert_eq!(summaries[1].content, "summary2");
1398    }
1399
1400    #[tokio::test]
1401    async fn summarize_below_threshold() {
1402        let memory = test_semantic_memory(false).await;
1403        let cid = memory.sqlite().create_conversation().await.unwrap();
1404
1405        memory.remember(cid, "user", "hello").await.unwrap();
1406
1407        let result = memory.summarize(cid, 10).await.unwrap();
1408        assert!(result.is_none());
1409    }
1410
1411    #[tokio::test]
1412    async fn summarize_stores_summary() {
1413        let memory = test_semantic_memory(false).await;
1414        let cid = memory.sqlite().create_conversation().await.unwrap();
1415
1416        for i in 0..5 {
1417            memory
1418                .remember(cid, "user", &format!("message {i}"))
1419                .await
1420                .unwrap();
1421        }
1422
1423        let summary_id = memory.summarize(cid, 3).await.unwrap();
1424        assert!(summary_id.is_some());
1425
1426        let summaries = memory.load_summaries(cid).await.unwrap();
1427        assert_eq!(summaries.len(), 1);
1428        assert_eq!(summaries[0].id, summary_id.unwrap());
1429        assert!(!summaries[0].content.is_empty());
1430    }
1431
1432    #[tokio::test]
1433    async fn summarize_respects_previous_summaries() {
1434        let memory = test_semantic_memory(false).await;
1435        let cid = memory.sqlite().create_conversation().await.unwrap();
1436
1437        for i in 0..10 {
1438            memory
1439                .remember(cid, "user", &format!("message {i}"))
1440                .await
1441                .unwrap();
1442        }
1443
1444        let s1 = memory.summarize(cid, 3).await.unwrap();
1445        assert!(s1.is_some());
1446
1447        let s2 = memory.summarize(cid, 3).await.unwrap();
1448        assert!(s2.is_some());
1449
1450        let summaries = memory.load_summaries(cid).await.unwrap();
1451        assert_eq!(summaries.len(), 2);
1452        assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1453    }
1454
1455    #[tokio::test]
1456    async fn remember_multiple_messages_increments_ids() {
1457        let memory = test_semantic_memory(false).await;
1458        let cid = memory.sqlite.create_conversation().await.unwrap();
1459
1460        let id1 = memory.remember(cid, "user", "first").await.unwrap();
1461        let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1462        let id3 = memory.remember(cid, "user", "third").await.unwrap();
1463
1464        assert!(id1 < id2);
1465        assert!(id2 < id3);
1466    }
1467
1468    #[tokio::test]
1469    async fn message_count_across_conversations() {
1470        let memory = test_semantic_memory(false).await;
1471        let cid1 = memory.sqlite().create_conversation().await.unwrap();
1472        let cid2 = memory.sqlite().create_conversation().await.unwrap();
1473
1474        memory.remember(cid1, "user", "msg1").await.unwrap();
1475        memory.remember(cid1, "user", "msg2").await.unwrap();
1476        memory.remember(cid2, "user", "msg3").await.unwrap();
1477
1478        assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1479        assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1480    }
1481
1482    #[tokio::test]
1483    async fn summarize_exact_threshold_returns_none() {
1484        let memory = test_semantic_memory(false).await;
1485        let cid = memory.sqlite().create_conversation().await.unwrap();
1486
1487        for i in 0..3 {
1488            memory
1489                .remember(cid, "user", &format!("msg {i}"))
1490                .await
1491                .unwrap();
1492        }
1493
1494        let result = memory.summarize(cid, 3).await.unwrap();
1495        assert!(result.is_none());
1496    }
1497
1498    #[tokio::test]
1499    async fn summarize_one_above_threshold_produces_summary() {
1500        let memory = test_semantic_memory(false).await;
1501        let cid = memory.sqlite().create_conversation().await.unwrap();
1502
1503        for i in 0..4 {
1504            memory
1505                .remember(cid, "user", &format!("msg {i}"))
1506                .await
1507                .unwrap();
1508        }
1509
1510        let result = memory.summarize(cid, 3).await.unwrap();
1511        assert!(result.is_some());
1512    }
1513
1514    #[tokio::test]
1515    async fn summary_fields_populated() {
1516        let memory = test_semantic_memory(false).await;
1517        let cid = memory.sqlite().create_conversation().await.unwrap();
1518
1519        for i in 0..5 {
1520            memory
1521                .remember(cid, "user", &format!("msg {i}"))
1522                .await
1523                .unwrap();
1524        }
1525
1526        memory.summarize(cid, 3).await.unwrap();
1527        let summaries = memory.load_summaries(cid).await.unwrap();
1528        let s = &summaries[0];
1529
1530        assert_eq!(s.conversation_id, cid);
1531        assert!(s.first_message_id > MessageId(0));
1532        assert!(s.last_message_id >= s.first_message_id);
1533        assert!(s.token_estimate >= 0);
1534        assert!(!s.content.is_empty());
1535    }
1536
1537    #[test]
1538    fn build_summarization_prompt_format() {
1539        let messages = vec![
1540            (MessageId(1), "user".into(), "Hello".into()),
1541            (MessageId(2), "assistant".into(), "Hi there".into()),
1542        ];
1543        let prompt = build_summarization_prompt(&messages);
1544        assert!(prompt.contains("user: Hello"));
1545        assert!(prompt.contains("assistant: Hi there"));
1546        assert!(prompt.contains("key_facts"));
1547    }
1548
1549    #[test]
1550    fn build_summarization_prompt_empty() {
1551        let messages: Vec<(MessageId, String, String)> = vec![];
1552        let prompt = build_summarization_prompt(&messages);
1553        assert!(prompt.contains("key_facts"));
1554    }
1555
1556    #[test]
1557    fn structured_summary_deserialize() {
1558        let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1559        let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1560        assert_eq!(ss.summary, "s");
1561        assert_eq!(ss.key_facts.len(), 2);
1562        assert_eq!(ss.entities.len(), 1);
1563    }
1564
1565    #[test]
1566    fn structured_summary_empty_facts() {
1567        let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
1568        let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1569        assert!(ss.key_facts.is_empty());
1570        assert!(ss.entities.is_empty());
1571    }
1572
1573    #[tokio::test]
1574    async fn search_key_facts_no_qdrant_empty() {
1575        let memory = test_semantic_memory(false).await;
1576        let facts = memory.search_key_facts("query", 5).await.unwrap();
1577        assert!(facts.is_empty());
1578    }
1579
1580    #[test]
1581    fn recalled_message_debug() {
1582        let recalled = RecalledMessage {
1583            message: Message {
1584                role: Role::User,
1585                content: "test".into(),
1586                parts: vec![],
1587                metadata: MessageMetadata::default(),
1588            },
1589            score: 0.95,
1590        };
1591        let dbg = format!("{recalled:?}");
1592        assert!(dbg.contains("RecalledMessage"));
1593        assert!(dbg.contains("0.95"));
1594    }
1595
1596    #[test]
1597    fn summary_clone() {
1598        let summary = Summary {
1599            id: 1,
1600            conversation_id: ConversationId(2),
1601            content: "test summary".into(),
1602            first_message_id: MessageId(1),
1603            last_message_id: MessageId(5),
1604            token_estimate: 10,
1605        };
1606        let cloned = summary.clone();
1607        assert_eq!(summary.id, cloned.id);
1608        assert_eq!(summary.content, cloned.content);
1609    }
1610
1611    #[tokio::test]
1612    async fn remember_preserves_role_mapping() {
1613        let memory = test_semantic_memory(false).await;
1614        let cid = memory.sqlite.create_conversation().await.unwrap();
1615
1616        memory.remember(cid, "user", "u").await.unwrap();
1617        memory.remember(cid, "assistant", "a").await.unwrap();
1618        memory.remember(cid, "system", "s").await.unwrap();
1619
1620        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1621        assert_eq!(history.len(), 3);
1622        assert_eq!(history[0].role, Role::User);
1623        assert_eq!(history[1].role, Role::Assistant);
1624        assert_eq!(history[2].role, Role::System);
1625    }
1626
1627    #[tokio::test]
1628    async fn new_with_invalid_qdrant_url_graceful() {
1629        let mut mock = MockProvider::default();
1630        mock.supports_embeddings = true;
1631        let provider = AnyProvider::Mock(mock);
1632        let result =
1633            SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
1634        assert!(result.is_ok());
1635    }
1636
1637    #[tokio::test]
1638    async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
1639        // Build SemanticMemory with EmbeddingStore backed by SQLite instead of Qdrant
1640        let mut mock = MockProvider::default();
1641        mock.supports_embeddings = true;
1642        // Provide deterministic embedding vectors: embed returns a fixed 4-element vector
1643        // MockProvider.embed always returns the same vector, so cosine similarity = 1.0
1644        let provider = AnyProvider::Mock(mock);
1645
1646        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1647        let pool = sqlite.pool().clone();
1648        let qdrant = Some(crate::embedding_store::EmbeddingStore::new_sqlite(pool));
1649
1650        let memory = SemanticMemory {
1651            sqlite,
1652            qdrant,
1653            provider,
1654            embedding_model: "test-model".into(),
1655            vector_weight: 0.7,
1656            keyword_weight: 0.3,
1657            temporal_decay_enabled: false,
1658            temporal_decay_half_life_days: 30,
1659            mmr_enabled: false,
1660            mmr_lambda: 0.7,
1661            token_counter: Arc::new(TokenCounter::new()),
1662        };
1663
1664        let cid = memory.sqlite().create_conversation().await.unwrap();
1665
1666        // remember → stores in SQLite + SQLite vector store
1667        let id1 = memory
1668            .remember(cid, "user", "rust async programming")
1669            .await
1670            .unwrap();
1671        let id2 = memory
1672            .remember(cid, "assistant", "use tokio for async")
1673            .await
1674            .unwrap();
1675        assert!(id1 < id2);
1676
1677        // recall → should return results via FTS5 keyword search
1678        let recalled = memory.recall("rust", 5, None).await.unwrap();
1679        assert!(
1680            !recalled.is_empty(),
1681            "recall must return at least one result"
1682        );
1683
1684        // Verify history is accessible
1685        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1686        assert_eq!(history.len(), 2);
1687        assert_eq!(history[0].content, "rust async programming");
1688    }
1689
1690    #[tokio::test]
1691    async fn remember_with_embeddings_supported_but_no_qdrant() {
1692        let memory = test_semantic_memory(true).await;
1693        let cid = memory.sqlite.create_conversation().await.unwrap();
1694
1695        let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
1696        assert!(msg_id > MessageId(0));
1697
1698        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1699        assert_eq!(history.len(), 1);
1700        assert_eq!(history[0].content, "hello embed");
1701    }
1702
1703    #[tokio::test]
1704    async fn remember_verifies_content_via_load_history() {
1705        let memory = test_semantic_memory(false).await;
1706        let cid = memory.sqlite.create_conversation().await.unwrap();
1707
1708        memory.remember(cid, "user", "alpha").await.unwrap();
1709        memory.remember(cid, "assistant", "beta").await.unwrap();
1710        memory.remember(cid, "user", "gamma").await.unwrap();
1711
1712        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1713        assert_eq!(history.len(), 3);
1714        assert_eq!(history[0].content, "alpha");
1715        assert_eq!(history[1].content, "beta");
1716        assert_eq!(history[2].content, "gamma");
1717    }
1718
1719    #[tokio::test]
1720    async fn message_count_multiple_conversations_isolated() {
1721        let memory = test_semantic_memory(false).await;
1722        let cid1 = memory.sqlite().create_conversation().await.unwrap();
1723        let cid2 = memory.sqlite().create_conversation().await.unwrap();
1724        let cid3 = memory.sqlite().create_conversation().await.unwrap();
1725
1726        for _ in 0..5 {
1727            memory.remember(cid1, "user", "msg").await.unwrap();
1728        }
1729        for _ in 0..3 {
1730            memory.remember(cid2, "user", "msg").await.unwrap();
1731        }
1732
1733        assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
1734        assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
1735        assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
1736    }
1737
1738    #[tokio::test]
1739    async fn summarize_empty_messages_range_returns_none() {
1740        let memory = test_semantic_memory(false).await;
1741        let cid = memory.sqlite().create_conversation().await.unwrap();
1742
1743        for i in 0..6 {
1744            memory
1745                .remember(cid, "user", &format!("msg {i}"))
1746                .await
1747                .unwrap();
1748        }
1749
1750        memory.summarize(cid, 3).await.unwrap();
1751        memory.summarize(cid, 3).await.unwrap();
1752
1753        let summaries = memory.load_summaries(cid).await.unwrap();
1754        assert_eq!(summaries.len(), 2);
1755    }
1756
1757    #[tokio::test]
1758    async fn summarize_token_estimate_populated() {
1759        let memory = test_semantic_memory(false).await;
1760        let cid = memory.sqlite().create_conversation().await.unwrap();
1761
1762        for i in 0..5 {
1763            memory
1764                .remember(cid, "user", &format!("message {i}"))
1765                .await
1766                .unwrap();
1767        }
1768
1769        memory.summarize(cid, 3).await.unwrap();
1770        let summaries = memory.load_summaries(cid).await.unwrap();
1771        let token_est = summaries[0].token_estimate;
1772        assert!(token_est > 0);
1773    }
1774
1775    #[tokio::test]
1776    async fn summarize_fails_when_provider_chat_fails() {
1777        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1778        let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
1779            "http://127.0.0.1:1",
1780            "test".into(),
1781            "embed".into(),
1782        ));
1783        let memory = SemanticMemory {
1784            sqlite,
1785            qdrant: None,
1786            provider,
1787            embedding_model: "test".into(),
1788            vector_weight: 0.7,
1789            keyword_weight: 0.3,
1790            temporal_decay_enabled: false,
1791            temporal_decay_half_life_days: 30,
1792            mmr_enabled: false,
1793            mmr_lambda: 0.7,
1794            token_counter: Arc::new(TokenCounter::new()),
1795        };
1796        let cid = memory.sqlite().create_conversation().await.unwrap();
1797
1798        for i in 0..5 {
1799            memory
1800                .remember(cid, "user", &format!("msg {i}"))
1801                .await
1802                .unwrap();
1803        }
1804
1805        let result = memory.summarize(cid, 3).await;
1806        assert!(result.is_err());
1807    }
1808
1809    #[tokio::test]
1810    async fn embed_missing_without_embedding_support_returns_zero() {
1811        let memory = test_semantic_memory(false).await;
1812        let cid = memory.sqlite().create_conversation().await.unwrap();
1813        memory
1814            .sqlite()
1815            .save_message(cid, "user", "test message")
1816            .await
1817            .unwrap();
1818
1819        let count = memory.embed_missing().await.unwrap();
1820        assert_eq!(count, 0);
1821    }
1822
1823    #[tokio::test]
1824    async fn has_embedding_returns_false_when_no_qdrant() {
1825        let memory = test_semantic_memory(false).await;
1826        let cid = memory.sqlite.create_conversation().await.unwrap();
1827        let msg_id = memory.remember(cid, "user", "test").await.unwrap();
1828        assert!(!memory.has_embedding(msg_id).await.unwrap());
1829    }
1830
1831    #[tokio::test]
1832    async fn recall_empty_without_qdrant_regardless_of_filter() {
1833        let memory = test_semantic_memory(true).await;
1834        let filter = SearchFilter {
1835            conversation_id: Some(ConversationId(1)),
1836            role: None,
1837        };
1838        let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
1839        assert!(recalled.is_empty());
1840    }
1841
1842    #[tokio::test]
1843    async fn summarize_message_range_bounds() {
1844        let memory = test_semantic_memory(false).await;
1845        let cid = memory.sqlite().create_conversation().await.unwrap();
1846
1847        for i in 0..8 {
1848            memory
1849                .remember(cid, "user", &format!("msg {i}"))
1850                .await
1851                .unwrap();
1852        }
1853
1854        let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
1855        let summaries = memory.load_summaries(cid).await.unwrap();
1856        assert_eq!(summaries.len(), 1);
1857        assert_eq!(summaries[0].id, summary_id);
1858        assert!(summaries[0].first_message_id >= MessageId(1));
1859        assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
1860    }
1861
1862    #[test]
1863    fn build_summarization_prompt_preserves_order() {
1864        let messages = vec![
1865            (MessageId(1), "user".into(), "first".into()),
1866            (MessageId(2), "assistant".into(), "second".into()),
1867            (MessageId(3), "user".into(), "third".into()),
1868        ];
1869        let prompt = build_summarization_prompt(&messages);
1870        let first_pos = prompt.find("user: first").unwrap();
1871        let second_pos = prompt.find("assistant: second").unwrap();
1872        let third_pos = prompt.find("user: third").unwrap();
1873        assert!(first_pos < second_pos);
1874        assert!(second_pos < third_pos);
1875    }
1876
1877    #[test]
1878    fn summary_debug() {
1879        let summary = Summary {
1880            id: 1,
1881            conversation_id: ConversationId(2),
1882            content: "test".into(),
1883            first_message_id: MessageId(1),
1884            last_message_id: MessageId(5),
1885            token_estimate: 10,
1886        };
1887        let dbg = format!("{summary:?}");
1888        assert!(dbg.contains("Summary"));
1889    }
1890
1891    #[tokio::test]
1892    async fn message_count_nonexistent_conversation() {
1893        let memory = test_semantic_memory(false).await;
1894        let count = memory.message_count(ConversationId(999)).await.unwrap();
1895        assert_eq!(count, 0);
1896    }
1897
1898    #[tokio::test]
1899    async fn load_summaries_nonexistent_conversation() {
1900        let memory = test_semantic_memory(false).await;
1901        let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
1902        assert!(summaries.is_empty());
1903    }
1904
1905    #[tokio::test]
1906    async fn store_session_summary_no_qdrant_noop() {
1907        let memory = test_semantic_memory(true).await;
1908        let result = memory
1909            .store_session_summary(ConversationId(1), "test summary")
1910            .await;
1911        assert!(result.is_ok());
1912    }
1913
1914    #[tokio::test]
1915    async fn store_session_summary_no_embeddings_noop() {
1916        let memory = test_semantic_memory(false).await;
1917        let result = memory
1918            .store_session_summary(ConversationId(1), "test summary")
1919            .await;
1920        assert!(result.is_ok());
1921    }
1922
1923    #[tokio::test]
1924    async fn search_session_summaries_no_qdrant_empty() {
1925        let memory = test_semantic_memory(true).await;
1926        let results = memory
1927            .search_session_summaries("query", 5, None)
1928            .await
1929            .unwrap();
1930        assert!(results.is_empty());
1931    }
1932
1933    #[tokio::test]
1934    async fn search_session_summaries_no_embeddings_empty() {
1935        let memory = test_semantic_memory(false).await;
1936        let results = memory
1937            .search_session_summaries("query", 5, Some(ConversationId(1)))
1938            .await
1939            .unwrap();
1940        assert!(results.is_empty());
1941    }
1942
1943    #[test]
1944    fn session_summary_result_debug() {
1945        let result = SessionSummaryResult {
1946            summary_text: "test".into(),
1947            score: 0.9,
1948            conversation_id: ConversationId(1),
1949        };
1950        let dbg = format!("{result:?}");
1951        assert!(dbg.contains("SessionSummaryResult"));
1952    }
1953
1954    #[test]
1955    fn session_summary_result_clone() {
1956        let result = SessionSummaryResult {
1957            summary_text: "test".into(),
1958            score: 0.9,
1959            conversation_id: ConversationId(1),
1960        };
1961        let cloned = result.clone();
1962        assert_eq!(result.summary_text, cloned.summary_text);
1963        assert_eq!(result.conversation_id, cloned.conversation_id);
1964    }
1965
1966    #[tokio::test]
1967    async fn recall_fts5_fallback_without_qdrant() {
1968        let memory = test_semantic_memory(false).await;
1969        let cid = memory.sqlite.create_conversation().await.unwrap();
1970
1971        memory
1972            .remember(cid, "user", "rust programming guide")
1973            .await
1974            .unwrap();
1975        memory
1976            .remember(cid, "assistant", "python tutorial")
1977            .await
1978            .unwrap();
1979        memory
1980            .remember(cid, "user", "advanced rust patterns")
1981            .await
1982            .unwrap();
1983
1984        let recalled = memory.recall("rust", 5, None).await.unwrap();
1985        assert_eq!(recalled.len(), 2);
1986        assert!(recalled[0].score >= recalled[1].score);
1987    }
1988
1989    #[tokio::test]
1990    async fn recall_fts5_fallback_with_filter() {
1991        let memory = test_semantic_memory(false).await;
1992        let cid1 = memory.sqlite.create_conversation().await.unwrap();
1993        let cid2 = memory.sqlite.create_conversation().await.unwrap();
1994
1995        memory.remember(cid1, "user", "hello world").await.unwrap();
1996        memory
1997            .remember(cid2, "user", "hello universe")
1998            .await
1999            .unwrap();
2000
2001        let filter = SearchFilter {
2002            conversation_id: Some(cid1),
2003            role: None,
2004        };
2005        let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
2006        assert_eq!(recalled.len(), 1);
2007    }
2008
2009    #[tokio::test]
2010    async fn recall_fts5_no_matches_returns_empty() {
2011        let memory = test_semantic_memory(false).await;
2012        let cid = memory.sqlite.create_conversation().await.unwrap();
2013
2014        memory.remember(cid, "user", "hello world").await.unwrap();
2015
2016        let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
2017        assert!(recalled.is_empty());
2018    }
2019
2020    #[tokio::test]
2021    async fn recall_fts5_respects_limit() {
2022        let memory = test_semantic_memory(false).await;
2023        let cid = memory.sqlite.create_conversation().await.unwrap();
2024
2025        for i in 0..10 {
2026            memory
2027                .remember(cid, "user", &format!("test message number {i}"))
2028                .await
2029                .unwrap();
2030        }
2031
2032        let recalled = memory.recall("test", 3, None).await.unwrap();
2033        assert_eq!(recalled.len(), 3);
2034    }
2035
2036    // Priority 2: summarize fallback path
2037
2038    #[tokio::test]
2039    async fn summarize_fallback_to_plain_text_when_structured_fails() {
2040        // Use OllamaProvider pointing at an unreachable URL for chat_typed_erased,
2041        // but MockProvider for the plain chat call.
2042        // The easiest way: MockProvider returns non-JSON plain text so chat_typed_erased
2043        // (which uses chat() + JSON parse) will fail to parse, then falls back to chat().
2044        // However MockProvider.chat_typed calls chat() which returns default_response.
2045        // chat_typed tries to parse it as JSON → fails → retries → fails → returns StructuredParse error.
2046        // Then the fallback calls plain chat() which succeeds.
2047        let sqlite = SqliteStore::new(":memory:").await.unwrap();
2048        let mut mock = MockProvider::default();
2049        // First two calls go to chat_typed (attempt + retry), third call is the plain fallback
2050        mock.default_response = "plain text summary".into();
2051        let provider = AnyProvider::Mock(mock);
2052
2053        let memory = SemanticMemory {
2054            sqlite,
2055            qdrant: None,
2056            provider,
2057            embedding_model: "test".into(),
2058            vector_weight: 0.7,
2059            keyword_weight: 0.3,
2060            temporal_decay_enabled: false,
2061            temporal_decay_half_life_days: 30,
2062            mmr_enabled: false,
2063            mmr_lambda: 0.7,
2064            token_counter: Arc::new(TokenCounter::new()),
2065        };
2066
2067        let cid = memory.sqlite().create_conversation().await.unwrap();
2068        for i in 0..5 {
2069            memory
2070                .remember(cid, "user", &format!("msg {i}"))
2071                .await
2072                .unwrap();
2073        }
2074
2075        let result = memory.summarize(cid, 3).await;
2076        // The summarize will either succeed (with plain text fallback) or fail
2077        // depending on how many retries chat_typed_erased does internally.
2078        // With MockProvider returning non-JSON plain text, chat_typed fails to parse.
2079        // The fallback plain chat() returns "plain text summary".
2080        // Result should be Ok with a summary stored.
2081        assert!(result.is_ok());
2082        let summaries = memory.load_summaries(cid).await.unwrap();
2083        assert_eq!(summaries.len(), 1);
2084        assert!(!summaries[0].content.is_empty());
2085    }
2086
2087    // Temporal decay tests
2088
2089    #[test]
2090    fn temporal_decay_disabled_leaves_scores_unchanged() {
2091        let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2092        let timestamps = std::collections::HashMap::new();
2093        apply_temporal_decay(&mut ranked, &timestamps, 30);
2094        assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
2095        assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
2096    }
2097
2098    #[test]
2099    fn temporal_decay_zero_age_preserves_score() {
2100        let now = std::time::SystemTime::now()
2101            .duration_since(std::time::UNIX_EPOCH)
2102            .unwrap_or_default()
2103            .as_secs()
2104            .cast_signed();
2105        let mut ranked = vec![(MessageId(1), 1.0f64)];
2106        let mut timestamps = std::collections::HashMap::new();
2107        timestamps.insert(MessageId(1), now);
2108        apply_temporal_decay(&mut ranked, &timestamps, 30);
2109        // age = 0 days, exp(0) = 1.0 → no change
2110        assert!((ranked[0].1 - 1.0).abs() < 0.01);
2111    }
2112
2113    #[test]
2114    fn temporal_decay_half_life_halves_score() {
2115        // Age exactly half_life_days → score should be halved
2116        let half_life = 30u32;
2117        let age_secs = i64::from(half_life) * 86400;
2118        let now = std::time::SystemTime::now()
2119            .duration_since(std::time::UNIX_EPOCH)
2120            .unwrap_or_default()
2121            .as_secs()
2122            .cast_signed();
2123        let ts = now - age_secs;
2124        let mut ranked = vec![(MessageId(1), 1.0f64)];
2125        let mut timestamps = std::collections::HashMap::new();
2126        timestamps.insert(MessageId(1), ts);
2127        apply_temporal_decay(&mut ranked, &timestamps, half_life);
2128        // exp(-ln2) = 0.5
2129        assert!(
2130            (ranked[0].1 - 0.5).abs() < 0.01,
2131            "score was {}",
2132            ranked[0].1
2133        );
2134    }
2135
2136    // MMR tests
2137
2138    #[test]
2139    fn mmr_empty_input_returns_empty() {
2140        let ranked = vec![];
2141        let vectors = std::collections::HashMap::new();
2142        let result = apply_mmr(&ranked, &vectors, 0.7, 5);
2143        assert!(result.is_empty());
2144    }
2145
2146    #[test]
2147    fn mmr_returns_up_to_limit() {
2148        let ranked = vec![
2149            (MessageId(1), 1.0f64),
2150            (MessageId(2), 0.9f64),
2151            (MessageId(3), 0.8f64),
2152        ];
2153        let mut vectors = std::collections::HashMap::new();
2154        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2155        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2156        vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
2157        let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2158        assert_eq!(result.len(), 2);
2159    }
2160
2161    #[test]
2162    fn mmr_without_vectors_picks_by_relevance() {
2163        let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2164        let vectors = std::collections::HashMap::new();
2165        let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2166        assert_eq!(result.len(), 2);
2167        assert_eq!(result[0].0, MessageId(1));
2168    }
2169
2170    #[test]
2171    fn mmr_prefers_diverse_over_redundant() {
2172        // Two candidates with same relevance but msg 2 is orthogonal (more diverse)
2173        let ranked = vec![
2174            (MessageId(1), 1.0f64), // selected first
2175            (MessageId(2), 0.9f64), // orthogonal to 1
2176            (MessageId(3), 0.9f64), // parallel to 1 (redundant)
2177        ];
2178        let mut vectors = std::collections::HashMap::new();
2179        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2180        vectors.insert(MessageId(2), vec![0.0f32, 1.0]); // orthogonal
2181        vectors.insert(MessageId(3), vec![1.0f32, 0.0]); // same as 1
2182        let result = apply_mmr(&ranked, &vectors, 0.5, 2);
2183        assert_eq!(result.len(), 2);
2184        assert_eq!(result[0].0, MessageId(1));
2185        // msg 2 should be preferred over msg 3 (diverse)
2186        assert_eq!(result[1].0, MessageId(2));
2187    }
2188
2189    #[test]
2190    fn temporal_decay_half_life_zero_is_noop() {
2191        let now = std::time::SystemTime::now()
2192            .duration_since(std::time::UNIX_EPOCH)
2193            .unwrap_or_default()
2194            .as_secs()
2195            .cast_signed();
2196        let age_secs = 30i64 * 86400;
2197        let ts = now - age_secs;
2198        let mut ranked = vec![(MessageId(1), 1.0f64)];
2199        let mut timestamps = std::collections::HashMap::new();
2200        timestamps.insert(MessageId(1), ts);
2201        // half_life=0 → guard returns early, score must remain 1.0
2202        apply_temporal_decay(&mut ranked, &timestamps, 0);
2203        assert!(
2204            (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2205            "score was {}",
2206            ranked[0].1
2207        );
2208    }
2209
2210    #[test]
2211    fn temporal_decay_huge_age_near_zero() {
2212        let now = std::time::SystemTime::now()
2213            .duration_since(std::time::UNIX_EPOCH)
2214            .unwrap_or_default()
2215            .as_secs()
2216            .cast_signed();
2217        // 10 years = ~3650 days
2218        let age_secs = 3650i64 * 86400;
2219        let ts = now - age_secs;
2220        let mut ranked = vec![(MessageId(1), 1.0f64)];
2221        let mut timestamps = std::collections::HashMap::new();
2222        timestamps.insert(MessageId(1), ts);
2223        apply_temporal_decay(&mut ranked, &timestamps, 30);
2224        // After 3650 days with half_life=30, score should be essentially 0
2225        assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2226    }
2227
2228    #[test]
2229    fn temporal_decay_small_half_life() {
2230        // Very small half_life (1 day), age = 7 days → 2^(-7) ≈ 0.0078
2231        let now = std::time::SystemTime::now()
2232            .duration_since(std::time::UNIX_EPOCH)
2233            .unwrap_or_default()
2234            .as_secs()
2235            .cast_signed();
2236        let ts = now - 7 * 86400i64;
2237        let mut ranked = vec![(MessageId(1), 1.0f64)];
2238        let mut timestamps = std::collections::HashMap::new();
2239        timestamps.insert(MessageId(1), ts);
2240        apply_temporal_decay(&mut ranked, &timestamps, 1);
2241        assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2242    }
2243
2244    #[test]
2245    fn mmr_lambda_zero_max_diversity() {
2246        // lambda=0 → pure diversity: second item should be most dissimilar
2247        let ranked = vec![
2248            (MessageId(1), 1.0f64),  // selected first (always highest relevance)
2249            (MessageId(2), 0.9f64),  // orthogonal to 1
2250            (MessageId(3), 0.85f64), // parallel to 1 (max_sim=1.0)
2251        ];
2252        let mut vectors = std::collections::HashMap::new();
2253        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2254        vectors.insert(MessageId(2), vec![0.0f32, 1.0]); // orthogonal
2255        vectors.insert(MessageId(3), vec![1.0f32, 0.0]); // same direction
2256        let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2257        assert_eq!(result.len(), 3);
2258        // After 1 is selected: mmr(2) = 0 - (1-0)*0 = 0, mmr(3) = 0 - 1*1 = -1 → 2 wins
2259        assert_eq!(result[1].0, MessageId(2));
2260    }
2261
2262    #[test]
2263    fn mmr_lambda_one_pure_relevance() {
2264        // lambda=1 → pure relevance, should pick in relevance order
2265        let ranked = vec![
2266            (MessageId(1), 1.0f64),
2267            (MessageId(2), 0.8f64),
2268            (MessageId(3), 0.6f64),
2269        ];
2270        let mut vectors = std::collections::HashMap::new();
2271        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2272        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2273        vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2274        let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2275        assert_eq!(result.len(), 3);
2276        assert_eq!(result[0].0, MessageId(1));
2277        assert_eq!(result[1].0, MessageId(2));
2278        assert_eq!(result[2].0, MessageId(3));
2279    }
2280
2281    #[test]
2282    fn mmr_limit_zero_returns_empty() {
2283        let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2284        let mut vectors = std::collections::HashMap::new();
2285        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2286        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2287        let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2288        assert!(result.is_empty());
2289    }
2290
2291    #[test]
2292    fn mmr_duplicate_vectors_penalizes_second() {
2293        // Two items with identical embeddings: second should be heavily penalized
2294        let ranked = vec![
2295            (MessageId(1), 1.0f64),
2296            (MessageId(2), 1.0f64), // same relevance, same direction
2297            (MessageId(3), 0.9f64), // orthogonal, lower relevance
2298        ];
2299        let mut vectors = std::collections::HashMap::new();
2300        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2301        vectors.insert(MessageId(2), vec![1.0f32, 0.0]); // duplicate
2302        vectors.insert(MessageId(3), vec![0.0f32, 1.0]); // orthogonal
2303        let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2304        assert_eq!(result.len(), 3);
2305        assert_eq!(result[0].0, MessageId(1));
2306        // msg3 (orthogonal) should be preferred over msg2 (duplicate) with lambda=0.5
2307        assert_eq!(result[1].0, MessageId(3));
2308    }
2309
2310    // Priority 3: proptest
2311
2312    use proptest::prelude::*;
2313
2314    proptest! {
2315        #[test]
2316        fn count_tokens_never_panics(s in ".*") {
2317            let counter = crate::token_counter::TokenCounter::new();
2318            let _ = counter.count_tokens(&s);
2319        }
2320    }
2321}