Skip to main content

zeph_memory/
semantic.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::any::AnyProvider;
5use zeph_llm::provider::{LlmProvider, Message, MessageMetadata, Role};
6
7use std::sync::Arc;
8
9use crate::embedding_store::{EmbeddingStore, MessageKind, SearchFilter};
10use crate::error::MemoryError;
11use crate::sqlite::SqliteStore;
12use crate::token_counter::TokenCounter;
13use crate::types::{ConversationId, MessageId};
14use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
15
16const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
17const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
18const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
19
20#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, schemars::JsonSchema)]
21pub struct StructuredSummary {
22    pub summary: String,
23    pub key_facts: Vec<String>,
24    pub entities: Vec<String>,
25}
26
27#[derive(Debug)]
28pub struct RecalledMessage {
29    pub message: Message,
30    pub score: f32,
31}
32
33#[derive(Debug, Clone)]
34pub struct Summary {
35    pub id: i64,
36    pub conversation_id: ConversationId,
37    pub content: String,
38    pub first_message_id: MessageId,
39    pub last_message_id: MessageId,
40    pub token_estimate: i64,
41}
42
43#[derive(Debug, Clone)]
44pub struct SessionSummaryResult {
45    pub summary_text: String,
46    pub score: f32,
47    pub conversation_id: ConversationId,
48}
49
50fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
51    if a.len() != b.len() || a.is_empty() {
52        return 0.0;
53    }
54    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
55    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
56    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
57    if norm_a == 0.0 || norm_b == 0.0 {
58        return 0.0;
59    }
60    dot / (norm_a * norm_b)
61}
62
63fn apply_temporal_decay(
64    ranked: &mut [(MessageId, f64)],
65    timestamps: &std::collections::HashMap<MessageId, i64>,
66    half_life_days: u32,
67) {
68    if half_life_days == 0 {
69        return;
70    }
71    let now = std::time::SystemTime::now()
72        .duration_since(std::time::UNIX_EPOCH)
73        .unwrap_or_default()
74        .as_secs()
75        .cast_signed();
76    let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
77
78    for (msg_id, score) in ranked.iter_mut() {
79        if let Some(&ts) = timestamps.get(msg_id) {
80            #[allow(clippy::cast_precision_loss)]
81            let age_days = (now - ts).max(0) as f64 / 86400.0;
82            *score *= (-lambda * age_days).exp();
83        }
84    }
85}
86
87fn apply_mmr(
88    ranked: &[(MessageId, f64)],
89    vectors: &std::collections::HashMap<MessageId, Vec<f32>>,
90    lambda: f32,
91    limit: usize,
92) -> Vec<(MessageId, f64)> {
93    if ranked.is_empty() || limit == 0 {
94        return Vec::new();
95    }
96
97    let lambda = f64::from(lambda);
98    let mut selected: Vec<(MessageId, f64)> = Vec::with_capacity(limit);
99    let mut remaining: Vec<(MessageId, f64)> = ranked.to_vec();
100
101    while selected.len() < limit && !remaining.is_empty() {
102        let best_idx = if selected.is_empty() {
103            // Pick highest relevance first
104            0
105        } else {
106            let mut best = 0usize;
107            let mut best_score = f64::NEG_INFINITY;
108
109            for (i, &(cand_id, relevance)) in remaining.iter().enumerate() {
110                let max_sim = if let Some(cand_vec) = vectors.get(&cand_id) {
111                    selected
112                        .iter()
113                        .filter_map(|(sel_id, _)| vectors.get(sel_id))
114                        .map(|sel_vec| f64::from(cosine_similarity(cand_vec, sel_vec)))
115                        .fold(f64::NEG_INFINITY, f64::max)
116                } else {
117                    0.0
118                };
119                let max_sim = if max_sim == f64::NEG_INFINITY {
120                    0.0
121                } else {
122                    max_sim
123                };
124                let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim;
125                if mmr_score > best_score {
126                    best_score = mmr_score;
127                    best = i;
128                }
129            }
130            best
131        };
132
133        selected.push(remaining.remove(best_idx));
134    }
135
136    selected
137}
138
139fn build_summarization_prompt(messages: &[(MessageId, String, String)]) -> String {
140    let mut prompt = String::from(
141        "Summarize the following conversation. Extract key facts, decisions, entities, \
142         and context needed to continue the conversation.\n\n\
143         Respond in JSON with fields: summary (string), key_facts (list of strings), \
144         entities (list of strings).\n\nConversation:\n",
145    );
146
147    for (_, role, content) in messages {
148        prompt.push_str(role);
149        prompt.push_str(": ");
150        prompt.push_str(content);
151        prompt.push('\n');
152    }
153
154    prompt
155}
156
157pub struct SemanticMemory {
158    sqlite: SqliteStore,
159    qdrant: Option<EmbeddingStore>,
160    provider: AnyProvider,
161    embedding_model: String,
162    vector_weight: f64,
163    keyword_weight: f64,
164    temporal_decay_enabled: bool,
165    temporal_decay_half_life_days: u32,
166    mmr_enabled: bool,
167    mmr_lambda: f32,
168    pub token_counter: Arc<TokenCounter>,
169}
170
171impl SemanticMemory {
172    /// Create a new `SemanticMemory` instance with default hybrid search weights (0.7/0.3).
173    ///
174    /// Qdrant connection is best-effort: if unavailable, semantic search is disabled.
175    ///
176    /// # Errors
177    ///
178    /// Returns an error if `SQLite` cannot be initialized.
179    pub async fn new(
180        sqlite_path: &str,
181        qdrant_url: &str,
182        provider: AnyProvider,
183        embedding_model: &str,
184    ) -> Result<Self, MemoryError> {
185        Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
186    }
187
188    /// Create a new `SemanticMemory` with custom vector/keyword weights for hybrid search.
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if `SQLite` cannot be initialized.
193    pub async fn with_weights(
194        sqlite_path: &str,
195        qdrant_url: &str,
196        provider: AnyProvider,
197        embedding_model: &str,
198        vector_weight: f64,
199        keyword_weight: f64,
200    ) -> Result<Self, MemoryError> {
201        Self::with_weights_and_pool_size(
202            sqlite_path,
203            qdrant_url,
204            provider,
205            embedding_model,
206            vector_weight,
207            keyword_weight,
208            5,
209        )
210        .await
211    }
212
213    /// Create a new `SemanticMemory` with custom weights and configurable pool size.
214    ///
215    /// # Errors
216    ///
217    /// Returns an error if `SQLite` cannot be initialized.
218    pub async fn with_weights_and_pool_size(
219        sqlite_path: &str,
220        qdrant_url: &str,
221        provider: AnyProvider,
222        embedding_model: &str,
223        vector_weight: f64,
224        keyword_weight: f64,
225        pool_size: u32,
226    ) -> Result<Self, MemoryError> {
227        let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
228        let pool = sqlite.pool().clone();
229
230        let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
231            Ok(store) => Some(store),
232            Err(e) => {
233                tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
234                None
235            }
236        };
237
238        Ok(Self {
239            sqlite,
240            qdrant,
241            provider,
242            embedding_model: embedding_model.into(),
243            vector_weight,
244            keyword_weight,
245            temporal_decay_enabled: false,
246            temporal_decay_half_life_days: 30,
247            mmr_enabled: false,
248            mmr_lambda: 0.7,
249            token_counter: Arc::new(TokenCounter::new()),
250        })
251    }
252
253    /// Configure temporal decay and MMR re-ranking options.
254    #[must_use]
255    pub fn with_ranking_options(
256        mut self,
257        temporal_decay_enabled: bool,
258        temporal_decay_half_life_days: u32,
259        mmr_enabled: bool,
260        mmr_lambda: f32,
261    ) -> Self {
262        self.temporal_decay_enabled = temporal_decay_enabled;
263        self.temporal_decay_half_life_days = temporal_decay_half_life_days;
264        self.mmr_enabled = mmr_enabled;
265        self.mmr_lambda = mmr_lambda;
266        self
267    }
268
269    /// Create a `SemanticMemory` using the `SQLite`-embedded vector backend.
270    ///
271    /// # Errors
272    ///
273    /// Returns an error if `SQLite` cannot be initialized.
274    pub async fn with_sqlite_backend(
275        sqlite_path: &str,
276        provider: AnyProvider,
277        embedding_model: &str,
278        vector_weight: f64,
279        keyword_weight: f64,
280    ) -> Result<Self, MemoryError> {
281        Self::with_sqlite_backend_and_pool_size(
282            sqlite_path,
283            provider,
284            embedding_model,
285            vector_weight,
286            keyword_weight,
287            5,
288        )
289        .await
290    }
291
292    /// Create a `SemanticMemory` using the `SQLite`-embedded vector backend with configurable pool size.
293    ///
294    /// # Errors
295    ///
296    /// Returns an error if `SQLite` cannot be initialized.
297    pub async fn with_sqlite_backend_and_pool_size(
298        sqlite_path: &str,
299        provider: AnyProvider,
300        embedding_model: &str,
301        vector_weight: f64,
302        keyword_weight: f64,
303        pool_size: u32,
304    ) -> Result<Self, MemoryError> {
305        let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
306        let pool = sqlite.pool().clone();
307        let store = EmbeddingStore::new_sqlite(pool);
308
309        Ok(Self {
310            sqlite,
311            qdrant: Some(store),
312            provider,
313            embedding_model: embedding_model.into(),
314            vector_weight,
315            keyword_weight,
316            temporal_decay_enabled: false,
317            temporal_decay_half_life_days: 30,
318            mmr_enabled: false,
319            mmr_lambda: 0.7,
320            token_counter: Arc::new(TokenCounter::new()),
321        })
322    }
323
324    /// Save a message to `SQLite` and optionally embed and store in Qdrant.
325    ///
326    /// Returns the message ID assigned by `SQLite`.
327    ///
328    /// # Errors
329    ///
330    /// Returns an error if the `SQLite` save fails. Embedding failures are logged but not
331    /// propagated.
332    pub async fn remember(
333        &self,
334        conversation_id: ConversationId,
335        role: &str,
336        content: &str,
337    ) -> Result<MessageId, MemoryError> {
338        let message_id = self
339            .sqlite
340            .save_message(conversation_id, role, content)
341            .await?;
342
343        if let Some(qdrant) = &self.qdrant
344            && self.provider.supports_embeddings()
345        {
346            match self.provider.embed(content).await {
347                Ok(vector) => {
348                    // Ensure collection exists before storing
349                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
350                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
351                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
352                    } else if let Err(e) = qdrant
353                        .store(
354                            message_id,
355                            conversation_id,
356                            role,
357                            vector,
358                            MessageKind::Regular,
359                            &self.embedding_model,
360                        )
361                        .await
362                    {
363                        tracing::warn!("Failed to store embedding: {e:#}");
364                    }
365                }
366                Err(e) => {
367                    tracing::warn!("Failed to generate embedding: {e:#}");
368                }
369            }
370        }
371
372        Ok(message_id)
373    }
374
375    /// Save a message with pre-serialized parts JSON to `SQLite` and optionally embed in Qdrant.
376    ///
377    /// Returns `(message_id, embedding_stored)` tuple where `embedding_stored` is `true` if
378    /// an embedding was successfully generated and stored in Qdrant.
379    ///
380    /// # Errors
381    ///
382    /// Returns an error if the `SQLite` save fails.
383    pub async fn remember_with_parts(
384        &self,
385        conversation_id: ConversationId,
386        role: &str,
387        content: &str,
388        parts_json: &str,
389    ) -> Result<(MessageId, bool), MemoryError> {
390        let message_id = self
391            .sqlite
392            .save_message_with_parts(conversation_id, role, content, parts_json)
393            .await?;
394
395        let mut embedding_stored = false;
396
397        if let Some(qdrant) = &self.qdrant
398            && self.provider.supports_embeddings()
399        {
400            match self.provider.embed(content).await {
401                Ok(vector) => {
402                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
403                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
404                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
405                    } else if let Err(e) = qdrant
406                        .store(
407                            message_id,
408                            conversation_id,
409                            role,
410                            vector,
411                            MessageKind::Regular,
412                            &self.embedding_model,
413                        )
414                        .await
415                    {
416                        tracing::warn!("Failed to store embedding: {e:#}");
417                    } else {
418                        embedding_stored = true;
419                    }
420                }
421                Err(e) => {
422                    tracing::warn!("Failed to generate embedding: {e:#}");
423                }
424            }
425        }
426
427        Ok((message_id, embedding_stored))
428    }
429
430    /// Save a message to `SQLite` without generating an embedding.
431    ///
432    /// Use this when embedding is intentionally skipped (e.g. autosave disabled for assistant).
433    ///
434    /// # Errors
435    ///
436    /// Returns an error if the `SQLite` save fails.
437    pub async fn save_only(
438        &self,
439        conversation_id: ConversationId,
440        role: &str,
441        content: &str,
442        parts_json: &str,
443    ) -> Result<MessageId, MemoryError> {
444        self.sqlite
445            .save_message_with_parts(conversation_id, role, content, parts_json)
446            .await
447    }
448
449    /// Recall relevant messages using hybrid search (vector + FTS5 keyword).
450    ///
451    /// When Qdrant is available, runs both vector and keyword searches, then merges
452    /// results using weighted scoring. When Qdrant is unavailable, falls back to
453    /// FTS5-only keyword search.
454    ///
455    /// # Errors
456    ///
457    /// Returns an error if embedding generation, Qdrant search, or FTS5 query fails.
458    #[allow(clippy::too_many_lines)]
459    pub async fn recall(
460        &self,
461        query: &str,
462        limit: usize,
463        filter: Option<SearchFilter>,
464    ) -> Result<Vec<RecalledMessage>, MemoryError> {
465        let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
466
467        // FTS5 keyword search (always available)
468        let keyword_results = match self
469            .sqlite
470            .keyword_search(query, limit * 2, conversation_id)
471            .await
472        {
473            Ok(results) => results,
474            Err(e) => {
475                tracing::warn!("FTS5 keyword search failed: {e:#}");
476                Vec::new()
477            }
478        };
479
480        // Vector search (only when Qdrant available)
481        let vector_results = if let Some(qdrant) = &self.qdrant
482            && self.provider.supports_embeddings()
483        {
484            let query_vector = self.provider.embed(query).await?;
485            let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
486            qdrant.ensure_collection(vector_size).await?;
487            qdrant.search(&query_vector, limit * 2, filter).await?
488        } else {
489            Vec::new()
490        };
491
492        // Merge results with weighted scoring
493        let mut scores: std::collections::HashMap<MessageId, f64> =
494            std::collections::HashMap::new();
495
496        if !vector_results.is_empty() {
497            let max_vs = vector_results
498                .iter()
499                .map(|r| r.score)
500                .fold(f32::NEG_INFINITY, f32::max);
501            let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
502            for r in &vector_results {
503                let normalized = f64::from(r.score / norm);
504                *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
505            }
506        }
507
508        if !keyword_results.is_empty() {
509            let max_ks = keyword_results
510                .iter()
511                .map(|r| r.1)
512                .fold(f64::NEG_INFINITY, f64::max);
513            let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
514            for &(msg_id, score) in &keyword_results {
515                let normalized = score / norm;
516                *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
517            }
518        }
519
520        if scores.is_empty() {
521            return Ok(Vec::new());
522        }
523
524        // Sort by combined score descending
525        let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
526        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
527
528        // Apply temporal decay (before MMR)
529        if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
530            let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
531            match self.sqlite.message_timestamps(&ids).await {
532                Ok(timestamps) => {
533                    apply_temporal_decay(
534                        &mut ranked,
535                        &timestamps,
536                        self.temporal_decay_half_life_days,
537                    );
538                    ranked
539                        .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
540                }
541                Err(e) => {
542                    tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
543                }
544            }
545        }
546
547        // Apply MMR re-ranking (after decay, before truncation)
548        if self.mmr_enabled && !vector_results.is_empty() {
549            if let Some(qdrant) = &self.qdrant {
550                let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
551                match qdrant.get_vectors(&ids).await {
552                    Ok(vec_map) if !vec_map.is_empty() => {
553                        ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
554                    }
555                    Ok(_) => {
556                        ranked.truncate(limit);
557                    }
558                    Err(e) => {
559                        tracing::warn!("MMR: failed to fetch vectors: {e:#}");
560                        ranked.truncate(limit);
561                    }
562                }
563            } else {
564                ranked.truncate(limit);
565            }
566        } else {
567            ranked.truncate(limit);
568        }
569
570        let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
571        let messages = self.sqlite.messages_by_ids(&ids).await?;
572        let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
573
574        let recalled = ranked
575            .iter()
576            .filter_map(|(msg_id, score)| {
577                msg_map.get(msg_id).map(|msg| RecalledMessage {
578                    message: msg.clone(),
579                    #[expect(clippy::cast_possible_truncation)]
580                    score: *score as f32,
581                })
582            })
583            .collect();
584
585        Ok(recalled)
586    }
587
588    /// Check whether an embedding exists for a given message ID.
589    ///
590    /// # Errors
591    ///
592    /// Returns an error if the `SQLite` query fails.
593    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
594        match &self.qdrant {
595            Some(qdrant) => qdrant.has_embedding(message_id).await,
596            None => Ok(false),
597        }
598    }
599
600    /// Embed all messages that do not yet have embeddings.
601    ///
602    /// Returns the count of successfully embedded messages.
603    ///
604    /// # Errors
605    ///
606    /// Returns an error if collection initialization or database query fails.
607    /// Individual embedding failures are logged but do not stop processing.
608    pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
609        let Some(qdrant) = &self.qdrant else {
610            return Ok(0);
611        };
612        if !self.provider.supports_embeddings() {
613            return Ok(0);
614        }
615
616        let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
617
618        if unembedded.is_empty() {
619            return Ok(0);
620        }
621
622        let probe = self.provider.embed("probe").await?;
623        let vector_size = u64::try_from(probe.len())?;
624        qdrant.ensure_collection(vector_size).await?;
625
626        let mut count = 0;
627        for (msg_id, conversation_id, role, content) in &unembedded {
628            match self.provider.embed(content).await {
629                Ok(vector) => {
630                    if let Err(e) = qdrant
631                        .store(
632                            *msg_id,
633                            *conversation_id,
634                            role,
635                            vector,
636                            MessageKind::Regular,
637                            &self.embedding_model,
638                        )
639                        .await
640                    {
641                        tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
642                        continue;
643                    }
644                    count += 1;
645                }
646                Err(e) => {
647                    tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
648                }
649            }
650        }
651
652        tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
653        Ok(count)
654    }
655
656    /// Store a session summary into the dedicated `zeph_session_summaries` Qdrant collection.
657    ///
658    /// # Errors
659    ///
660    /// Returns an error if embedding or Qdrant storage fails.
661    pub async fn store_session_summary(
662        &self,
663        conversation_id: ConversationId,
664        summary_text: &str,
665    ) -> Result<(), MemoryError> {
666        let Some(qdrant) = &self.qdrant else {
667            return Ok(());
668        };
669        if !self.provider.supports_embeddings() {
670            return Ok(());
671        }
672
673        let vector = self.provider.embed(summary_text).await?;
674        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
675        qdrant
676            .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
677            .await?;
678
679        let payload = serde_json::json!({
680            "conversation_id": conversation_id.0,
681            "summary_text": summary_text,
682        });
683
684        qdrant
685            .store_to_collection(SESSION_SUMMARIES_COLLECTION, payload, vector)
686            .await?;
687
688        tracing::debug!(
689            conversation_id = conversation_id.0,
690            "stored session summary"
691        );
692        Ok(())
693    }
694
695    /// Search session summaries from other conversations.
696    ///
697    /// # Errors
698    ///
699    /// Returns an error if embedding or Qdrant search fails.
700    pub async fn search_session_summaries(
701        &self,
702        query: &str,
703        limit: usize,
704        exclude_conversation_id: Option<ConversationId>,
705    ) -> Result<Vec<SessionSummaryResult>, MemoryError> {
706        let Some(qdrant) = &self.qdrant else {
707            return Ok(Vec::new());
708        };
709        if !self.provider.supports_embeddings() {
710            return Ok(Vec::new());
711        }
712
713        let vector = self.provider.embed(query).await?;
714        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
715        qdrant
716            .ensure_named_collection(SESSION_SUMMARIES_COLLECTION, vector_size)
717            .await?;
718
719        let filter = exclude_conversation_id.map(|cid| VectorFilter {
720            must: vec![],
721            must_not: vec![FieldCondition {
722                field: "conversation_id".into(),
723                value: FieldValue::Integer(cid.0),
724            }],
725        });
726
727        let points = qdrant
728            .search_collection(SESSION_SUMMARIES_COLLECTION, &vector, limit, filter)
729            .await?;
730
731        let results = points
732            .into_iter()
733            .filter_map(|point| {
734                let summary_text = point.payload.get("summary_text")?.as_str()?.to_owned();
735                let conversation_id =
736                    ConversationId(point.payload.get("conversation_id")?.as_i64()?);
737                Some(SessionSummaryResult {
738                    summary_text,
739                    score: point.score,
740                    conversation_id,
741                })
742            })
743            .collect();
744
745        Ok(results)
746    }
747
748    /// Access the underlying `SqliteStore` for operations that don't involve semantics.
749    #[must_use]
750    pub fn sqlite(&self) -> &SqliteStore {
751        &self.sqlite
752    }
753
754    /// Check if the vector store backend is reachable.
755    ///
756    /// Performs a real health check (Qdrant gRPC ping or `SQLite` query)
757    /// instead of just checking whether the client was created.
758    pub async fn is_vector_store_connected(&self) -> bool {
759        match self.qdrant.as_ref() {
760            Some(store) => store.health_check().await,
761            None => false,
762        }
763    }
764
765    /// Check if a vector store client is configured (may not be connected).
766    #[must_use]
767    pub fn has_vector_store(&self) -> bool {
768        self.qdrant.is_some()
769    }
770
771    /// Count messages in a conversation.
772    ///
773    /// # Errors
774    ///
775    /// Returns an error if the query fails.
776    pub async fn message_count(&self, conversation_id: ConversationId) -> Result<i64, MemoryError> {
777        self.sqlite.count_messages(conversation_id).await
778    }
779
780    /// Count messages not yet covered by any summary.
781    ///
782    /// # Errors
783    ///
784    /// Returns an error if the query fails.
785    pub async fn unsummarized_message_count(
786        &self,
787        conversation_id: ConversationId,
788    ) -> Result<i64, MemoryError> {
789        let after_id = self
790            .sqlite
791            .latest_summary_last_message_id(conversation_id)
792            .await?
793            .unwrap_or(MessageId(0));
794        self.sqlite
795            .count_messages_after(conversation_id, after_id)
796            .await
797    }
798
799    /// Load all summaries for a conversation.
800    ///
801    /// # Errors
802    ///
803    /// Returns an error if the query fails.
804    pub async fn load_summaries(
805        &self,
806        conversation_id: ConversationId,
807    ) -> Result<Vec<Summary>, MemoryError> {
808        let rows = self.sqlite.load_summaries(conversation_id).await?;
809        let summaries = rows
810            .into_iter()
811            .map(
812                |(
813                    id,
814                    conversation_id,
815                    content,
816                    first_message_id,
817                    last_message_id,
818                    token_estimate,
819                )| {
820                    Summary {
821                        id,
822                        conversation_id,
823                        content,
824                        first_message_id,
825                        last_message_id,
826                        token_estimate,
827                    }
828                },
829            )
830            .collect();
831        Ok(summaries)
832    }
833
834    /// Generate a summary of the oldest unsummarized messages.
835    ///
836    /// Returns `Ok(None)` if there are not enough messages to summarize.
837    ///
838    /// # Errors
839    ///
840    /// Returns an error if LLM call or database operation fails.
841    pub async fn summarize(
842        &self,
843        conversation_id: ConversationId,
844        message_count: usize,
845    ) -> Result<Option<i64>, MemoryError> {
846        let total = self.sqlite.count_messages(conversation_id).await?;
847
848        if total <= i64::try_from(message_count)? {
849            return Ok(None);
850        }
851
852        let after_id = self
853            .sqlite
854            .latest_summary_last_message_id(conversation_id)
855            .await?
856            .unwrap_or(MessageId(0));
857
858        let messages = self
859            .sqlite
860            .load_messages_range(conversation_id, after_id, message_count)
861            .await?;
862
863        if messages.is_empty() {
864            return Ok(None);
865        }
866
867        let prompt = build_summarization_prompt(&messages);
868        let chat_messages = vec![Message {
869            role: Role::User,
870            content: prompt,
871            parts: vec![],
872            metadata: MessageMetadata::default(),
873        }];
874
875        let structured = match self
876            .provider
877            .chat_typed_erased::<StructuredSummary>(&chat_messages)
878            .await
879        {
880            Ok(s) => s,
881            Err(e) => {
882                tracing::warn!(
883                    "structured summarization failed, falling back to plain text: {e:#}"
884                );
885                let plain = self.provider.chat(&chat_messages).await?;
886                StructuredSummary {
887                    summary: plain,
888                    key_facts: vec![],
889                    entities: vec![],
890                }
891            }
892        };
893        let summary_text = &structured.summary;
894
895        let token_estimate = i64::try_from(self.token_counter.count_tokens(summary_text))?;
896        let first_message_id = messages[0].0;
897        let last_message_id = messages[messages.len() - 1].0;
898
899        let summary_id = self
900            .sqlite
901            .save_summary(
902                conversation_id,
903                summary_text,
904                first_message_id,
905                last_message_id,
906                token_estimate,
907            )
908            .await?;
909
910        if let Some(qdrant) = &self.qdrant
911            && self.provider.supports_embeddings()
912        {
913            match self.provider.embed(summary_text).await {
914                Ok(vector) => {
915                    // Ensure collection exists before storing
916                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
917                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
918                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
919                    } else if let Err(e) = qdrant
920                        .store(
921                            MessageId(summary_id),
922                            conversation_id,
923                            "system",
924                            vector,
925                            MessageKind::Summary,
926                            &self.embedding_model,
927                        )
928                        .await
929                    {
930                        tracing::warn!("Failed to embed summary: {e:#}");
931                    }
932                }
933                Err(e) => {
934                    tracing::warn!("Failed to generate summary embedding: {e:#}");
935                }
936            }
937        }
938
939        // Store key facts as individual Qdrant points
940        if !structured.key_facts.is_empty() {
941            self.store_key_facts(conversation_id, summary_id, &structured.key_facts)
942                .await;
943        }
944
945        Ok(Some(summary_id))
946    }
947
948    async fn store_key_facts(
949        &self,
950        conversation_id: ConversationId,
951        source_summary_id: i64,
952        key_facts: &[String],
953    ) {
954        let Some(qdrant) = &self.qdrant else {
955            return;
956        };
957        if !self.provider.supports_embeddings() {
958            return;
959        }
960
961        let Some(first_fact) = key_facts.first() else {
962            return;
963        };
964        let first_vector = match self.provider.embed(first_fact).await {
965            Ok(v) => v,
966            Err(e) => {
967                tracing::warn!("Failed to embed key fact: {e:#}");
968                return;
969            }
970        };
971        let vector_size = u64::try_from(first_vector.len()).unwrap_or(896);
972        if let Err(e) = qdrant
973            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
974            .await
975        {
976            tracing::warn!("Failed to ensure key_facts collection: {e:#}");
977            return;
978        }
979
980        let first_payload = serde_json::json!({
981            "conversation_id": conversation_id.0,
982            "fact_text": first_fact,
983            "source_summary_id": source_summary_id,
984        });
985        if let Err(e) = qdrant
986            .store_to_collection(KEY_FACTS_COLLECTION, first_payload, first_vector)
987            .await
988        {
989            tracing::warn!("Failed to store key fact: {e:#}");
990        }
991
992        for fact in &key_facts[1..] {
993            match self.provider.embed(fact).await {
994                Ok(vector) => {
995                    let payload = serde_json::json!({
996                        "conversation_id": conversation_id.0,
997                        "fact_text": fact,
998                        "source_summary_id": source_summary_id,
999                    });
1000                    if let Err(e) = qdrant
1001                        .store_to_collection(KEY_FACTS_COLLECTION, payload, vector)
1002                        .await
1003                    {
1004                        tracing::warn!("Failed to store key fact: {e:#}");
1005                    }
1006                }
1007                Err(e) => {
1008                    tracing::warn!("Failed to embed key fact: {e:#}");
1009                }
1010            }
1011        }
1012    }
1013
1014    /// Search key facts extracted from conversation summaries.
1015    ///
1016    /// # Errors
1017    ///
1018    /// Returns an error if embedding or Qdrant search fails.
1019    pub async fn search_key_facts(
1020        &self,
1021        query: &str,
1022        limit: usize,
1023    ) -> Result<Vec<String>, MemoryError> {
1024        let Some(qdrant) = &self.qdrant else {
1025            return Ok(Vec::new());
1026        };
1027        if !self.provider.supports_embeddings() {
1028            return Ok(Vec::new());
1029        }
1030
1031        let vector = self.provider.embed(query).await?;
1032        let vector_size = u64::try_from(vector.len()).unwrap_or(896);
1033        qdrant
1034            .ensure_named_collection(KEY_FACTS_COLLECTION, vector_size)
1035            .await?;
1036
1037        let points = qdrant
1038            .search_collection(KEY_FACTS_COLLECTION, &vector, limit, None)
1039            .await?;
1040
1041        let facts = points
1042            .into_iter()
1043            .filter_map(|p| p.payload.get("fact_text")?.as_str().map(String::from))
1044            .collect();
1045
1046        Ok(facts)
1047    }
1048
1049    /// Search a named document collection by semantic similarity.
1050    ///
1051    /// Returns up to `limit` scored vector points whose payloads contain ingested document chunks.
1052    /// Returns an empty vec when Qdrant is unavailable, the collection does not exist,
1053    /// or the provider does not support embeddings.
1054    ///
1055    /// # Errors
1056    ///
1057    /// Returns an error if embedding generation or Qdrant search fails.
1058    pub async fn search_document_collection(
1059        &self,
1060        collection: &str,
1061        query: &str,
1062        limit: usize,
1063    ) -> Result<Vec<crate::ScoredVectorPoint>, MemoryError> {
1064        let Some(qdrant) = &self.qdrant else {
1065            return Ok(Vec::new());
1066        };
1067        if !self.provider.supports_embeddings() {
1068            return Ok(Vec::new());
1069        }
1070        if !qdrant.collection_exists(collection).await? {
1071            return Ok(Vec::new());
1072        }
1073        let vector = self.provider.embed(query).await?;
1074        qdrant
1075            .search_collection(collection, &vector, limit, None)
1076            .await
1077    }
1078
1079    /// Store an embedding for a user correction in the vector store.
1080    ///
1081    /// Silently skips if no vector store is configured or embeddings are unsupported.
1082    ///
1083    /// # Errors
1084    ///
1085    /// Returns an error if embedding generation or vector store write fails.
1086    pub async fn store_correction_embedding(
1087        &self,
1088        correction_id: i64,
1089        correction_text: &str,
1090    ) -> Result<(), MemoryError> {
1091        let Some(ref store) = self.qdrant else {
1092            return Ok(());
1093        };
1094        if !self.provider.supports_embeddings() {
1095            return Ok(());
1096        }
1097        let embedding = self
1098            .provider
1099            .embed(correction_text)
1100            .await
1101            .map_err(|e| MemoryError::Other(e.to_string()))?;
1102        let payload = serde_json::json!({ "correction_id": correction_id });
1103        store
1104            .store_to_collection(CORRECTIONS_COLLECTION, payload, embedding)
1105            .await?;
1106        Ok(())
1107    }
1108
1109    /// Retrieve corrections semantically similar to `query`.
1110    ///
1111    /// Returns up to `limit` corrections scoring above `min_score`.
1112    /// Returns an empty vec if no vector store is configured.
1113    ///
1114    /// # Errors
1115    ///
1116    /// Returns an error if embedding generation or vector search fails.
1117    pub async fn retrieve_similar_corrections(
1118        &self,
1119        query: &str,
1120        limit: usize,
1121        min_score: f32,
1122    ) -> Result<Vec<crate::sqlite::corrections::UserCorrectionRow>, MemoryError> {
1123        let Some(ref store) = self.qdrant else {
1124            return Ok(vec![]);
1125        };
1126        if !self.provider.supports_embeddings() {
1127            return Ok(vec![]);
1128        }
1129        let embedding = self
1130            .provider
1131            .embed(query)
1132            .await
1133            .map_err(|e| MemoryError::Other(e.to_string()))?;
1134        let scored = store
1135            .search_collection(CORRECTIONS_COLLECTION, &embedding, limit, None)
1136            .await
1137            .unwrap_or_default();
1138
1139        let mut results = Vec::new();
1140        for point in scored {
1141            if point.score < min_score {
1142                continue;
1143            }
1144            if let Some(id_val) = point.payload.get("correction_id")
1145                && let Some(id) = id_val.as_i64()
1146            {
1147                let rows = self.sqlite.load_corrections_for_id(id).await?;
1148                results.extend(rows);
1149            }
1150        }
1151        Ok(results)
1152    }
1153}
1154
1155#[cfg(test)]
1156mod tests {
1157    use zeph_llm::mock::MockProvider;
1158    use zeph_llm::provider::Role;
1159
1160    use super::*;
1161
1162    fn test_provider() -> AnyProvider {
1163        AnyProvider::Mock(MockProvider::default())
1164    }
1165
1166    async fn test_semantic_memory(_supports_embeddings: bool) -> SemanticMemory {
1167        let provider = test_provider();
1168        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1169
1170        SemanticMemory {
1171            sqlite,
1172            qdrant: None,
1173            provider,
1174            embedding_model: "test-model".into(),
1175            vector_weight: 0.7,
1176            keyword_weight: 0.3,
1177            temporal_decay_enabled: false,
1178            temporal_decay_half_life_days: 30,
1179            mmr_enabled: false,
1180            mmr_lambda: 0.7,
1181            token_counter: Arc::new(TokenCounter::new()),
1182        }
1183    }
1184
1185    #[tokio::test]
1186    async fn remember_saves_to_sqlite() {
1187        let memory = test_semantic_memory(false).await;
1188
1189        let cid = memory.sqlite.create_conversation().await.unwrap();
1190        let msg_id = memory.remember(cid, "user", "hello").await.unwrap();
1191
1192        assert_eq!(msg_id, MessageId(1));
1193
1194        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1195        assert_eq!(history.len(), 1);
1196        assert_eq!(history[0].role, Role::User);
1197        assert_eq!(history[0].content, "hello");
1198    }
1199
1200    #[tokio::test]
1201    async fn remember_with_parts_saves_parts_json() {
1202        let memory = test_semantic_memory(false).await;
1203        let cid = memory.sqlite.create_conversation().await.unwrap();
1204
1205        let parts_json =
1206            r#"[{"kind":"ToolOutput","tool_name":"shell","body":"hello","compacted_at":null}]"#;
1207        let (msg_id, _embedding_stored) = memory
1208            .remember_with_parts(cid, "assistant", "tool output", parts_json)
1209            .await
1210            .unwrap();
1211        assert!(msg_id > MessageId(0));
1212
1213        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1214        assert_eq!(history.len(), 1);
1215        assert_eq!(history[0].content, "tool output");
1216    }
1217
1218    #[tokio::test]
1219    async fn recall_returns_empty_without_qdrant() {
1220        let memory = test_semantic_memory(true).await;
1221
1222        let recalled = memory.recall("test", 5, None).await.unwrap();
1223        assert!(recalled.is_empty());
1224    }
1225
1226    #[tokio::test]
1227    async fn has_embedding_without_qdrant() {
1228        let memory = test_semantic_memory(true).await;
1229
1230        let has_embedding = memory.has_embedding(MessageId(1)).await.unwrap();
1231        assert!(!has_embedding);
1232    }
1233
1234    #[tokio::test]
1235    async fn embed_missing_without_qdrant() {
1236        let memory = test_semantic_memory(true).await;
1237
1238        let count = memory.embed_missing().await.unwrap();
1239        assert_eq!(count, 0);
1240    }
1241
1242    #[tokio::test]
1243    async fn sqlite_accessor() {
1244        let memory = test_semantic_memory(false).await;
1245
1246        let cid = memory.sqlite().create_conversation().await.unwrap();
1247        assert_eq!(cid, ConversationId(1));
1248
1249        memory
1250            .sqlite()
1251            .save_message(cid, "user", "test")
1252            .await
1253            .unwrap();
1254
1255        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1256        assert_eq!(history.len(), 1);
1257    }
1258
1259    #[tokio::test]
1260    async fn has_vector_store_returns_false_when_unavailable() {
1261        let memory = test_semantic_memory(false).await;
1262        assert!(!memory.has_vector_store());
1263    }
1264
1265    #[tokio::test]
1266    async fn is_vector_store_connected_returns_false_when_unavailable() {
1267        let memory = test_semantic_memory(false).await;
1268        assert!(!memory.is_vector_store_connected().await);
1269    }
1270
1271    #[tokio::test]
1272    async fn recall_returns_empty_when_embeddings_not_supported() {
1273        let memory = test_semantic_memory(false).await;
1274
1275        let recalled = memory.recall("test", 5, None).await.unwrap();
1276        assert!(recalled.is_empty());
1277    }
1278
1279    #[tokio::test]
1280    async fn embed_missing_returns_zero_when_embeddings_not_supported() {
1281        let memory = test_semantic_memory(false).await;
1282
1283        let cid = memory.sqlite().create_conversation().await.unwrap();
1284        memory
1285            .sqlite()
1286            .save_message(cid, "user", "test")
1287            .await
1288            .unwrap();
1289
1290        let count = memory.embed_missing().await.unwrap();
1291        assert_eq!(count, 0);
1292    }
1293
1294    #[tokio::test]
1295    async fn message_count_empty_conversation() {
1296        let memory = test_semantic_memory(false).await;
1297        let cid = memory.sqlite().create_conversation().await.unwrap();
1298
1299        let count = memory.message_count(cid).await.unwrap();
1300        assert_eq!(count, 0);
1301    }
1302
1303    #[tokio::test]
1304    async fn message_count_after_saves() {
1305        let memory = test_semantic_memory(false).await;
1306        let cid = memory.sqlite().create_conversation().await.unwrap();
1307
1308        memory.remember(cid, "user", "msg1").await.unwrap();
1309        memory.remember(cid, "assistant", "msg2").await.unwrap();
1310
1311        let count = memory.message_count(cid).await.unwrap();
1312        assert_eq!(count, 2);
1313    }
1314
1315    #[tokio::test]
1316    async fn unsummarized_count_decreases_after_summary() {
1317        let memory = test_semantic_memory(false).await;
1318        let cid = memory.sqlite().create_conversation().await.unwrap();
1319
1320        for i in 0..10 {
1321            memory
1322                .remember(cid, "user", &format!("msg{i}"))
1323                .await
1324                .unwrap();
1325        }
1326        assert_eq!(memory.unsummarized_message_count(cid).await.unwrap(), 10);
1327
1328        memory.summarize(cid, 5).await.unwrap();
1329
1330        assert!(memory.unsummarized_message_count(cid).await.unwrap() < 10);
1331        assert_eq!(memory.message_count(cid).await.unwrap(), 10);
1332    }
1333
1334    #[tokio::test]
1335    async fn load_summaries_empty() {
1336        let memory = test_semantic_memory(false).await;
1337        let cid = memory.sqlite().create_conversation().await.unwrap();
1338
1339        let summaries = memory.load_summaries(cid).await.unwrap();
1340        assert!(summaries.is_empty());
1341    }
1342
1343    #[tokio::test]
1344    async fn load_summaries_ordered() {
1345        let memory = test_semantic_memory(false).await;
1346        let cid = memory.sqlite().create_conversation().await.unwrap();
1347
1348        let msg_id1 = memory.remember(cid, "user", "m1").await.unwrap();
1349        let msg_id2 = memory.remember(cid, "assistant", "m2").await.unwrap();
1350        let msg_id3 = memory.remember(cid, "user", "m3").await.unwrap();
1351
1352        let s1 = memory
1353            .sqlite()
1354            .save_summary(cid, "summary1", msg_id1, msg_id2, 3)
1355            .await
1356            .unwrap();
1357        let s2 = memory
1358            .sqlite()
1359            .save_summary(cid, "summary2", msg_id2, msg_id3, 3)
1360            .await
1361            .unwrap();
1362
1363        let summaries = memory.load_summaries(cid).await.unwrap();
1364        assert_eq!(summaries.len(), 2);
1365        assert_eq!(summaries[0].id, s1);
1366        assert_eq!(summaries[0].content, "summary1");
1367        assert_eq!(summaries[1].id, s2);
1368        assert_eq!(summaries[1].content, "summary2");
1369    }
1370
1371    #[tokio::test]
1372    async fn summarize_below_threshold() {
1373        let memory = test_semantic_memory(false).await;
1374        let cid = memory.sqlite().create_conversation().await.unwrap();
1375
1376        memory.remember(cid, "user", "hello").await.unwrap();
1377
1378        let result = memory.summarize(cid, 10).await.unwrap();
1379        assert!(result.is_none());
1380    }
1381
1382    #[tokio::test]
1383    async fn summarize_stores_summary() {
1384        let memory = test_semantic_memory(false).await;
1385        let cid = memory.sqlite().create_conversation().await.unwrap();
1386
1387        for i in 0..5 {
1388            memory
1389                .remember(cid, "user", &format!("message {i}"))
1390                .await
1391                .unwrap();
1392        }
1393
1394        let summary_id = memory.summarize(cid, 3).await.unwrap();
1395        assert!(summary_id.is_some());
1396
1397        let summaries = memory.load_summaries(cid).await.unwrap();
1398        assert_eq!(summaries.len(), 1);
1399        assert_eq!(summaries[0].id, summary_id.unwrap());
1400        assert!(!summaries[0].content.is_empty());
1401    }
1402
1403    #[tokio::test]
1404    async fn summarize_respects_previous_summaries() {
1405        let memory = test_semantic_memory(false).await;
1406        let cid = memory.sqlite().create_conversation().await.unwrap();
1407
1408        for i in 0..10 {
1409            memory
1410                .remember(cid, "user", &format!("message {i}"))
1411                .await
1412                .unwrap();
1413        }
1414
1415        let s1 = memory.summarize(cid, 3).await.unwrap();
1416        assert!(s1.is_some());
1417
1418        let s2 = memory.summarize(cid, 3).await.unwrap();
1419        assert!(s2.is_some());
1420
1421        let summaries = memory.load_summaries(cid).await.unwrap();
1422        assert_eq!(summaries.len(), 2);
1423        assert!(summaries[0].last_message_id < summaries[1].first_message_id);
1424    }
1425
1426    #[tokio::test]
1427    async fn remember_multiple_messages_increments_ids() {
1428        let memory = test_semantic_memory(false).await;
1429        let cid = memory.sqlite.create_conversation().await.unwrap();
1430
1431        let id1 = memory.remember(cid, "user", "first").await.unwrap();
1432        let id2 = memory.remember(cid, "assistant", "second").await.unwrap();
1433        let id3 = memory.remember(cid, "user", "third").await.unwrap();
1434
1435        assert!(id1 < id2);
1436        assert!(id2 < id3);
1437    }
1438
1439    #[tokio::test]
1440    async fn message_count_across_conversations() {
1441        let memory = test_semantic_memory(false).await;
1442        let cid1 = memory.sqlite().create_conversation().await.unwrap();
1443        let cid2 = memory.sqlite().create_conversation().await.unwrap();
1444
1445        memory.remember(cid1, "user", "msg1").await.unwrap();
1446        memory.remember(cid1, "user", "msg2").await.unwrap();
1447        memory.remember(cid2, "user", "msg3").await.unwrap();
1448
1449        assert_eq!(memory.message_count(cid1).await.unwrap(), 2);
1450        assert_eq!(memory.message_count(cid2).await.unwrap(), 1);
1451    }
1452
1453    #[tokio::test]
1454    async fn summarize_exact_threshold_returns_none() {
1455        let memory = test_semantic_memory(false).await;
1456        let cid = memory.sqlite().create_conversation().await.unwrap();
1457
1458        for i in 0..3 {
1459            memory
1460                .remember(cid, "user", &format!("msg {i}"))
1461                .await
1462                .unwrap();
1463        }
1464
1465        let result = memory.summarize(cid, 3).await.unwrap();
1466        assert!(result.is_none());
1467    }
1468
1469    #[tokio::test]
1470    async fn summarize_one_above_threshold_produces_summary() {
1471        let memory = test_semantic_memory(false).await;
1472        let cid = memory.sqlite().create_conversation().await.unwrap();
1473
1474        for i in 0..4 {
1475            memory
1476                .remember(cid, "user", &format!("msg {i}"))
1477                .await
1478                .unwrap();
1479        }
1480
1481        let result = memory.summarize(cid, 3).await.unwrap();
1482        assert!(result.is_some());
1483    }
1484
1485    #[tokio::test]
1486    async fn summary_fields_populated() {
1487        let memory = test_semantic_memory(false).await;
1488        let cid = memory.sqlite().create_conversation().await.unwrap();
1489
1490        for i in 0..5 {
1491            memory
1492                .remember(cid, "user", &format!("msg {i}"))
1493                .await
1494                .unwrap();
1495        }
1496
1497        memory.summarize(cid, 3).await.unwrap();
1498        let summaries = memory.load_summaries(cid).await.unwrap();
1499        let s = &summaries[0];
1500
1501        assert_eq!(s.conversation_id, cid);
1502        assert!(s.first_message_id > MessageId(0));
1503        assert!(s.last_message_id >= s.first_message_id);
1504        assert!(s.token_estimate >= 0);
1505        assert!(!s.content.is_empty());
1506    }
1507
1508    #[test]
1509    fn build_summarization_prompt_format() {
1510        let messages = vec![
1511            (MessageId(1), "user".into(), "Hello".into()),
1512            (MessageId(2), "assistant".into(), "Hi there".into()),
1513        ];
1514        let prompt = build_summarization_prompt(&messages);
1515        assert!(prompt.contains("user: Hello"));
1516        assert!(prompt.contains("assistant: Hi there"));
1517        assert!(prompt.contains("key_facts"));
1518    }
1519
1520    #[test]
1521    fn build_summarization_prompt_empty() {
1522        let messages: Vec<(MessageId, String, String)> = vec![];
1523        let prompt = build_summarization_prompt(&messages);
1524        assert!(prompt.contains("key_facts"));
1525    }
1526
1527    #[test]
1528    fn structured_summary_deserialize() {
1529        let json = r#"{"summary":"s","key_facts":["f1","f2"],"entities":["e1"]}"#;
1530        let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1531        assert_eq!(ss.summary, "s");
1532        assert_eq!(ss.key_facts.len(), 2);
1533        assert_eq!(ss.entities.len(), 1);
1534    }
1535
1536    #[test]
1537    fn structured_summary_empty_facts() {
1538        let json = r#"{"summary":"s","key_facts":[],"entities":[]}"#;
1539        let ss: StructuredSummary = serde_json::from_str(json).unwrap();
1540        assert!(ss.key_facts.is_empty());
1541        assert!(ss.entities.is_empty());
1542    }
1543
1544    #[tokio::test]
1545    async fn search_key_facts_no_qdrant_empty() {
1546        let memory = test_semantic_memory(false).await;
1547        let facts = memory.search_key_facts("query", 5).await.unwrap();
1548        assert!(facts.is_empty());
1549    }
1550
1551    #[test]
1552    fn recalled_message_debug() {
1553        let recalled = RecalledMessage {
1554            message: Message {
1555                role: Role::User,
1556                content: "test".into(),
1557                parts: vec![],
1558                metadata: MessageMetadata::default(),
1559            },
1560            score: 0.95,
1561        };
1562        let dbg = format!("{recalled:?}");
1563        assert!(dbg.contains("RecalledMessage"));
1564        assert!(dbg.contains("0.95"));
1565    }
1566
1567    #[test]
1568    fn summary_clone() {
1569        let summary = Summary {
1570            id: 1,
1571            conversation_id: ConversationId(2),
1572            content: "test summary".into(),
1573            first_message_id: MessageId(1),
1574            last_message_id: MessageId(5),
1575            token_estimate: 10,
1576        };
1577        let cloned = summary.clone();
1578        assert_eq!(summary.id, cloned.id);
1579        assert_eq!(summary.content, cloned.content);
1580    }
1581
1582    #[tokio::test]
1583    async fn remember_preserves_role_mapping() {
1584        let memory = test_semantic_memory(false).await;
1585        let cid = memory.sqlite.create_conversation().await.unwrap();
1586
1587        memory.remember(cid, "user", "u").await.unwrap();
1588        memory.remember(cid, "assistant", "a").await.unwrap();
1589        memory.remember(cid, "system", "s").await.unwrap();
1590
1591        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1592        assert_eq!(history.len(), 3);
1593        assert_eq!(history[0].role, Role::User);
1594        assert_eq!(history[1].role, Role::Assistant);
1595        assert_eq!(history[2].role, Role::System);
1596    }
1597
1598    #[tokio::test]
1599    async fn new_with_invalid_qdrant_url_graceful() {
1600        let mut mock = MockProvider::default();
1601        mock.supports_embeddings = true;
1602        let provider = AnyProvider::Mock(mock);
1603        let result =
1604            SemanticMemory::new(":memory:", "http://127.0.0.1:1", provider, "test-model").await;
1605        assert!(result.is_ok());
1606    }
1607
1608    #[tokio::test]
1609    async fn test_semantic_memory_sqlite_remember_recall_roundtrip() {
1610        // Build SemanticMemory with EmbeddingStore backed by SQLite instead of Qdrant
1611        let mut mock = MockProvider::default();
1612        mock.supports_embeddings = true;
1613        // Provide deterministic embedding vectors: embed returns a fixed 4-element vector
1614        // MockProvider.embed always returns the same vector, so cosine similarity = 1.0
1615        let provider = AnyProvider::Mock(mock);
1616
1617        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1618        let pool = sqlite.pool().clone();
1619        let qdrant = Some(crate::embedding_store::EmbeddingStore::new_sqlite(pool));
1620
1621        let memory = SemanticMemory {
1622            sqlite,
1623            qdrant,
1624            provider,
1625            embedding_model: "test-model".into(),
1626            vector_weight: 0.7,
1627            keyword_weight: 0.3,
1628            temporal_decay_enabled: false,
1629            temporal_decay_half_life_days: 30,
1630            mmr_enabled: false,
1631            mmr_lambda: 0.7,
1632            token_counter: Arc::new(TokenCounter::new()),
1633        };
1634
1635        let cid = memory.sqlite().create_conversation().await.unwrap();
1636
1637        // remember → stores in SQLite + SQLite vector store
1638        let id1 = memory
1639            .remember(cid, "user", "rust async programming")
1640            .await
1641            .unwrap();
1642        let id2 = memory
1643            .remember(cid, "assistant", "use tokio for async")
1644            .await
1645            .unwrap();
1646        assert!(id1 < id2);
1647
1648        // recall → should return results via FTS5 keyword search
1649        let recalled = memory.recall("rust", 5, None).await.unwrap();
1650        assert!(
1651            !recalled.is_empty(),
1652            "recall must return at least one result"
1653        );
1654
1655        // Verify history is accessible
1656        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1657        assert_eq!(history.len(), 2);
1658        assert_eq!(history[0].content, "rust async programming");
1659    }
1660
1661    #[tokio::test]
1662    async fn remember_with_embeddings_supported_but_no_qdrant() {
1663        let memory = test_semantic_memory(true).await;
1664        let cid = memory.sqlite.create_conversation().await.unwrap();
1665
1666        let msg_id = memory.remember(cid, "user", "hello embed").await.unwrap();
1667        assert!(msg_id > MessageId(0));
1668
1669        let history = memory.sqlite.load_history(cid, 50).await.unwrap();
1670        assert_eq!(history.len(), 1);
1671        assert_eq!(history[0].content, "hello embed");
1672    }
1673
1674    #[tokio::test]
1675    async fn remember_verifies_content_via_load_history() {
1676        let memory = test_semantic_memory(false).await;
1677        let cid = memory.sqlite.create_conversation().await.unwrap();
1678
1679        memory.remember(cid, "user", "alpha").await.unwrap();
1680        memory.remember(cid, "assistant", "beta").await.unwrap();
1681        memory.remember(cid, "user", "gamma").await.unwrap();
1682
1683        let history = memory.sqlite().load_history(cid, 50).await.unwrap();
1684        assert_eq!(history.len(), 3);
1685        assert_eq!(history[0].content, "alpha");
1686        assert_eq!(history[1].content, "beta");
1687        assert_eq!(history[2].content, "gamma");
1688    }
1689
1690    #[tokio::test]
1691    async fn message_count_multiple_conversations_isolated() {
1692        let memory = test_semantic_memory(false).await;
1693        let cid1 = memory.sqlite().create_conversation().await.unwrap();
1694        let cid2 = memory.sqlite().create_conversation().await.unwrap();
1695        let cid3 = memory.sqlite().create_conversation().await.unwrap();
1696
1697        for _ in 0..5 {
1698            memory.remember(cid1, "user", "msg").await.unwrap();
1699        }
1700        for _ in 0..3 {
1701            memory.remember(cid2, "user", "msg").await.unwrap();
1702        }
1703
1704        assert_eq!(memory.message_count(cid1).await.unwrap(), 5);
1705        assert_eq!(memory.message_count(cid2).await.unwrap(), 3);
1706        assert_eq!(memory.message_count(cid3).await.unwrap(), 0);
1707    }
1708
1709    #[tokio::test]
1710    async fn summarize_empty_messages_range_returns_none() {
1711        let memory = test_semantic_memory(false).await;
1712        let cid = memory.sqlite().create_conversation().await.unwrap();
1713
1714        for i in 0..6 {
1715            memory
1716                .remember(cid, "user", &format!("msg {i}"))
1717                .await
1718                .unwrap();
1719        }
1720
1721        memory.summarize(cid, 3).await.unwrap();
1722        memory.summarize(cid, 3).await.unwrap();
1723
1724        let summaries = memory.load_summaries(cid).await.unwrap();
1725        assert_eq!(summaries.len(), 2);
1726    }
1727
1728    #[tokio::test]
1729    async fn summarize_token_estimate_populated() {
1730        let memory = test_semantic_memory(false).await;
1731        let cid = memory.sqlite().create_conversation().await.unwrap();
1732
1733        for i in 0..5 {
1734            memory
1735                .remember(cid, "user", &format!("message {i}"))
1736                .await
1737                .unwrap();
1738        }
1739
1740        memory.summarize(cid, 3).await.unwrap();
1741        let summaries = memory.load_summaries(cid).await.unwrap();
1742        let token_est = summaries[0].token_estimate;
1743        assert!(token_est > 0);
1744    }
1745
1746    #[tokio::test]
1747    async fn summarize_fails_when_provider_chat_fails() {
1748        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1749        let provider = AnyProvider::Ollama(zeph_llm::ollama::OllamaProvider::new(
1750            "http://127.0.0.1:1",
1751            "test".into(),
1752            "embed".into(),
1753        ));
1754        let memory = SemanticMemory {
1755            sqlite,
1756            qdrant: None,
1757            provider,
1758            embedding_model: "test".into(),
1759            vector_weight: 0.7,
1760            keyword_weight: 0.3,
1761            temporal_decay_enabled: false,
1762            temporal_decay_half_life_days: 30,
1763            mmr_enabled: false,
1764            mmr_lambda: 0.7,
1765            token_counter: Arc::new(TokenCounter::new()),
1766        };
1767        let cid = memory.sqlite().create_conversation().await.unwrap();
1768
1769        for i in 0..5 {
1770            memory
1771                .remember(cid, "user", &format!("msg {i}"))
1772                .await
1773                .unwrap();
1774        }
1775
1776        let result = memory.summarize(cid, 3).await;
1777        assert!(result.is_err());
1778    }
1779
1780    #[tokio::test]
1781    async fn embed_missing_without_embedding_support_returns_zero() {
1782        let memory = test_semantic_memory(false).await;
1783        let cid = memory.sqlite().create_conversation().await.unwrap();
1784        memory
1785            .sqlite()
1786            .save_message(cid, "user", "test message")
1787            .await
1788            .unwrap();
1789
1790        let count = memory.embed_missing().await.unwrap();
1791        assert_eq!(count, 0);
1792    }
1793
1794    #[tokio::test]
1795    async fn has_embedding_returns_false_when_no_qdrant() {
1796        let memory = test_semantic_memory(false).await;
1797        let cid = memory.sqlite.create_conversation().await.unwrap();
1798        let msg_id = memory.remember(cid, "user", "test").await.unwrap();
1799        assert!(!memory.has_embedding(msg_id).await.unwrap());
1800    }
1801
1802    #[tokio::test]
1803    async fn recall_empty_without_qdrant_regardless_of_filter() {
1804        let memory = test_semantic_memory(true).await;
1805        let filter = SearchFilter {
1806            conversation_id: Some(ConversationId(1)),
1807            role: None,
1808        };
1809        let recalled = memory.recall("query", 10, Some(filter)).await.unwrap();
1810        assert!(recalled.is_empty());
1811    }
1812
1813    #[tokio::test]
1814    async fn summarize_message_range_bounds() {
1815        let memory = test_semantic_memory(false).await;
1816        let cid = memory.sqlite().create_conversation().await.unwrap();
1817
1818        for i in 0..8 {
1819            memory
1820                .remember(cid, "user", &format!("msg {i}"))
1821                .await
1822                .unwrap();
1823        }
1824
1825        let summary_id = memory.summarize(cid, 4).await.unwrap().unwrap();
1826        let summaries = memory.load_summaries(cid).await.unwrap();
1827        assert_eq!(summaries.len(), 1);
1828        assert_eq!(summaries[0].id, summary_id);
1829        assert!(summaries[0].first_message_id >= MessageId(1));
1830        assert!(summaries[0].last_message_id >= summaries[0].first_message_id);
1831    }
1832
1833    #[test]
1834    fn build_summarization_prompt_preserves_order() {
1835        let messages = vec![
1836            (MessageId(1), "user".into(), "first".into()),
1837            (MessageId(2), "assistant".into(), "second".into()),
1838            (MessageId(3), "user".into(), "third".into()),
1839        ];
1840        let prompt = build_summarization_prompt(&messages);
1841        let first_pos = prompt.find("user: first").unwrap();
1842        let second_pos = prompt.find("assistant: second").unwrap();
1843        let third_pos = prompt.find("user: third").unwrap();
1844        assert!(first_pos < second_pos);
1845        assert!(second_pos < third_pos);
1846    }
1847
1848    #[test]
1849    fn summary_debug() {
1850        let summary = Summary {
1851            id: 1,
1852            conversation_id: ConversationId(2),
1853            content: "test".into(),
1854            first_message_id: MessageId(1),
1855            last_message_id: MessageId(5),
1856            token_estimate: 10,
1857        };
1858        let dbg = format!("{summary:?}");
1859        assert!(dbg.contains("Summary"));
1860    }
1861
1862    #[tokio::test]
1863    async fn message_count_nonexistent_conversation() {
1864        let memory = test_semantic_memory(false).await;
1865        let count = memory.message_count(ConversationId(999)).await.unwrap();
1866        assert_eq!(count, 0);
1867    }
1868
1869    #[tokio::test]
1870    async fn load_summaries_nonexistent_conversation() {
1871        let memory = test_semantic_memory(false).await;
1872        let summaries = memory.load_summaries(ConversationId(999)).await.unwrap();
1873        assert!(summaries.is_empty());
1874    }
1875
1876    #[tokio::test]
1877    async fn store_session_summary_no_qdrant_noop() {
1878        let memory = test_semantic_memory(true).await;
1879        let result = memory
1880            .store_session_summary(ConversationId(1), "test summary")
1881            .await;
1882        assert!(result.is_ok());
1883    }
1884
1885    #[tokio::test]
1886    async fn store_session_summary_no_embeddings_noop() {
1887        let memory = test_semantic_memory(false).await;
1888        let result = memory
1889            .store_session_summary(ConversationId(1), "test summary")
1890            .await;
1891        assert!(result.is_ok());
1892    }
1893
1894    #[tokio::test]
1895    async fn search_session_summaries_no_qdrant_empty() {
1896        let memory = test_semantic_memory(true).await;
1897        let results = memory
1898            .search_session_summaries("query", 5, None)
1899            .await
1900            .unwrap();
1901        assert!(results.is_empty());
1902    }
1903
1904    #[tokio::test]
1905    async fn search_session_summaries_no_embeddings_empty() {
1906        let memory = test_semantic_memory(false).await;
1907        let results = memory
1908            .search_session_summaries("query", 5, Some(ConversationId(1)))
1909            .await
1910            .unwrap();
1911        assert!(results.is_empty());
1912    }
1913
1914    #[test]
1915    fn session_summary_result_debug() {
1916        let result = SessionSummaryResult {
1917            summary_text: "test".into(),
1918            score: 0.9,
1919            conversation_id: ConversationId(1),
1920        };
1921        let dbg = format!("{result:?}");
1922        assert!(dbg.contains("SessionSummaryResult"));
1923    }
1924
1925    #[test]
1926    fn session_summary_result_clone() {
1927        let result = SessionSummaryResult {
1928            summary_text: "test".into(),
1929            score: 0.9,
1930            conversation_id: ConversationId(1),
1931        };
1932        let cloned = result.clone();
1933        assert_eq!(result.summary_text, cloned.summary_text);
1934        assert_eq!(result.conversation_id, cloned.conversation_id);
1935    }
1936
1937    #[tokio::test]
1938    async fn recall_fts5_fallback_without_qdrant() {
1939        let memory = test_semantic_memory(false).await;
1940        let cid = memory.sqlite.create_conversation().await.unwrap();
1941
1942        memory
1943            .remember(cid, "user", "rust programming guide")
1944            .await
1945            .unwrap();
1946        memory
1947            .remember(cid, "assistant", "python tutorial")
1948            .await
1949            .unwrap();
1950        memory
1951            .remember(cid, "user", "advanced rust patterns")
1952            .await
1953            .unwrap();
1954
1955        let recalled = memory.recall("rust", 5, None).await.unwrap();
1956        assert_eq!(recalled.len(), 2);
1957        assert!(recalled[0].score >= recalled[1].score);
1958    }
1959
1960    #[tokio::test]
1961    async fn recall_fts5_fallback_with_filter() {
1962        let memory = test_semantic_memory(false).await;
1963        let cid1 = memory.sqlite.create_conversation().await.unwrap();
1964        let cid2 = memory.sqlite.create_conversation().await.unwrap();
1965
1966        memory.remember(cid1, "user", "hello world").await.unwrap();
1967        memory
1968            .remember(cid2, "user", "hello universe")
1969            .await
1970            .unwrap();
1971
1972        let filter = SearchFilter {
1973            conversation_id: Some(cid1),
1974            role: None,
1975        };
1976        let recalled = memory.recall("hello", 5, Some(filter)).await.unwrap();
1977        assert_eq!(recalled.len(), 1);
1978    }
1979
1980    #[tokio::test]
1981    async fn recall_fts5_no_matches_returns_empty() {
1982        let memory = test_semantic_memory(false).await;
1983        let cid = memory.sqlite.create_conversation().await.unwrap();
1984
1985        memory.remember(cid, "user", "hello world").await.unwrap();
1986
1987        let recalled = memory.recall("nonexistent", 5, None).await.unwrap();
1988        assert!(recalled.is_empty());
1989    }
1990
1991    #[tokio::test]
1992    async fn recall_fts5_respects_limit() {
1993        let memory = test_semantic_memory(false).await;
1994        let cid = memory.sqlite.create_conversation().await.unwrap();
1995
1996        for i in 0..10 {
1997            memory
1998                .remember(cid, "user", &format!("test message number {i}"))
1999                .await
2000                .unwrap();
2001        }
2002
2003        let recalled = memory.recall("test", 3, None).await.unwrap();
2004        assert_eq!(recalled.len(), 3);
2005    }
2006
2007    // Priority 2: summarize fallback path
2008
2009    #[tokio::test]
2010    async fn summarize_fallback_to_plain_text_when_structured_fails() {
2011        // Use OllamaProvider pointing at an unreachable URL for chat_typed_erased,
2012        // but MockProvider for the plain chat call.
2013        // The easiest way: MockProvider returns non-JSON plain text so chat_typed_erased
2014        // (which uses chat() + JSON parse) will fail to parse, then falls back to chat().
2015        // However MockProvider.chat_typed calls chat() which returns default_response.
2016        // chat_typed tries to parse it as JSON → fails → retries → fails → returns StructuredParse error.
2017        // Then the fallback calls plain chat() which succeeds.
2018        let sqlite = SqliteStore::new(":memory:").await.unwrap();
2019        let mut mock = MockProvider::default();
2020        // First two calls go to chat_typed (attempt + retry), third call is the plain fallback
2021        mock.default_response = "plain text summary".into();
2022        let provider = AnyProvider::Mock(mock);
2023
2024        let memory = SemanticMemory {
2025            sqlite,
2026            qdrant: None,
2027            provider,
2028            embedding_model: "test".into(),
2029            vector_weight: 0.7,
2030            keyword_weight: 0.3,
2031            temporal_decay_enabled: false,
2032            temporal_decay_half_life_days: 30,
2033            mmr_enabled: false,
2034            mmr_lambda: 0.7,
2035            token_counter: Arc::new(TokenCounter::new()),
2036        };
2037
2038        let cid = memory.sqlite().create_conversation().await.unwrap();
2039        for i in 0..5 {
2040            memory
2041                .remember(cid, "user", &format!("msg {i}"))
2042                .await
2043                .unwrap();
2044        }
2045
2046        let result = memory.summarize(cid, 3).await;
2047        // The summarize will either succeed (with plain text fallback) or fail
2048        // depending on how many retries chat_typed_erased does internally.
2049        // With MockProvider returning non-JSON plain text, chat_typed fails to parse.
2050        // The fallback plain chat() returns "plain text summary".
2051        // Result should be Ok with a summary stored.
2052        assert!(result.is_ok());
2053        let summaries = memory.load_summaries(cid).await.unwrap();
2054        assert_eq!(summaries.len(), 1);
2055        assert!(!summaries[0].content.is_empty());
2056    }
2057
2058    // Temporal decay tests
2059
2060    #[test]
2061    fn temporal_decay_disabled_leaves_scores_unchanged() {
2062        let mut ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2063        let timestamps = std::collections::HashMap::new();
2064        apply_temporal_decay(&mut ranked, &timestamps, 30);
2065        assert!((ranked[0].1 - 1.0).abs() < f64::EPSILON);
2066        assert!((ranked[1].1 - 0.5).abs() < f64::EPSILON);
2067    }
2068
2069    #[test]
2070    fn temporal_decay_zero_age_preserves_score() {
2071        let now = std::time::SystemTime::now()
2072            .duration_since(std::time::UNIX_EPOCH)
2073            .unwrap_or_default()
2074            .as_secs()
2075            .cast_signed();
2076        let mut ranked = vec![(MessageId(1), 1.0f64)];
2077        let mut timestamps = std::collections::HashMap::new();
2078        timestamps.insert(MessageId(1), now);
2079        apply_temporal_decay(&mut ranked, &timestamps, 30);
2080        // age = 0 days, exp(0) = 1.0 → no change
2081        assert!((ranked[0].1 - 1.0).abs() < 0.01);
2082    }
2083
2084    #[test]
2085    fn temporal_decay_half_life_halves_score() {
2086        // Age exactly half_life_days → score should be halved
2087        let half_life = 30u32;
2088        let age_secs = i64::from(half_life) * 86400;
2089        let now = std::time::SystemTime::now()
2090            .duration_since(std::time::UNIX_EPOCH)
2091            .unwrap_or_default()
2092            .as_secs()
2093            .cast_signed();
2094        let ts = now - age_secs;
2095        let mut ranked = vec![(MessageId(1), 1.0f64)];
2096        let mut timestamps = std::collections::HashMap::new();
2097        timestamps.insert(MessageId(1), ts);
2098        apply_temporal_decay(&mut ranked, &timestamps, half_life);
2099        // exp(-ln2) = 0.5
2100        assert!(
2101            (ranked[0].1 - 0.5).abs() < 0.01,
2102            "score was {}",
2103            ranked[0].1
2104        );
2105    }
2106
2107    // MMR tests
2108
2109    #[test]
2110    fn mmr_empty_input_returns_empty() {
2111        let ranked = vec![];
2112        let vectors = std::collections::HashMap::new();
2113        let result = apply_mmr(&ranked, &vectors, 0.7, 5);
2114        assert!(result.is_empty());
2115    }
2116
2117    #[test]
2118    fn mmr_returns_up_to_limit() {
2119        let ranked = vec![
2120            (MessageId(1), 1.0f64),
2121            (MessageId(2), 0.9f64),
2122            (MessageId(3), 0.8f64),
2123        ];
2124        let mut vectors = std::collections::HashMap::new();
2125        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2126        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2127        vectors.insert(MessageId(3), vec![1.0f32, 0.0]);
2128        let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2129        assert_eq!(result.len(), 2);
2130    }
2131
2132    #[test]
2133    fn mmr_without_vectors_picks_by_relevance() {
2134        let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.5f64)];
2135        let vectors = std::collections::HashMap::new();
2136        let result = apply_mmr(&ranked, &vectors, 0.7, 2);
2137        assert_eq!(result.len(), 2);
2138        assert_eq!(result[0].0, MessageId(1));
2139    }
2140
2141    #[test]
2142    fn mmr_prefers_diverse_over_redundant() {
2143        // Two candidates with same relevance but msg 2 is orthogonal (more diverse)
2144        let ranked = vec![
2145            (MessageId(1), 1.0f64), // selected first
2146            (MessageId(2), 0.9f64), // orthogonal to 1
2147            (MessageId(3), 0.9f64), // parallel to 1 (redundant)
2148        ];
2149        let mut vectors = std::collections::HashMap::new();
2150        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2151        vectors.insert(MessageId(2), vec![0.0f32, 1.0]); // orthogonal
2152        vectors.insert(MessageId(3), vec![1.0f32, 0.0]); // same as 1
2153        let result = apply_mmr(&ranked, &vectors, 0.5, 2);
2154        assert_eq!(result.len(), 2);
2155        assert_eq!(result[0].0, MessageId(1));
2156        // msg 2 should be preferred over msg 3 (diverse)
2157        assert_eq!(result[1].0, MessageId(2));
2158    }
2159
2160    #[test]
2161    fn temporal_decay_half_life_zero_is_noop() {
2162        let now = std::time::SystemTime::now()
2163            .duration_since(std::time::UNIX_EPOCH)
2164            .unwrap_or_default()
2165            .as_secs()
2166            .cast_signed();
2167        let age_secs = 30i64 * 86400;
2168        let ts = now - age_secs;
2169        let mut ranked = vec![(MessageId(1), 1.0f64)];
2170        let mut timestamps = std::collections::HashMap::new();
2171        timestamps.insert(MessageId(1), ts);
2172        // half_life=0 → guard returns early, score must remain 1.0
2173        apply_temporal_decay(&mut ranked, &timestamps, 0);
2174        assert!(
2175            (ranked[0].1 - 1.0).abs() < f64::EPSILON,
2176            "score was {}",
2177            ranked[0].1
2178        );
2179    }
2180
2181    #[test]
2182    fn temporal_decay_huge_age_near_zero() {
2183        let now = std::time::SystemTime::now()
2184            .duration_since(std::time::UNIX_EPOCH)
2185            .unwrap_or_default()
2186            .as_secs()
2187            .cast_signed();
2188        // 10 years = ~3650 days
2189        let age_secs = 3650i64 * 86400;
2190        let ts = now - age_secs;
2191        let mut ranked = vec![(MessageId(1), 1.0f64)];
2192        let mut timestamps = std::collections::HashMap::new();
2193        timestamps.insert(MessageId(1), ts);
2194        apply_temporal_decay(&mut ranked, &timestamps, 30);
2195        // After 3650 days with half_life=30, score should be essentially 0
2196        assert!(ranked[0].1 < 0.001, "score was {}", ranked[0].1);
2197    }
2198
2199    #[test]
2200    fn temporal_decay_small_half_life() {
2201        // Very small half_life (1 day), age = 7 days → 2^(-7) ≈ 0.0078
2202        let now = std::time::SystemTime::now()
2203            .duration_since(std::time::UNIX_EPOCH)
2204            .unwrap_or_default()
2205            .as_secs()
2206            .cast_signed();
2207        let ts = now - 7 * 86400i64;
2208        let mut ranked = vec![(MessageId(1), 1.0f64)];
2209        let mut timestamps = std::collections::HashMap::new();
2210        timestamps.insert(MessageId(1), ts);
2211        apply_temporal_decay(&mut ranked, &timestamps, 1);
2212        assert!(ranked[0].1 < 0.01, "score was {}", ranked[0].1);
2213    }
2214
2215    #[test]
2216    fn mmr_lambda_zero_max_diversity() {
2217        // lambda=0 → pure diversity: second item should be most dissimilar
2218        let ranked = vec![
2219            (MessageId(1), 1.0f64),  // selected first (always highest relevance)
2220            (MessageId(2), 0.9f64),  // orthogonal to 1
2221            (MessageId(3), 0.85f64), // parallel to 1 (max_sim=1.0)
2222        ];
2223        let mut vectors = std::collections::HashMap::new();
2224        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2225        vectors.insert(MessageId(2), vec![0.0f32, 1.0]); // orthogonal
2226        vectors.insert(MessageId(3), vec![1.0f32, 0.0]); // same direction
2227        let result = apply_mmr(&ranked, &vectors, 0.0, 3);
2228        assert_eq!(result.len(), 3);
2229        // After 1 is selected: mmr(2) = 0 - (1-0)*0 = 0, mmr(3) = 0 - 1*1 = -1 → 2 wins
2230        assert_eq!(result[1].0, MessageId(2));
2231    }
2232
2233    #[test]
2234    fn mmr_lambda_one_pure_relevance() {
2235        // lambda=1 → pure relevance, should pick in relevance order
2236        let ranked = vec![
2237            (MessageId(1), 1.0f64),
2238            (MessageId(2), 0.8f64),
2239            (MessageId(3), 0.6f64),
2240        ];
2241        let mut vectors = std::collections::HashMap::new();
2242        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2243        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2244        vectors.insert(MessageId(3), vec![0.5f32, 0.5]);
2245        let result = apply_mmr(&ranked, &vectors, 1.0, 3);
2246        assert_eq!(result.len(), 3);
2247        assert_eq!(result[0].0, MessageId(1));
2248        assert_eq!(result[1].0, MessageId(2));
2249        assert_eq!(result[2].0, MessageId(3));
2250    }
2251
2252    #[test]
2253    fn mmr_limit_zero_returns_empty() {
2254        let ranked = vec![(MessageId(1), 1.0f64), (MessageId(2), 0.8f64)];
2255        let mut vectors = std::collections::HashMap::new();
2256        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2257        vectors.insert(MessageId(2), vec![0.0f32, 1.0]);
2258        let result = apply_mmr(&ranked, &vectors, 0.7, 0);
2259        assert!(result.is_empty());
2260    }
2261
2262    #[test]
2263    fn mmr_duplicate_vectors_penalizes_second() {
2264        // Two items with identical embeddings: second should be heavily penalized
2265        let ranked = vec![
2266            (MessageId(1), 1.0f64),
2267            (MessageId(2), 1.0f64), // same relevance, same direction
2268            (MessageId(3), 0.9f64), // orthogonal, lower relevance
2269        ];
2270        let mut vectors = std::collections::HashMap::new();
2271        vectors.insert(MessageId(1), vec![1.0f32, 0.0]);
2272        vectors.insert(MessageId(2), vec![1.0f32, 0.0]); // duplicate
2273        vectors.insert(MessageId(3), vec![0.0f32, 1.0]); // orthogonal
2274        let result = apply_mmr(&ranked, &vectors, 0.5, 3);
2275        assert_eq!(result.len(), 3);
2276        assert_eq!(result[0].0, MessageId(1));
2277        // msg3 (orthogonal) should be preferred over msg2 (duplicate) with lambda=0.5
2278        assert_eq!(result[1].0, MessageId(3));
2279    }
2280
2281    // Priority 3: proptest
2282
2283    use proptest::prelude::*;
2284
2285    proptest! {
2286        #[test]
2287        fn count_tokens_never_panics(s in ".*") {
2288            let counter = crate::token_counter::TokenCounter::new();
2289            let _ = counter.count_tokens(&s);
2290        }
2291    }
2292}