Skip to main content

zeph_memory/
tiered_retrieval.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `MemFlow` tiered intent-driven retrieval pipeline (issue #3712).
5//!
6//! Classifies each recall query into one of three intent tiers and dispatches to the
7//! cheapest sufficient backend, assembling evidence within a configurable token budget.
8//!
9//! # Tiers
10//!
11//! | Tier | Backend | Top-k | Graph hops |
12//! |------|---------|-------|-----------|
13//! | `ProfileLookup` | Keyword / persona | 3 | 0 |
14//! | `TargetedRetrieval` | Hybrid | 10 | 1 |
15//! | `DeepReasoning` | Hybrid + graph | 20 | 2 |
16//!
17//! The classifier maps the existing [`MemoryRoute`] to an [`IntentClass`]:
18//! - `Keyword | Episodic` → `ProfileLookup`
19//! - `Semantic | Hybrid` → `TargetedRetrieval`
20//! - `Graph` → `DeepReasoning`
21//!
22//! When `classifier_provider` is set and the LLM call fails, the pipeline falls back to
23//! [`HeuristicRouter`] (fail-open, logged at `warn`).
24//!
25//! # Token-budget assembly
26//!
27//! Recall results are truncated to fit within `token_budget`. An optional validation step
28//! asks a lightweight LLM whether the gathered evidence is sufficient; on low confidence,
29//! the pipeline escalates to the next heavier tier (up to `max_escalations`).
30
31use std::collections::HashMap;
32use std::sync::Arc;
33
34use tracing::Instrument as _;
35pub use zeph_config::memory::TieredRetrievalConfig;
36use zeph_llm::any::AnyProvider;
37
38use crate::embedding_store::SearchFilter;
39use crate::error::MemoryError;
40use crate::router::{HeuristicRouter, HybridRouter, MemoryRoute, MemoryRouter};
41use crate::semantic::RecalledMessage;
42use crate::semantic::SemanticMemory;
43use crate::types::{ConversationId, MessageId};
44
45// ── Intent classification ─────────────────────────────────────────────────────
46
47/// Query intent tier for `MemFlow` tiered retrieval.
48///
49/// Maps to increasing levels of retrieval cost and depth.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51#[non_exhaustive]
52pub enum IntentClass {
53    /// Fast profile/attribute lookup — keyword search, top-k = 3.
54    ProfileLookup,
55    /// Standard semantic retrieval — hybrid search with MMR, top-k = 10.
56    TargetedRetrieval,
57    /// Multi-hop reasoning — hybrid + graph traversal, top-k = 20.
58    DeepReasoning,
59}
60
61impl IntentClass {
62    fn from_route(route: MemoryRoute) -> Self {
63        match route {
64            MemoryRoute::Keyword | MemoryRoute::Episodic => Self::ProfileLookup,
65            MemoryRoute::Graph => Self::DeepReasoning,
66            _ => Self::TargetedRetrieval,
67        }
68    }
69
70    fn top_k(self) -> usize {
71        match self {
72            Self::ProfileLookup => 3,
73            Self::TargetedRetrieval => 10,
74            Self::DeepReasoning => 20,
75        }
76    }
77
78    /// Returns the next heavier tier for escalation, or `None` if already at maximum.
79    fn escalate(self) -> Option<Self> {
80        match self {
81            Self::ProfileLookup => Some(Self::TargetedRetrieval),
82            Self::TargetedRetrieval => Some(Self::DeepReasoning),
83            Self::DeepReasoning => None,
84        }
85    }
86}
87
88impl std::fmt::Display for IntentClass {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        match self {
91            Self::ProfileLookup => f.write_str("ProfileLookup"),
92            Self::TargetedRetrieval => f.write_str("TargetedRetrieval"),
93            Self::DeepReasoning => f.write_str("DeepReasoning"),
94        }
95    }
96}
97
98// ── Result ────────────────────────────────────────────────────────────────────
99
100/// Result of tiered retrieval, including evidence and tier metadata.
101#[derive(Debug)]
102pub struct TieredRetrievalResult {
103    /// Retrieved memory entries ordered by relevance score.
104    pub messages: Vec<RecalledMessage>,
105    /// The intent class that produced this result.
106    pub intent: IntentClass,
107    /// Approximate token count of all message content.
108    pub tokens_used: usize,
109    /// Whether the pipeline escalated to a heavier tier due to validation.
110    pub tier_escalated: bool,
111}
112
113// ── Tiered retrieval logic ─────────────────────────────────────────────────────
114
115/// Execute `MemFlow` tiered retrieval for a single query.
116///
117/// Classifies intent, retrieves tier candidates, assembles evidence within budget, and
118/// optionally validates + escalates if evidence is insufficient.
119///
120/// `classifier` should be the provider resolved from
121/// [`TieredRetrievalConfig::classifier_provider`]. When `Some`, a [`HybridRouter`] is
122/// used for LLM-backed intent classification (with [`HeuristicRouter`] as fallback on
123/// LLM failure). When `None`, only the heuristic router is used.
124///
125/// `validator` should be the provider resolved from
126/// [`TieredRetrievalConfig::validator_provider`]. When `Some` and
127/// `config.validation_enabled` is `true`, the validator LLM judges evidence quality and
128/// triggers tier escalation when confidence is low.
129///
130/// `conversation_id` scopes the search to a single conversation. Pass `None` to search globally.
131///
132/// # Errors
133///
134/// Returns an error if any underlying search or database operation fails.
135#[tracing::instrument(name = "memory.tiered.retrieve", skip_all, fields(intent = tracing::field::Empty))]
136pub async fn recall_tiered(
137    memory: &SemanticMemory,
138    query: &str,
139    conversation_id: Option<ConversationId>,
140    classifier: Option<&Arc<AnyProvider>>,
141    validator: Option<&Arc<AnyProvider>>,
142    config: &TieredRetrievalConfig,
143    remaining_budget: Option<usize>,
144) -> Result<TieredRetrievalResult, MemoryError> {
145    let effective_budget =
146        remaining_budget.map_or(config.token_budget, |rb| rb.min(config.token_budget));
147
148    let initial_intent = if let Some(classifier_provider) = classifier {
149        let hybrid = HybridRouter::new(
150            Arc::clone(classifier_provider),
151            MemoryRoute::Hybrid,
152            // 0.7 is the codebase-wide default for HybridRouter confidence threshold
153            0.7,
154        );
155        let decision = if let Ok(d) = tokio::time::timeout(
156            std::time::Duration::from_secs(config.classifier_timeout_secs),
157            hybrid.classify_async(query),
158        )
159        .await
160        {
161            d
162        } else {
163            tracing::warn!("tiered: classifier LLM timed out, falling back to heuristic");
164            HeuristicRouter.route_with_confidence(query)
165        };
166        IntentClass::from_route(decision.route)
167    } else {
168        let decision = HeuristicRouter.route_with_confidence(query);
169        IntentClass::from_route(decision.route)
170    };
171
172    tracing::debug!(intent = %initial_intent, query_len = query.len(), "tiered: classified intent");
173
174    escalation_loop(
175        memory,
176        query,
177        conversation_id,
178        initial_intent,
179        validator,
180        config,
181        effective_budget,
182    )
183    .await
184}
185
186/// Inner escalation loop shared across retrieval entry points.
187///
188/// Iterates through tiers starting at `initial_intent`, retrieving candidates and
189/// validating evidence quality. Escalates to heavier tiers when validation indicates
190/// insufficient evidence.
191async fn escalation_loop(
192    memory: &SemanticMemory,
193    query: &str,
194    conversation_id: Option<ConversationId>,
195    initial_intent: IntentClass,
196    validator: Option<&Arc<AnyProvider>>,
197    config: &TieredRetrievalConfig,
198    effective_budget: usize,
199) -> Result<TieredRetrievalResult, MemoryError> {
200    let mut intent = initial_intent;
201    let mut escalations: u8 = 0;
202    let mut tier_escalated = false;
203
204    loop {
205        let raw_candidates = retrieve_tier(memory, query, conversation_id, intent)
206            .instrument(tracing::debug_span!("memory.tiered.retrieve_tier", tier = %intent))
207            .await?;
208
209        let candidates = score_candidates(memory, query, raw_candidates, config)
210            .instrument(tracing::debug_span!("memory.tiered.score_candidates", tier = %intent))
211            .await?;
212
213        let (messages, tokens_used) = {
214            let _span = tracing::debug_span!("memory.tiered.assemble").entered();
215            assemble_within_budget(candidates, effective_budget)
216        };
217
218        // Validate evidence quality if enabled and a validator is available.
219        if config.validation_enabled
220            && escalations < config.max_escalations
221            && let Some(validator_provider) = validator
222            && let Some(next_tier) = intent.escalate()
223        {
224            let sufficient = validate_evidence(
225                validator_provider,
226                query,
227                &messages,
228                config.validation_threshold,
229                config.validator_timeout_secs,
230            )
231            .instrument(tracing::debug_span!("memory.tiered.validate"))
232            .await;
233            if !sufficient {
234                tracing::debug!(
235                    current_tier = %intent,
236                    next_tier = %next_tier,
237                    escalations,
238                    "tiered: evidence insufficient, escalating tier"
239                );
240                intent = next_tier;
241                escalations += 1;
242                tier_escalated = true;
243                continue;
244            }
245        }
246
247        return Ok(TieredRetrievalResult {
248            messages,
249            intent,
250            tokens_used,
251            tier_escalated,
252        });
253    }
254}
255
256/// Retrieve candidates for the given intent tier from `SemanticMemory`.
257async fn retrieve_tier(
258    memory: &SemanticMemory,
259    query: &str,
260    conversation_id: Option<ConversationId>,
261    intent: IntentClass,
262) -> Result<Vec<RecalledMessage>, MemoryError> {
263    let top_k = intent.top_k();
264    let heuristic = HeuristicRouter;
265
266    let filter = conversation_id.map(|cid| SearchFilter {
267        conversation_id: Some(cid),
268        role: None,
269        category: None,
270    });
271
272    // All tiers route through recall_routed; the heuristic router maps intent-appropriate
273    // routes. Graph traversal for DeepReasoning is left to the caller via recall_graph.
274    memory
275        .recall_routed(query, top_k, filter, &heuristic, None)
276        .await
277}
278
279// ── Five-signal retrieval scoring ─────────────────────────────────────────────
280
281/// Intermediate struct for multi-signal scoring of a single candidate.
282struct ScoredCandidate {
283    recalled: RecalledMessage,
284}
285
286/// Re-score `candidates` using up to five signals and return them sorted by final score.
287///
288/// Overwrites `RecalledMessage::score` with the combined weighted score.
289/// When all signal weights are zero (mis-configuration), logs a debug warning and
290/// returns candidates in original order with scores unchanged.
291///
292/// # Errors
293///
294/// Propagates store errors from timestamp or tier lookups.
295#[allow(clippy::too_many_lines)]
296#[tracing::instrument(name = "memory.tiered.score_candidates", skip_all)]
297async fn score_candidates(
298    memory: &SemanticMemory,
299    query: &str,
300    candidates: Vec<RecalledMessage>,
301    config: &TieredRetrievalConfig,
302) -> Result<Vec<RecalledMessage>, MemoryError> {
303    if candidates.is_empty() {
304        return Ok(candidates);
305    }
306
307    let total_weight = config.similarity_weight
308        + config.recency_weight
309        + config.tfidf_weight
310        + config.cognitive_signal_weight
311        + config.tier_boost_weight;
312
313    if total_weight < f64::EPSILON {
314        tracing::debug!("score_candidates: all signal weights are zero, returning original order");
315        return Ok(candidates);
316    }
317
318    let ids: Vec<MessageId> = candidates
319        .iter()
320        .map(|c| MessageId(c.message.metadata.db_id.unwrap_or(0)))
321        .collect();
322
323    // Fetch timestamps and tiers only when their respective signals are active.
324    let (timestamps_res, tiers_res) = tokio::join!(
325        async {
326            if config.recency_weight > 0.0 {
327                memory.sqlite().message_timestamps(&ids).await
328            } else {
329                Ok(HashMap::new())
330            }
331        },
332        async {
333            if config.tier_boost_weight > 0.0 {
334                memory.sqlite().fetch_tiers(&ids).await
335            } else {
336                Ok(HashMap::new())
337            }
338        },
339    );
340    let timestamps: HashMap<MessageId, i64> = timestamps_res.unwrap_or_else(|e| {
341        tracing::warn!("score_candidates: failed to fetch timestamps: {e:#}");
342        HashMap::new()
343    });
344    let tiers: HashMap<MessageId, String> = tiers_res.unwrap_or_else(|e| {
345        tracing::warn!("score_candidates: failed to fetch tiers: {e:#}");
346        HashMap::new()
347    });
348
349    // Fetch access counts for cognitive signal.
350    let access_counts: HashMap<MessageId, i64> = if config.cognitive_signal_weight > 0.0 {
351        memory
352            .sqlite()
353            .message_access_counts(&ids)
354            .await
355            .unwrap_or_else(|e| {
356                tracing::warn!("score_candidates: failed to fetch access counts: {e:#}");
357                HashMap::new()
358            })
359    } else {
360        HashMap::new()
361    };
362
363    let tfidf_scores = if config.tfidf_weight > 0.0 {
364        compute_tfidf_scores(query, &candidates)
365    } else {
366        vec![0.0_f64; candidates.len()]
367    };
368
369    let max_access: i64 = access_counts.values().copied().max().unwrap_or(0);
370
371    let now_secs = std::time::SystemTime::now()
372        .duration_since(std::time::UNIX_EPOCH)
373        .map_or(0_i64, |d| i64::try_from(d.as_secs()).unwrap_or(i64::MAX));
374
375    let mut scored: Vec<ScoredCandidate> = candidates
376        .into_iter()
377        .zip(tfidf_scores)
378        .map(|(recalled, tfidf)| {
379            let msg_id = MessageId(recalled.message.metadata.db_id.unwrap_or(0));
380
381            let similarity = f64::from(recalled.score);
382            let recency = if config.recency_weight > 0.0 && config.recency_half_life_days > 0 {
383                let ts = timestamps.get(&msg_id).copied().unwrap_or(now_secs);
384                compute_recency(ts, now_secs, config.recency_half_life_days)
385            } else {
386                0.0
387            };
388
389            let cognitive = if config.cognitive_signal_weight > 0.0 && max_access > 0 {
390                let count = access_counts.get(&msg_id).copied().unwrap_or(0);
391                // Both are i64; precision loss is acceptable for a normalized ratio.
392                #[allow(clippy::cast_precision_loss)]
393                let ratio = count as f64 / max_access as f64;
394                ratio
395            } else {
396                0.0
397            };
398
399            let tier_signal = if config.tier_boost_weight > 0.0 {
400                let tier = tiers.get(&msg_id).map_or("episodic", String::as_str);
401                if tier == "semantic" {
402                    config.semantic_tier_boost
403                } else {
404                    0.0
405                }
406            } else {
407                0.0
408            };
409
410            let final_score = config.similarity_weight * similarity
411                + config.recency_weight * recency
412                + config.tfidf_weight * tfidf
413                + config.cognitive_signal_weight * cognitive
414                + config.tier_boost_weight * tier_signal;
415
416            ScoredCandidate {
417                recalled: RecalledMessage {
418                    // f64 → f32: deliberate truncation, score precision is adequate.
419                    #[allow(clippy::cast_possible_truncation)]
420                    score: final_score as f32,
421                    ..recalled
422                },
423            }
424        })
425        .collect();
426
427    scored.sort_by(|a, b| {
428        b.recalled
429            .score
430            .partial_cmp(&a.recalled.score)
431            .unwrap_or(std::cmp::Ordering::Equal)
432    });
433
434    Ok(scored.into_iter().map(|s| s.recalled).collect())
435}
436
437/// Compute recency score in `[0.0, 1.0]` using exponential half-life decay.
438///
439/// Returns `1.0` for a message created right now and approaches `0.0` for very old messages.
440/// A message that is exactly `half_life_days` old receives a score of `0.5`.
441///
442/// # Precondition
443///
444/// `half_life_days` must be greater than zero. Passing `0` is a programming error and will
445/// panic in debug builds.
446fn compute_recency(created_at_secs: i64, now_secs: i64, half_life_days: u32) -> f64 {
447    debug_assert!(half_life_days > 0, "half_life_days must be > 0");
448    // Precision loss is acceptable: age is a time delta in days, not a financial value.
449    #[allow(clippy::cast_precision_loss)]
450    let age_days = (now_secs - created_at_secs).max(0) as f64 / 86_400.0;
451    let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
452    (-lambda * age_days).exp()
453}
454
455/// Compute per-candidate TF-IDF scores against `query`, normalised to `[0.0, 1.0]`.
456///
457/// Uses a simplified TF-IDF with BM25-style parameters (k1 = 1.2, b = 0.75).
458/// Scores are normalised by dividing by the maximum score in the batch.
459fn compute_tfidf_scores(query: &str, candidates: &[RecalledMessage]) -> Vec<f64> {
460    const K1: f64 = 1.2;
461    const B: f64 = 0.75;
462
463    let query_terms: Vec<String> = query.split_whitespace().map(str::to_lowercase).collect();
464
465    if query_terms.is_empty() || candidates.is_empty() {
466        return vec![0.0; candidates.len()];
467    }
468
469    // Tokenise each candidate document.
470    let docs: Vec<Vec<String>> = candidates
471        .iter()
472        .map(|c| {
473            c.message
474                .content
475                .split_whitespace()
476                .map(str::to_lowercase)
477                .collect()
478        })
479        .collect();
480
481    // Precision loss is acceptable for term-frequency ratios over small candidate sets.
482    #[allow(clippy::cast_precision_loss)]
483    let n = docs.len() as f64;
484    #[allow(clippy::cast_precision_loss)]
485    let avg_dl = docs.iter().map(|d| d.len() as f64).sum::<f64>().max(1.0) / n;
486
487    let mut scores = vec![0.0_f64; docs.len()];
488
489    for term in &query_terms {
490        // Document frequency across the candidate set.
491        #[allow(clippy::cast_precision_loss)]
492        let df = docs.iter().filter(|d| d.contains(term)).count() as f64;
493        if df == 0.0 {
494            continue;
495        }
496        // IDF with smoothing.
497        let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
498
499        for (i, doc) in docs.iter().enumerate() {
500            #[allow(clippy::cast_precision_loss)]
501            let dl = doc.len() as f64;
502            #[allow(clippy::cast_precision_loss)]
503            let tf = doc.iter().filter(|t| *t == term).count() as f64;
504            let bm25_tf = (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * dl / avg_dl));
505            scores[i] += idf * bm25_tf;
506        }
507    }
508
509    // Normalise to [0.0, 1.0].
510    let max_score = scores.iter().copied().fold(0.0_f64, f64::max);
511    if max_score > 0.0 {
512        for s in &mut scores {
513            *s /= max_score;
514        }
515    }
516
517    scores
518}
519
520/// Truncate `candidates` to fit within `budget` tokens.
521///
522/// Uses the same 4 chars-per-token approximation as the rest of the codebase.
523/// Returns the retained messages and the total token count consumed.
524fn assemble_within_budget(
525    candidates: Vec<RecalledMessage>,
526    budget: usize,
527) -> (Vec<RecalledMessage>, usize) {
528    let mut retained = Vec::with_capacity(candidates.len());
529    let mut total_tokens: usize = 0;
530
531    for msg in candidates {
532        let msg_tokens = zeph_common::text::estimate_tokens(&msg.message.content);
533        if total_tokens.saturating_add(msg_tokens) > budget {
534            break;
535        }
536        total_tokens += msg_tokens;
537        retained.push(msg);
538    }
539
540    (retained, total_tokens)
541}
542
543/// Ask the validator LLM whether the gathered evidence is sufficient for the query.
544///
545/// Returns `true` when the validator's confidence is >= `threshold` or when the
546/// call fails (fail-open: prefer serving potentially incomplete evidence over blocking).
547async fn validate_evidence(
548    provider: &Arc<AnyProvider>,
549    query: &str,
550    messages: &[RecalledMessage],
551    threshold: f32,
552    timeout_secs: u64,
553) -> bool {
554    use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
555
556    if messages.is_empty() {
557        return false;
558    }
559
560    let evidence_snippet = messages
561        .iter()
562        .take(5)
563        .map(|m| {
564            zeph_common::sanitize::strip_control_chars_preserve_whitespace(&m.message.content)
565                .chars()
566                .take(200)
567                .collect::<String>()
568        })
569        .collect::<Vec<_>>()
570        .join("\n---\n");
571
572    let system = "You are an evidence quality judge. \
573        Given a query and evidence snippets, decide if the evidence is sufficient to answer the query. \
574        Respond ONLY with a JSON object: {\"sufficient\": true|false, \"confidence\": 0.0-1.0}";
575
576    let sanitized_query = zeph_common::sanitize::strip_control_chars_preserve_whitespace(query);
577    let user = format!(
578        "<query>{}</query>\n<evidence>{}</evidence>",
579        sanitized_query.chars().take(500).collect::<String>(),
580        evidence_snippet
581    );
582
583    let msgs = vec![
584        Message {
585            role: Role::System,
586            content: system.to_owned(),
587            parts: vec![],
588            metadata: MessageMetadata::default(),
589        },
590        Message {
591            role: Role::User,
592            content: user,
593            parts: vec![],
594            metadata: MessageMetadata::default(),
595        },
596    ];
597
598    match tokio::time::timeout(
599        std::time::Duration::from_secs(timeout_secs),
600        provider.chat(&msgs),
601    )
602    .await
603    {
604        Ok(Ok(raw)) => parse_validation_response(&raw, threshold),
605        Ok(Err(e)) => {
606            tracing::warn!(error = %e, "tiered: validator LLM call failed, treating as sufficient");
607            true
608        }
609        Err(_) => {
610            tracing::warn!("tiered: validator LLM call timed out, treating as sufficient");
611            true
612        }
613    }
614}
615
616fn parse_validation_response(raw: &str, threshold: f32) -> bool {
617    let json_str = raw
618        .find('{')
619        .and_then(|s| raw[s..].rfind('}').map(|e| &raw[s..=s + e]))
620        .unwrap_or("");
621
622    if let Ok(v) = serde_json::from_str::<serde_json::Value>(json_str) {
623        let sufficient = v
624            .get("sufficient")
625            .and_then(serde_json::Value::as_bool)
626            .unwrap_or(true);
627        #[allow(clippy::cast_possible_truncation)]
628        let confidence = v
629            .get("confidence")
630            .and_then(serde_json::Value::as_f64)
631            .map_or(1.0, |c| c.clamp(0.0, 1.0) as f32);
632
633        return sufficient && confidence >= threshold;
634    }
635
636    tracing::debug!("tiered: could not parse validator response, treating as sufficient");
637    true
638}
639
640// ── Tests ─────────────────────────────────────────────────────────────────────
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645    use crate::router::MemoryRoute;
646    use crate::semantic::RecalledMessage;
647    use zeph_llm::provider::{Message, MessageMetadata, Role};
648
649    fn make_message(content: &str) -> RecalledMessage {
650        RecalledMessage {
651            message: Message {
652                role: Role::User,
653                content: content.to_owned(),
654                parts: vec![],
655                metadata: MessageMetadata::default(),
656            },
657            score: 1.0,
658        }
659    }
660
661    // ── Signal scoring unit tests ─────────────────────────────────────────────
662
663    #[test]
664    fn compute_recency_zero_age_returns_one() {
665        let now = 1_000_000_i64;
666        let score = compute_recency(now, now, 7);
667        assert!((score - 1.0).abs() < 1e-9);
668    }
669
670    #[test]
671    fn compute_recency_half_life_returns_half() {
672        let now = 1_000_000_i64;
673        let half_life_days = 7_u32;
674        let age_secs = i64::from(half_life_days) * 86_400;
675        let score = compute_recency(now - age_secs, now, half_life_days);
676        assert!((score - 0.5).abs() < 1e-9);
677    }
678
679    #[test]
680    fn compute_recency_large_age_approaches_zero() {
681        // 1000 days with 7-day half-life: score ≈ 2^(-1000/7) ≈ 1e-43
682        let now = 1_000_i64 * 86_400;
683        let score = compute_recency(0, now, 7);
684        assert!(score < 1e-6, "score was {score}");
685    }
686
687    #[test]
688    fn compute_recency_future_timestamp_clamped_to_one() {
689        let now = 1_000_000_i64;
690        // created_at in the future → age < 0 → clamped to 0 → score = 1.0
691        let score = compute_recency(now + 86_400, now, 7);
692        assert!((score - 1.0).abs() < 1e-9);
693    }
694
695    #[test]
696    fn compute_tfidf_empty_candidates_returns_empty() {
697        let scores = compute_tfidf_scores("hello", &[]);
698        assert!(scores.is_empty());
699    }
700
701    #[test]
702    fn compute_tfidf_empty_query_returns_zeros() {
703        let candidates = vec![make_message("hello world")];
704        let scores = compute_tfidf_scores("", &candidates);
705        assert_eq!(scores.len(), 1);
706        assert!(scores[0].abs() < f64::EPSILON);
707    }
708
709    #[test]
710    fn compute_tfidf_exact_match_scores_nonzero() {
711        let candidates = vec![
712            make_message("the quick brown fox"),
713            make_message("completely unrelated content"),
714        ];
715        let scores = compute_tfidf_scores("fox", &candidates);
716        assert_eq!(scores.len(), 2);
717        // The message containing "fox" must score higher.
718        assert!(scores[0] > scores[1]);
719    }
720
721    #[test]
722    fn compute_tfidf_no_match_returns_zeros() {
723        let candidates = vec![make_message("apple banana cherry")];
724        let scores = compute_tfidf_scores("zzz xyz", &candidates);
725        assert_eq!(scores.len(), 1);
726        assert!(scores[0].abs() < f64::EPSILON);
727    }
728
729    #[test]
730    fn compute_tfidf_max_score_normalised_to_one() {
731        let candidates = vec![
732            make_message("rust programming language"),
733            make_message("python programming language"),
734            make_message("java is a drink"),
735        ];
736        let scores = compute_tfidf_scores("rust programming", &candidates);
737        let max = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
738        assert!((max - 1.0).abs() < 1e-9, "max score must be 1.0, got {max}");
739    }
740
741    #[test]
742    fn score_candidates_empty_input_returns_empty() {
743        // Pure sync test via tokio runtime.
744        let rt = tokio::runtime::Builder::new_current_thread()
745            .enable_all()
746            .build()
747            .unwrap();
748        rt.block_on(async {
749            let memory = crate::testing::mock_semantic_memory()
750                .await
751                .expect("mock_semantic_memory");
752            let config = TieredRetrievalConfig::default();
753            let result = score_candidates(&memory, "query", vec![], &config)
754                .await
755                .expect("score_candidates must not fail on empty input");
756            assert!(result.is_empty());
757        });
758    }
759
760    #[test]
761    fn score_candidates_similarity_weight_reorders_by_score() {
762        let rt = tokio::runtime::Builder::new_current_thread()
763            .enable_all()
764            .build()
765            .unwrap();
766        rt.block_on(async {
767            let memory = crate::testing::mock_semantic_memory()
768                .await
769                .expect("mock_semantic_memory");
770            // similarity_weight = 1.0 activates the scoring formula; candidates provided in
771            // ascending score order to verify that sort_by reorders them descending.
772            let config = TieredRetrievalConfig {
773                similarity_weight: 1.0,
774                ..TieredRetrievalConfig::default()
775            };
776            let candidates = vec![
777                RecalledMessage {
778                    message: make_message("low score").message,
779                    score: 0.1,
780                },
781                RecalledMessage {
782                    message: make_message("high score").message,
783                    score: 0.9,
784                },
785                RecalledMessage {
786                    message: make_message("mid score").message,
787                    score: 0.5,
788                },
789            ];
790            let result = score_candidates(&memory, "query", candidates, &config)
791                .await
792                .expect("score_candidates must not fail");
793            assert_eq!(result.len(), 3);
794            // Descending order: 0.9 → 0.5 → 0.1
795            assert!(
796                result[0].score >= result[1].score,
797                "first score {} must be >= second score {}",
798                result[0].score,
799                result[1].score
800            );
801            assert!(
802                result[1].score >= result[2].score,
803                "second score {} must be >= third score {}",
804                result[1].score,
805                result[2].score
806            );
807            // Highest original score should be ranked first.
808            assert!(
809                (result[0].score - 0.9_f32).abs() < 1e-4,
810                "expected first score ~0.9, got {}",
811                result[0].score
812            );
813        });
814    }
815
816    #[test]
817    fn score_candidates_all_zero_weights_returns_original_order() {
818        let rt = tokio::runtime::Builder::new_current_thread()
819            .enable_all()
820            .build()
821            .unwrap();
822        rt.block_on(async {
823            let memory = crate::testing::mock_semantic_memory()
824                .await
825                .expect("mock_semantic_memory");
826            // All weights zero: score_candidates must return candidates unchanged.
827            let config = TieredRetrievalConfig {
828                similarity_weight: 0.0,
829                recency_weight: 0.0,
830                tfidf_weight: 0.0,
831                cognitive_signal_weight: 0.0,
832                tier_boost_weight: 0.0,
833                ..TieredRetrievalConfig::default()
834            };
835            let candidates = vec![
836                RecalledMessage {
837                    message: make_message("first").message,
838                    score: 0.9,
839                },
840                RecalledMessage {
841                    message: make_message("second").message,
842                    score: 0.1,
843                },
844            ];
845            let result = score_candidates(&memory, "query", candidates, &config)
846                .await
847                .expect("score_candidates must not fail");
848            // Original order preserved because all-zero weights triggers early return.
849            assert!((f64::from(result[0].score) - 0.9).abs() < 1e-6);
850            assert!((f64::from(result[1].score) - 0.1).abs() < 1e-6);
851        });
852    }
853
854    #[test]
855    fn tiered_retrieval_config_signal_weight_defaults() {
856        let cfg = TieredRetrievalConfig::default();
857        assert!((cfg.similarity_weight - 1.0).abs() < f64::EPSILON);
858        assert!(cfg.recency_weight.abs() < f64::EPSILON);
859        assert_eq!(cfg.recency_half_life_days, 7);
860        assert!(cfg.tfidf_weight.abs() < f64::EPSILON);
861        assert!(cfg.cognitive_signal_weight.abs() < f64::EPSILON);
862        assert!(cfg.tier_boost_weight.abs() < f64::EPSILON);
863        assert!((cfg.semantic_tier_boost - 1.0).abs() < f64::EPSILON);
864    }
865
866    #[test]
867    fn intent_class_from_route_mapping() {
868        assert_eq!(
869            IntentClass::from_route(MemoryRoute::Keyword),
870            IntentClass::ProfileLookup
871        );
872        assert_eq!(
873            IntentClass::from_route(MemoryRoute::Episodic),
874            IntentClass::ProfileLookup
875        );
876        assert_eq!(
877            IntentClass::from_route(MemoryRoute::Semantic),
878            IntentClass::TargetedRetrieval
879        );
880        assert_eq!(
881            IntentClass::from_route(MemoryRoute::Hybrid),
882            IntentClass::TargetedRetrieval
883        );
884        assert_eq!(
885            IntentClass::from_route(MemoryRoute::Graph),
886            IntentClass::DeepReasoning
887        );
888    }
889
890    #[test]
891    fn intent_class_top_k() {
892        assert_eq!(IntentClass::ProfileLookup.top_k(), 3);
893        assert_eq!(IntentClass::TargetedRetrieval.top_k(), 10);
894        assert_eq!(IntentClass::DeepReasoning.top_k(), 20);
895    }
896
897    #[test]
898    fn intent_class_escalate_chain() {
899        assert_eq!(
900            IntentClass::ProfileLookup.escalate(),
901            Some(IntentClass::TargetedRetrieval)
902        );
903        assert_eq!(
904            IntentClass::TargetedRetrieval.escalate(),
905            Some(IntentClass::DeepReasoning)
906        );
907        assert_eq!(IntentClass::DeepReasoning.escalate(), None);
908    }
909
910    #[test]
911    fn assemble_within_budget_empty_input() {
912        let (retained, tokens) = assemble_within_budget(vec![], 4096);
913        assert!(retained.is_empty());
914        assert_eq!(tokens, 0);
915    }
916
917    #[test]
918    fn assemble_within_budget_zero_budget_returns_nothing() {
919        let candidates = vec![make_message("hello"), make_message("world")];
920        let (retained, tokens) = assemble_within_budget(candidates, 0);
921        assert!(retained.is_empty(), "budget=0 must retain no messages");
922        assert_eq!(tokens, 0);
923    }
924
925    #[test]
926    fn assemble_within_budget_truncates_at_limit() {
927        // estimate_tokens = chars / 4. Each message: "a " * 400 = 800 chars = 200 tokens.
928        // Budget 250 fits exactly one (200 <= 250) but not two (200 + 200 = 400 > 250).
929        let msg = "a ".repeat(400);
930        let candidates = vec![make_message(&msg), make_message(&msg)];
931        let (retained, tokens) = assemble_within_budget(candidates, 250);
932        assert_eq!(
933            retained.len(),
934            1,
935            "tight budget must keep only first message"
936        );
937        assert_eq!(tokens, 200);
938    }
939
940    #[test]
941    fn parse_validation_response_missing_fields_defaults_to_sufficient() {
942        // Neither "sufficient" nor "confidence" present → defaults: sufficient=true, confidence=1.0
943        let raw = "{}";
944        assert!(
945            parse_validation_response(raw, 0.6),
946            "missing fields must default to sufficient"
947        );
948    }
949
950    #[test]
951    fn tiered_retrieval_config_defaults() {
952        let cfg = TieredRetrievalConfig::default();
953        assert!(!cfg.enabled);
954        assert_eq!(cfg.token_budget, 4096);
955        assert!(!cfg.validation_enabled);
956        assert_eq!(cfg.max_escalations, 1);
957        // Verify config-driven timeout defaults (fix #4250).
958        assert_eq!(cfg.classifier_timeout_secs, 5);
959        assert_eq!(cfg.validator_timeout_secs, 5);
960    }
961
962    #[test]
963    fn tiered_retrieval_config_timeout_fields_propagate() {
964        // Verify that custom timeout values survive a round-trip through the struct.
965        let cfg = TieredRetrievalConfig {
966            classifier_timeout_secs: 10,
967            validator_timeout_secs: 15,
968            ..TieredRetrievalConfig::default()
969        };
970        assert_eq!(cfg.classifier_timeout_secs, 10);
971        assert_eq!(cfg.validator_timeout_secs, 15);
972        // Confirm the durations would be built correctly from the fields.
973        let classifier_dur = std::time::Duration::from_secs(cfg.classifier_timeout_secs);
974        let validator_dur = std::time::Duration::from_secs(cfg.validator_timeout_secs);
975        assert_eq!(classifier_dur.as_secs(), 10);
976        assert_eq!(validator_dur.as_secs(), 15);
977    }
978
979    #[test]
980    fn parse_validation_response_sufficient() {
981        let raw = r#"{"sufficient": true, "confidence": 0.9}"#;
982        assert!(parse_validation_response(raw, 0.6));
983    }
984
985    #[test]
986    fn parse_validation_response_insufficient() {
987        let raw = r#"{"sufficient": false, "confidence": 0.4}"#;
988        assert!(!parse_validation_response(raw, 0.6));
989    }
990
991    #[test]
992    fn parse_validation_response_low_confidence() {
993        let raw = r#"{"sufficient": true, "confidence": 0.3}"#;
994        // threshold = 0.6, confidence 0.3 < 0.6 → insufficient
995        assert!(!parse_validation_response(raw, 0.6));
996    }
997
998    #[test]
999    fn parse_validation_response_malformed_json_treats_as_sufficient() {
1000        let raw = "not json at all";
1001        assert!(parse_validation_response(raw, 0.6));
1002    }
1003
1004    #[test]
1005    fn intent_class_display() {
1006        assert_eq!(IntentClass::ProfileLookup.to_string(), "ProfileLookup");
1007        assert_eq!(
1008            IntentClass::TargetedRetrieval.to_string(),
1009            "TargetedRetrieval"
1010        );
1011        assert_eq!(IntentClass::DeepReasoning.to_string(), "DeepReasoning");
1012    }
1013
1014    // ── Async tests ───────────────────────────────────────────────────────────
1015
1016    /// Test 1: `recall_tiered` with `classifier = None` uses the `HeuristicRouter` path.
1017    ///
1018    /// With no classifier provider, the pipeline must route via heuristic, complete without
1019    /// error, and return a result whose intent maps from the heuristic route.
1020    #[tokio::test]
1021    async fn recall_tiered_no_classifier_uses_heuristic_router() {
1022        let memory = crate::testing::mock_semantic_memory()
1023            .await
1024            .expect("mock_semantic_memory");
1025        let config = TieredRetrievalConfig {
1026            enabled: true,
1027            validation_enabled: false,
1028            ..TieredRetrievalConfig::default()
1029        };
1030
1031        let result = recall_tiered(&memory, "what is my name", None, None, None, &config, None)
1032            .await
1033            .expect("recall_tiered must not fail");
1034
1035        // HeuristicRouter classifies "what is my name" via keyword/semantic heuristic.
1036        // The exact tier depends on the heuristic, but the pipeline must complete.
1037        assert!(
1038            !result.tier_escalated,
1039            "no escalation when validation is off"
1040        );
1041        assert!(result.tokens_used <= config.token_budget);
1042    }
1043
1044    /// Test 2: `recall_tiered` with `classifier = Some(...)` exercises the `HybridRouter` path.
1045    ///
1046    /// The mock LLM returns a JSON route decision; the pipeline must parse it and use the
1047    /// resulting intent class.
1048    #[tokio::test]
1049    async fn recall_tiered_with_classifier_uses_hybrid_router() {
1050        use zeph_llm::mock::MockProvider;
1051
1052        let memory = crate::testing::mock_semantic_memory()
1053            .await
1054            .expect("mock_semantic_memory");
1055
1056        // HybridRouter asks the LLM for a route; respond with a valid JSON route decision.
1057        let route_json = r#"{"route": "Semantic", "confidence": 0.9}"#.to_owned();
1058        let mut mock = MockProvider::with_responses(vec![route_json]);
1059        mock.supports_embeddings = true;
1060        mock.embedding = vec![0.1_f32; 384];
1061        let classifier = Arc::new(AnyProvider::Mock(mock));
1062
1063        let config = TieredRetrievalConfig {
1064            enabled: true,
1065            validation_enabled: false,
1066            ..TieredRetrievalConfig::default()
1067        };
1068
1069        let result = recall_tiered(
1070            &memory,
1071            "semantic query about the user",
1072            None,
1073            Some(&classifier),
1074            None,
1075            &config,
1076            None,
1077        )
1078        .await
1079        .expect("recall_tiered with classifier must not fail");
1080
1081        assert!(!result.tier_escalated);
1082        assert!(result.tokens_used <= config.token_budget);
1083    }
1084
1085    /// Test 3: Escalation loop sets `tier_escalated = true` when the validator returns
1086    /// insufficient evidence and a heavier tier is available.
1087    ///
1088    /// Validator response with `{"sufficient": false, "confidence": 0.2}` triggers escalation.
1089    /// After escalation, the second-tier retrieve runs and the result has `tier_escalated = true`.
1090    #[tokio::test]
1091    async fn recall_tiered_escalates_when_evidence_insufficient() {
1092        use zeph_llm::mock::MockProvider;
1093
1094        let memory = crate::testing::mock_semantic_memory()
1095            .await
1096            .expect("mock_semantic_memory");
1097
1098        // First validator response: insufficient. Second: sufficient (prevents infinite loop).
1099        let insufficient = r#"{"sufficient": false, "confidence": 0.1}"#.to_owned();
1100        let sufficient = r#"{"sufficient": true, "confidence": 0.95}"#.to_owned();
1101        let mut validator_mock = MockProvider::with_responses(vec![insufficient, sufficient]);
1102        validator_mock.supports_embeddings = true;
1103        let validator = Arc::new(AnyProvider::Mock(validator_mock));
1104
1105        let config = TieredRetrievalConfig {
1106            enabled: true,
1107            validation_enabled: true,
1108            validation_threshold: 0.6,
1109            max_escalations: 2,
1110            ..TieredRetrievalConfig::default()
1111        };
1112
1113        let result = recall_tiered(
1114            &memory,
1115            "deep query",
1116            None,
1117            None,
1118            Some(&validator),
1119            &config,
1120            None,
1121        )
1122        .await
1123        .expect("escalation path must not fail");
1124
1125        assert!(
1126            result.tier_escalated,
1127            "must set tier_escalated when validator triggers escalation"
1128        );
1129    }
1130
1131    /// Test 4a: `validate_evidence` returns `true` (fail-open) when the validator LLM times out.
1132    ///
1133    /// Uses `with_delay` to force the validator past the configured timeout threshold.
1134    /// The pipeline must treat a timed-out validator as sufficient (fail-open) and not escalate.
1135    #[tokio::test]
1136    async fn validate_evidence_timeout_is_fail_open() {
1137        use zeph_llm::mock::MockProvider;
1138
1139        let memory = crate::testing::mock_semantic_memory()
1140            .await
1141            .expect("mock_semantic_memory");
1142
1143        // Store a message so validate_evidence gets a non-empty slice and actually calls the LLM.
1144        let conv_id = memory
1145            .sqlite()
1146            .create_conversation()
1147            .await
1148            .expect("create_conversation");
1149        memory
1150            .remember(conv_id, "user", "some evidence content", None)
1151            .await
1152            .expect("remember");
1153
1154        // Delay > validator_timeout_secs causes the internal tokio::time::timeout to fire.
1155        let slow_mock = MockProvider::default().with_delay(6_000);
1156        let validator = Arc::new(AnyProvider::Mock(slow_mock));
1157
1158        let config = TieredRetrievalConfig {
1159            enabled: true,
1160            validation_enabled: true,
1161            validation_threshold: 0.6,
1162            max_escalations: 1,
1163            validator_timeout_secs: 5,
1164            ..TieredRetrievalConfig::default()
1165        };
1166
1167        // The slow validator should time out and be treated as sufficient → no escalation.
1168        let result = recall_tiered(
1169            &memory,
1170            "evidence",
1171            None,
1172            None,
1173            Some(&validator),
1174            &config,
1175            None,
1176        )
1177        .await
1178        .expect("timeout path must not propagate as error");
1179
1180        // Fail-open: timed-out validator means no escalation.
1181        assert!(
1182            !result.tier_escalated,
1183            "validator timeout must be treated as sufficient (fail-open)"
1184        );
1185    }
1186
1187    /// Test 4b: `validate_evidence` returns `true` (fail-open) when the validator LLM errors.
1188    ///
1189    /// A failing provider simulates a transient API error. The pipeline must not escalate.
1190    #[tokio::test]
1191    async fn validate_evidence_llm_error_is_fail_open() {
1192        use zeph_llm::mock::MockProvider;
1193
1194        let memory = crate::testing::mock_semantic_memory()
1195            .await
1196            .expect("mock_semantic_memory");
1197
1198        // Store a message so validate_evidence gets a non-empty slice and actually calls the LLM.
1199        let conv_id = memory
1200            .sqlite()
1201            .create_conversation()
1202            .await
1203            .expect("create_conversation");
1204        memory
1205            .remember(conv_id, "user", "some evidence content", None)
1206            .await
1207            .expect("remember");
1208
1209        let failing_mock = MockProvider::failing();
1210        let validator = Arc::new(AnyProvider::Mock(failing_mock));
1211
1212        let config = TieredRetrievalConfig {
1213            enabled: true,
1214            validation_enabled: true,
1215            validation_threshold: 0.6,
1216            max_escalations: 1,
1217            ..TieredRetrievalConfig::default()
1218        };
1219
1220        let result = recall_tiered(
1221            &memory,
1222            "evidence",
1223            None,
1224            None,
1225            Some(&validator),
1226            &config,
1227            None,
1228        )
1229        .await
1230        .expect("LLM error path must not propagate as retrieval error");
1231
1232        assert!(
1233            !result.tier_escalated,
1234            "validator LLM error must be treated as sufficient (fail-open)"
1235        );
1236    }
1237
1238    /// Test 5: `recall_tiered` with a `conversation_id` filter passes it to `retrieve_tier`,
1239    /// which in turn applies a `SearchFilter` scoping the search to that conversation.
1240    ///
1241    /// The pipeline must complete successfully even when the filter yields zero results.
1242    #[tokio::test]
1243    async fn recall_tiered_with_conversation_id_filter() {
1244        let memory = crate::testing::mock_semantic_memory()
1245            .await
1246            .expect("mock_semantic_memory");
1247
1248        let conv_id = ConversationId(42);
1249        let config = TieredRetrievalConfig {
1250            enabled: true,
1251            validation_enabled: false,
1252            ..TieredRetrievalConfig::default()
1253        };
1254
1255        let result = recall_tiered(
1256            &memory,
1257            "what did we discuss",
1258            Some(conv_id),
1259            None,
1260            None,
1261            &config,
1262            None,
1263        )
1264        .await
1265        .expect("conversation-scoped recall must not fail");
1266
1267        // No messages stored for this conversation — result must be empty but valid.
1268        assert!(result.messages.is_empty());
1269        assert_eq!(result.tokens_used, 0);
1270        assert!(!result.tier_escalated);
1271    }
1272}