1use std::collections::HashMap;
32use std::sync::Arc;
33
34use tracing::Instrument as _;
35pub use zeph_config::memory::TieredRetrievalConfig;
36use zeph_llm::any::AnyProvider;
37
38use crate::embedding_store::SearchFilter;
39use crate::error::MemoryError;
40use crate::router::{HeuristicRouter, HybridRouter, MemoryRoute, MemoryRouter};
41use crate::semantic::RecalledMessage;
42use crate::semantic::SemanticMemory;
43use crate::types::{ConversationId, MessageId};
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51#[non_exhaustive]
52pub enum IntentClass {
53 ProfileLookup,
55 TargetedRetrieval,
57 DeepReasoning,
59}
60
61impl IntentClass {
62 fn from_route(route: MemoryRoute) -> Self {
63 match route {
64 MemoryRoute::Keyword | MemoryRoute::Episodic => Self::ProfileLookup,
65 MemoryRoute::Graph => Self::DeepReasoning,
66 _ => Self::TargetedRetrieval,
67 }
68 }
69
70 fn top_k(self) -> usize {
71 match self {
72 Self::ProfileLookup => 3,
73 Self::TargetedRetrieval => 10,
74 Self::DeepReasoning => 20,
75 }
76 }
77
78 fn escalate(self) -> Option<Self> {
80 match self {
81 Self::ProfileLookup => Some(Self::TargetedRetrieval),
82 Self::TargetedRetrieval => Some(Self::DeepReasoning),
83 Self::DeepReasoning => None,
84 }
85 }
86}
87
88impl std::fmt::Display for IntentClass {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 match self {
91 Self::ProfileLookup => f.write_str("ProfileLookup"),
92 Self::TargetedRetrieval => f.write_str("TargetedRetrieval"),
93 Self::DeepReasoning => f.write_str("DeepReasoning"),
94 }
95 }
96}
97
98#[derive(Debug)]
102pub struct TieredRetrievalResult {
103 pub messages: Vec<RecalledMessage>,
105 pub intent: IntentClass,
107 pub tokens_used: usize,
109 pub tier_escalated: bool,
111}
112
113#[tracing::instrument(name = "memory.tiered.retrieve", skip_all, fields(intent = tracing::field::Empty))]
136pub async fn recall_tiered(
137 memory: &SemanticMemory,
138 query: &str,
139 conversation_id: Option<ConversationId>,
140 classifier: Option<&Arc<AnyProvider>>,
141 validator: Option<&Arc<AnyProvider>>,
142 config: &TieredRetrievalConfig,
143 remaining_budget: Option<usize>,
144) -> Result<TieredRetrievalResult, MemoryError> {
145 let effective_budget =
146 remaining_budget.map_or(config.token_budget, |rb| rb.min(config.token_budget));
147
148 let initial_intent = if let Some(classifier_provider) = classifier {
149 let hybrid = HybridRouter::new(
150 Arc::clone(classifier_provider),
151 MemoryRoute::Hybrid,
152 0.7,
154 );
155 let decision = if let Ok(d) = tokio::time::timeout(
156 std::time::Duration::from_secs(config.classifier_timeout_secs),
157 hybrid.classify_async(query),
158 )
159 .await
160 {
161 d
162 } else {
163 tracing::warn!("tiered: classifier LLM timed out, falling back to heuristic");
164 HeuristicRouter.route_with_confidence(query)
165 };
166 IntentClass::from_route(decision.route)
167 } else {
168 let decision = HeuristicRouter.route_with_confidence(query);
169 IntentClass::from_route(decision.route)
170 };
171
172 tracing::debug!(intent = %initial_intent, query_len = query.len(), "tiered: classified intent");
173
174 escalation_loop(
175 memory,
176 query,
177 conversation_id,
178 initial_intent,
179 validator,
180 config,
181 effective_budget,
182 )
183 .await
184}
185
186async fn escalation_loop(
192 memory: &SemanticMemory,
193 query: &str,
194 conversation_id: Option<ConversationId>,
195 initial_intent: IntentClass,
196 validator: Option<&Arc<AnyProvider>>,
197 config: &TieredRetrievalConfig,
198 effective_budget: usize,
199) -> Result<TieredRetrievalResult, MemoryError> {
200 let mut intent = initial_intent;
201 let mut escalations: u8 = 0;
202 let mut tier_escalated = false;
203
204 loop {
205 let raw_candidates = retrieve_tier(memory, query, conversation_id, intent)
206 .instrument(tracing::debug_span!("memory.tiered.retrieve_tier", tier = %intent))
207 .await?;
208
209 let candidates = score_candidates(memory, query, raw_candidates, config)
210 .instrument(tracing::debug_span!("memory.tiered.score_candidates", tier = %intent))
211 .await?;
212
213 let (messages, tokens_used) = {
214 let _span = tracing::debug_span!("memory.tiered.assemble").entered();
215 assemble_within_budget(candidates, effective_budget)
216 };
217
218 if config.validation_enabled
220 && escalations < config.max_escalations
221 && let Some(validator_provider) = validator
222 && let Some(next_tier) = intent.escalate()
223 {
224 let sufficient = validate_evidence(
225 validator_provider,
226 query,
227 &messages,
228 config.validation_threshold,
229 config.validator_timeout_secs,
230 )
231 .instrument(tracing::debug_span!("memory.tiered.validate"))
232 .await;
233 if !sufficient {
234 tracing::debug!(
235 current_tier = %intent,
236 next_tier = %next_tier,
237 escalations,
238 "tiered: evidence insufficient, escalating tier"
239 );
240 intent = next_tier;
241 escalations += 1;
242 tier_escalated = true;
243 continue;
244 }
245 }
246
247 return Ok(TieredRetrievalResult {
248 messages,
249 intent,
250 tokens_used,
251 tier_escalated,
252 });
253 }
254}
255
256async fn retrieve_tier(
258 memory: &SemanticMemory,
259 query: &str,
260 conversation_id: Option<ConversationId>,
261 intent: IntentClass,
262) -> Result<Vec<RecalledMessage>, MemoryError> {
263 let top_k = intent.top_k();
264 let heuristic = HeuristicRouter;
265
266 let filter = conversation_id.map(|cid| SearchFilter {
267 conversation_id: Some(cid),
268 role: None,
269 category: None,
270 });
271
272 memory
275 .recall_routed(query, top_k, filter, &heuristic, None)
276 .await
277}
278
279struct ScoredCandidate {
283 recalled: RecalledMessage,
284}
285
286#[allow(clippy::too_many_lines)]
296#[tracing::instrument(name = "memory.tiered.score_candidates", skip_all)]
297async fn score_candidates(
298 memory: &SemanticMemory,
299 query: &str,
300 candidates: Vec<RecalledMessage>,
301 config: &TieredRetrievalConfig,
302) -> Result<Vec<RecalledMessage>, MemoryError> {
303 if candidates.is_empty() {
304 return Ok(candidates);
305 }
306
307 let total_weight = config.similarity_weight
308 + config.recency_weight
309 + config.tfidf_weight
310 + config.cognitive_signal_weight
311 + config.tier_boost_weight;
312
313 if total_weight < f64::EPSILON {
314 tracing::debug!("score_candidates: all signal weights are zero, returning original order");
315 return Ok(candidates);
316 }
317
318 let ids: Vec<MessageId> = candidates
319 .iter()
320 .map(|c| MessageId(c.message.metadata.db_id.unwrap_or(0)))
321 .collect();
322
323 let (timestamps_res, tiers_res) = tokio::join!(
325 async {
326 if config.recency_weight > 0.0 {
327 memory.sqlite().message_timestamps(&ids).await
328 } else {
329 Ok(HashMap::new())
330 }
331 },
332 async {
333 if config.tier_boost_weight > 0.0 {
334 memory.sqlite().fetch_tiers(&ids).await
335 } else {
336 Ok(HashMap::new())
337 }
338 },
339 );
340 let timestamps: HashMap<MessageId, i64> = timestamps_res.unwrap_or_else(|e| {
341 tracing::warn!("score_candidates: failed to fetch timestamps: {e:#}");
342 HashMap::new()
343 });
344 let tiers: HashMap<MessageId, String> = tiers_res.unwrap_or_else(|e| {
345 tracing::warn!("score_candidates: failed to fetch tiers: {e:#}");
346 HashMap::new()
347 });
348
349 let access_counts: HashMap<MessageId, i64> = if config.cognitive_signal_weight > 0.0 {
351 memory
352 .sqlite()
353 .message_access_counts(&ids)
354 .await
355 .unwrap_or_else(|e| {
356 tracing::warn!("score_candidates: failed to fetch access counts: {e:#}");
357 HashMap::new()
358 })
359 } else {
360 HashMap::new()
361 };
362
363 let tfidf_scores = if config.tfidf_weight > 0.0 {
364 compute_tfidf_scores(query, &candidates)
365 } else {
366 vec![0.0_f64; candidates.len()]
367 };
368
369 let max_access: i64 = access_counts.values().copied().max().unwrap_or(0);
370
371 let now_secs = std::time::SystemTime::now()
372 .duration_since(std::time::UNIX_EPOCH)
373 .map_or(0_i64, |d| i64::try_from(d.as_secs()).unwrap_or(i64::MAX));
374
375 let mut scored: Vec<ScoredCandidate> = candidates
376 .into_iter()
377 .zip(tfidf_scores)
378 .map(|(recalled, tfidf)| {
379 let msg_id = MessageId(recalled.message.metadata.db_id.unwrap_or(0));
380
381 let similarity = f64::from(recalled.score);
382 let recency = if config.recency_weight > 0.0 && config.recency_half_life_days > 0 {
383 let ts = timestamps.get(&msg_id).copied().unwrap_or(now_secs);
384 compute_recency(ts, now_secs, config.recency_half_life_days)
385 } else {
386 0.0
387 };
388
389 let cognitive = if config.cognitive_signal_weight > 0.0 && max_access > 0 {
390 let count = access_counts.get(&msg_id).copied().unwrap_or(0);
391 #[allow(clippy::cast_precision_loss)]
393 let ratio = count as f64 / max_access as f64;
394 ratio
395 } else {
396 0.0
397 };
398
399 let tier_signal = if config.tier_boost_weight > 0.0 {
400 let tier = tiers.get(&msg_id).map_or("episodic", String::as_str);
401 if tier == "semantic" {
402 config.semantic_tier_boost
403 } else {
404 0.0
405 }
406 } else {
407 0.0
408 };
409
410 let final_score = config.similarity_weight * similarity
411 + config.recency_weight * recency
412 + config.tfidf_weight * tfidf
413 + config.cognitive_signal_weight * cognitive
414 + config.tier_boost_weight * tier_signal;
415
416 ScoredCandidate {
417 recalled: RecalledMessage {
418 #[allow(clippy::cast_possible_truncation)]
420 score: final_score as f32,
421 ..recalled
422 },
423 }
424 })
425 .collect();
426
427 scored.sort_by(|a, b| {
428 b.recalled
429 .score
430 .partial_cmp(&a.recalled.score)
431 .unwrap_or(std::cmp::Ordering::Equal)
432 });
433
434 Ok(scored.into_iter().map(|s| s.recalled).collect())
435}
436
437fn compute_recency(created_at_secs: i64, now_secs: i64, half_life_days: u32) -> f64 {
447 debug_assert!(half_life_days > 0, "half_life_days must be > 0");
448 #[allow(clippy::cast_precision_loss)]
450 let age_days = (now_secs - created_at_secs).max(0) as f64 / 86_400.0;
451 let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
452 (-lambda * age_days).exp()
453}
454
455fn compute_tfidf_scores(query: &str, candidates: &[RecalledMessage]) -> Vec<f64> {
460 const K1: f64 = 1.2;
461 const B: f64 = 0.75;
462
463 let query_terms: Vec<String> = query.split_whitespace().map(str::to_lowercase).collect();
464
465 if query_terms.is_empty() || candidates.is_empty() {
466 return vec![0.0; candidates.len()];
467 }
468
469 let docs: Vec<Vec<String>> = candidates
471 .iter()
472 .map(|c| {
473 c.message
474 .content
475 .split_whitespace()
476 .map(str::to_lowercase)
477 .collect()
478 })
479 .collect();
480
481 #[allow(clippy::cast_precision_loss)]
483 let n = docs.len() as f64;
484 #[allow(clippy::cast_precision_loss)]
485 let avg_dl = docs.iter().map(|d| d.len() as f64).sum::<f64>().max(1.0) / n;
486
487 let mut scores = vec![0.0_f64; docs.len()];
488
489 for term in &query_terms {
490 #[allow(clippy::cast_precision_loss)]
492 let df = docs.iter().filter(|d| d.contains(term)).count() as f64;
493 if df == 0.0 {
494 continue;
495 }
496 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
498
499 for (i, doc) in docs.iter().enumerate() {
500 #[allow(clippy::cast_precision_loss)]
501 let dl = doc.len() as f64;
502 #[allow(clippy::cast_precision_loss)]
503 let tf = doc.iter().filter(|t| *t == term).count() as f64;
504 let bm25_tf = (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * dl / avg_dl));
505 scores[i] += idf * bm25_tf;
506 }
507 }
508
509 let max_score = scores.iter().copied().fold(0.0_f64, f64::max);
511 if max_score > 0.0 {
512 for s in &mut scores {
513 *s /= max_score;
514 }
515 }
516
517 scores
518}
519
520fn assemble_within_budget(
525 candidates: Vec<RecalledMessage>,
526 budget: usize,
527) -> (Vec<RecalledMessage>, usize) {
528 let mut retained = Vec::with_capacity(candidates.len());
529 let mut total_tokens: usize = 0;
530
531 for msg in candidates {
532 let msg_tokens = zeph_common::text::estimate_tokens(&msg.message.content);
533 if total_tokens.saturating_add(msg_tokens) > budget {
534 break;
535 }
536 total_tokens += msg_tokens;
537 retained.push(msg);
538 }
539
540 (retained, total_tokens)
541}
542
543async fn validate_evidence(
548 provider: &Arc<AnyProvider>,
549 query: &str,
550 messages: &[RecalledMessage],
551 threshold: f32,
552 timeout_secs: u64,
553) -> bool {
554 use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
555
556 if messages.is_empty() {
557 return false;
558 }
559
560 let evidence_snippet = messages
561 .iter()
562 .take(5)
563 .map(|m| {
564 zeph_common::sanitize::strip_control_chars_preserve_whitespace(&m.message.content)
565 .chars()
566 .take(200)
567 .collect::<String>()
568 })
569 .collect::<Vec<_>>()
570 .join("\n---\n");
571
572 let system = "You are an evidence quality judge. \
573 Given a query and evidence snippets, decide if the evidence is sufficient to answer the query. \
574 Respond ONLY with a JSON object: {\"sufficient\": true|false, \"confidence\": 0.0-1.0}";
575
576 let sanitized_query = zeph_common::sanitize::strip_control_chars_preserve_whitespace(query);
577 let user = format!(
578 "<query>{}</query>\n<evidence>{}</evidence>",
579 sanitized_query.chars().take(500).collect::<String>(),
580 evidence_snippet
581 );
582
583 let msgs = vec![
584 Message {
585 role: Role::System,
586 content: system.to_owned(),
587 parts: vec![],
588 metadata: MessageMetadata::default(),
589 },
590 Message {
591 role: Role::User,
592 content: user,
593 parts: vec![],
594 metadata: MessageMetadata::default(),
595 },
596 ];
597
598 match tokio::time::timeout(
599 std::time::Duration::from_secs(timeout_secs),
600 provider.chat(&msgs),
601 )
602 .await
603 {
604 Ok(Ok(raw)) => parse_validation_response(&raw, threshold),
605 Ok(Err(e)) => {
606 tracing::warn!(error = %e, "tiered: validator LLM call failed, treating as sufficient");
607 true
608 }
609 Err(_) => {
610 tracing::warn!("tiered: validator LLM call timed out, treating as sufficient");
611 true
612 }
613 }
614}
615
616fn parse_validation_response(raw: &str, threshold: f32) -> bool {
617 let json_str = raw
618 .find('{')
619 .and_then(|s| raw[s..].rfind('}').map(|e| &raw[s..=s + e]))
620 .unwrap_or("");
621
622 if let Ok(v) = serde_json::from_str::<serde_json::Value>(json_str) {
623 let sufficient = v
624 .get("sufficient")
625 .and_then(serde_json::Value::as_bool)
626 .unwrap_or(true);
627 #[allow(clippy::cast_possible_truncation)]
628 let confidence = v
629 .get("confidence")
630 .and_then(serde_json::Value::as_f64)
631 .map_or(1.0, |c| c.clamp(0.0, 1.0) as f32);
632
633 return sufficient && confidence >= threshold;
634 }
635
636 tracing::debug!("tiered: could not parse validator response, treating as sufficient");
637 true
638}
639
640#[cfg(test)]
643mod tests {
644 use super::*;
645 use crate::router::MemoryRoute;
646 use crate::semantic::RecalledMessage;
647 use zeph_llm::provider::{Message, MessageMetadata, Role};
648
649 fn make_message(content: &str) -> RecalledMessage {
650 RecalledMessage {
651 message: Message {
652 role: Role::User,
653 content: content.to_owned(),
654 parts: vec![],
655 metadata: MessageMetadata::default(),
656 },
657 score: 1.0,
658 }
659 }
660
661 #[test]
664 fn compute_recency_zero_age_returns_one() {
665 let now = 1_000_000_i64;
666 let score = compute_recency(now, now, 7);
667 assert!((score - 1.0).abs() < 1e-9);
668 }
669
670 #[test]
671 fn compute_recency_half_life_returns_half() {
672 let now = 1_000_000_i64;
673 let half_life_days = 7_u32;
674 let age_secs = i64::from(half_life_days) * 86_400;
675 let score = compute_recency(now - age_secs, now, half_life_days);
676 assert!((score - 0.5).abs() < 1e-9);
677 }
678
679 #[test]
680 fn compute_recency_large_age_approaches_zero() {
681 let now = 1_000_i64 * 86_400;
683 let score = compute_recency(0, now, 7);
684 assert!(score < 1e-6, "score was {score}");
685 }
686
687 #[test]
688 fn compute_recency_future_timestamp_clamped_to_one() {
689 let now = 1_000_000_i64;
690 let score = compute_recency(now + 86_400, now, 7);
692 assert!((score - 1.0).abs() < 1e-9);
693 }
694
695 #[test]
696 fn compute_tfidf_empty_candidates_returns_empty() {
697 let scores = compute_tfidf_scores("hello", &[]);
698 assert!(scores.is_empty());
699 }
700
701 #[test]
702 fn compute_tfidf_empty_query_returns_zeros() {
703 let candidates = vec![make_message("hello world")];
704 let scores = compute_tfidf_scores("", &candidates);
705 assert_eq!(scores.len(), 1);
706 assert!(scores[0].abs() < f64::EPSILON);
707 }
708
709 #[test]
710 fn compute_tfidf_exact_match_scores_nonzero() {
711 let candidates = vec![
712 make_message("the quick brown fox"),
713 make_message("completely unrelated content"),
714 ];
715 let scores = compute_tfidf_scores("fox", &candidates);
716 assert_eq!(scores.len(), 2);
717 assert!(scores[0] > scores[1]);
719 }
720
721 #[test]
722 fn compute_tfidf_no_match_returns_zeros() {
723 let candidates = vec![make_message("apple banana cherry")];
724 let scores = compute_tfidf_scores("zzz xyz", &candidates);
725 assert_eq!(scores.len(), 1);
726 assert!(scores[0].abs() < f64::EPSILON);
727 }
728
729 #[test]
730 fn compute_tfidf_max_score_normalised_to_one() {
731 let candidates = vec![
732 make_message("rust programming language"),
733 make_message("python programming language"),
734 make_message("java is a drink"),
735 ];
736 let scores = compute_tfidf_scores("rust programming", &candidates);
737 let max = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
738 assert!((max - 1.0).abs() < 1e-9, "max score must be 1.0, got {max}");
739 }
740
741 #[test]
742 fn score_candidates_empty_input_returns_empty() {
743 let rt = tokio::runtime::Builder::new_current_thread()
745 .enable_all()
746 .build()
747 .unwrap();
748 rt.block_on(async {
749 let memory = crate::testing::mock_semantic_memory()
750 .await
751 .expect("mock_semantic_memory");
752 let config = TieredRetrievalConfig::default();
753 let result = score_candidates(&memory, "query", vec![], &config)
754 .await
755 .expect("score_candidates must not fail on empty input");
756 assert!(result.is_empty());
757 });
758 }
759
760 #[test]
761 fn score_candidates_similarity_weight_reorders_by_score() {
762 let rt = tokio::runtime::Builder::new_current_thread()
763 .enable_all()
764 .build()
765 .unwrap();
766 rt.block_on(async {
767 let memory = crate::testing::mock_semantic_memory()
768 .await
769 .expect("mock_semantic_memory");
770 let config = TieredRetrievalConfig {
773 similarity_weight: 1.0,
774 ..TieredRetrievalConfig::default()
775 };
776 let candidates = vec![
777 RecalledMessage {
778 message: make_message("low score").message,
779 score: 0.1,
780 },
781 RecalledMessage {
782 message: make_message("high score").message,
783 score: 0.9,
784 },
785 RecalledMessage {
786 message: make_message("mid score").message,
787 score: 0.5,
788 },
789 ];
790 let result = score_candidates(&memory, "query", candidates, &config)
791 .await
792 .expect("score_candidates must not fail");
793 assert_eq!(result.len(), 3);
794 assert!(
796 result[0].score >= result[1].score,
797 "first score {} must be >= second score {}",
798 result[0].score,
799 result[1].score
800 );
801 assert!(
802 result[1].score >= result[2].score,
803 "second score {} must be >= third score {}",
804 result[1].score,
805 result[2].score
806 );
807 assert!(
809 (result[0].score - 0.9_f32).abs() < 1e-4,
810 "expected first score ~0.9, got {}",
811 result[0].score
812 );
813 });
814 }
815
816 #[test]
817 fn score_candidates_all_zero_weights_returns_original_order() {
818 let rt = tokio::runtime::Builder::new_current_thread()
819 .enable_all()
820 .build()
821 .unwrap();
822 rt.block_on(async {
823 let memory = crate::testing::mock_semantic_memory()
824 .await
825 .expect("mock_semantic_memory");
826 let config = TieredRetrievalConfig {
828 similarity_weight: 0.0,
829 recency_weight: 0.0,
830 tfidf_weight: 0.0,
831 cognitive_signal_weight: 0.0,
832 tier_boost_weight: 0.0,
833 ..TieredRetrievalConfig::default()
834 };
835 let candidates = vec![
836 RecalledMessage {
837 message: make_message("first").message,
838 score: 0.9,
839 },
840 RecalledMessage {
841 message: make_message("second").message,
842 score: 0.1,
843 },
844 ];
845 let result = score_candidates(&memory, "query", candidates, &config)
846 .await
847 .expect("score_candidates must not fail");
848 assert!((f64::from(result[0].score) - 0.9).abs() < 1e-6);
850 assert!((f64::from(result[1].score) - 0.1).abs() < 1e-6);
851 });
852 }
853
854 #[test]
855 fn tiered_retrieval_config_signal_weight_defaults() {
856 let cfg = TieredRetrievalConfig::default();
857 assert!((cfg.similarity_weight - 1.0).abs() < f64::EPSILON);
858 assert!(cfg.recency_weight.abs() < f64::EPSILON);
859 assert_eq!(cfg.recency_half_life_days, 7);
860 assert!(cfg.tfidf_weight.abs() < f64::EPSILON);
861 assert!(cfg.cognitive_signal_weight.abs() < f64::EPSILON);
862 assert!(cfg.tier_boost_weight.abs() < f64::EPSILON);
863 assert!((cfg.semantic_tier_boost - 1.0).abs() < f64::EPSILON);
864 }
865
866 #[test]
867 fn intent_class_from_route_mapping() {
868 assert_eq!(
869 IntentClass::from_route(MemoryRoute::Keyword),
870 IntentClass::ProfileLookup
871 );
872 assert_eq!(
873 IntentClass::from_route(MemoryRoute::Episodic),
874 IntentClass::ProfileLookup
875 );
876 assert_eq!(
877 IntentClass::from_route(MemoryRoute::Semantic),
878 IntentClass::TargetedRetrieval
879 );
880 assert_eq!(
881 IntentClass::from_route(MemoryRoute::Hybrid),
882 IntentClass::TargetedRetrieval
883 );
884 assert_eq!(
885 IntentClass::from_route(MemoryRoute::Graph),
886 IntentClass::DeepReasoning
887 );
888 }
889
890 #[test]
891 fn intent_class_top_k() {
892 assert_eq!(IntentClass::ProfileLookup.top_k(), 3);
893 assert_eq!(IntentClass::TargetedRetrieval.top_k(), 10);
894 assert_eq!(IntentClass::DeepReasoning.top_k(), 20);
895 }
896
897 #[test]
898 fn intent_class_escalate_chain() {
899 assert_eq!(
900 IntentClass::ProfileLookup.escalate(),
901 Some(IntentClass::TargetedRetrieval)
902 );
903 assert_eq!(
904 IntentClass::TargetedRetrieval.escalate(),
905 Some(IntentClass::DeepReasoning)
906 );
907 assert_eq!(IntentClass::DeepReasoning.escalate(), None);
908 }
909
910 #[test]
911 fn assemble_within_budget_empty_input() {
912 let (retained, tokens) = assemble_within_budget(vec![], 4096);
913 assert!(retained.is_empty());
914 assert_eq!(tokens, 0);
915 }
916
917 #[test]
918 fn assemble_within_budget_zero_budget_returns_nothing() {
919 let candidates = vec![make_message("hello"), make_message("world")];
920 let (retained, tokens) = assemble_within_budget(candidates, 0);
921 assert!(retained.is_empty(), "budget=0 must retain no messages");
922 assert_eq!(tokens, 0);
923 }
924
925 #[test]
926 fn assemble_within_budget_truncates_at_limit() {
927 let msg = "a ".repeat(400);
930 let candidates = vec![make_message(&msg), make_message(&msg)];
931 let (retained, tokens) = assemble_within_budget(candidates, 250);
932 assert_eq!(
933 retained.len(),
934 1,
935 "tight budget must keep only first message"
936 );
937 assert_eq!(tokens, 200);
938 }
939
940 #[test]
941 fn parse_validation_response_missing_fields_defaults_to_sufficient() {
942 let raw = "{}";
944 assert!(
945 parse_validation_response(raw, 0.6),
946 "missing fields must default to sufficient"
947 );
948 }
949
950 #[test]
951 fn tiered_retrieval_config_defaults() {
952 let cfg = TieredRetrievalConfig::default();
953 assert!(!cfg.enabled);
954 assert_eq!(cfg.token_budget, 4096);
955 assert!(!cfg.validation_enabled);
956 assert_eq!(cfg.max_escalations, 1);
957 assert_eq!(cfg.classifier_timeout_secs, 5);
959 assert_eq!(cfg.validator_timeout_secs, 5);
960 }
961
962 #[test]
963 fn tiered_retrieval_config_timeout_fields_propagate() {
964 let cfg = TieredRetrievalConfig {
966 classifier_timeout_secs: 10,
967 validator_timeout_secs: 15,
968 ..TieredRetrievalConfig::default()
969 };
970 assert_eq!(cfg.classifier_timeout_secs, 10);
971 assert_eq!(cfg.validator_timeout_secs, 15);
972 let classifier_dur = std::time::Duration::from_secs(cfg.classifier_timeout_secs);
974 let validator_dur = std::time::Duration::from_secs(cfg.validator_timeout_secs);
975 assert_eq!(classifier_dur.as_secs(), 10);
976 assert_eq!(validator_dur.as_secs(), 15);
977 }
978
979 #[test]
980 fn parse_validation_response_sufficient() {
981 let raw = r#"{"sufficient": true, "confidence": 0.9}"#;
982 assert!(parse_validation_response(raw, 0.6));
983 }
984
985 #[test]
986 fn parse_validation_response_insufficient() {
987 let raw = r#"{"sufficient": false, "confidence": 0.4}"#;
988 assert!(!parse_validation_response(raw, 0.6));
989 }
990
991 #[test]
992 fn parse_validation_response_low_confidence() {
993 let raw = r#"{"sufficient": true, "confidence": 0.3}"#;
994 assert!(!parse_validation_response(raw, 0.6));
996 }
997
998 #[test]
999 fn parse_validation_response_malformed_json_treats_as_sufficient() {
1000 let raw = "not json at all";
1001 assert!(parse_validation_response(raw, 0.6));
1002 }
1003
1004 #[test]
1005 fn intent_class_display() {
1006 assert_eq!(IntentClass::ProfileLookup.to_string(), "ProfileLookup");
1007 assert_eq!(
1008 IntentClass::TargetedRetrieval.to_string(),
1009 "TargetedRetrieval"
1010 );
1011 assert_eq!(IntentClass::DeepReasoning.to_string(), "DeepReasoning");
1012 }
1013
1014 #[tokio::test]
1021 async fn recall_tiered_no_classifier_uses_heuristic_router() {
1022 let memory = crate::testing::mock_semantic_memory()
1023 .await
1024 .expect("mock_semantic_memory");
1025 let config = TieredRetrievalConfig {
1026 enabled: true,
1027 validation_enabled: false,
1028 ..TieredRetrievalConfig::default()
1029 };
1030
1031 let result = recall_tiered(&memory, "what is my name", None, None, None, &config, None)
1032 .await
1033 .expect("recall_tiered must not fail");
1034
1035 assert!(
1038 !result.tier_escalated,
1039 "no escalation when validation is off"
1040 );
1041 assert!(result.tokens_used <= config.token_budget);
1042 }
1043
1044 #[tokio::test]
1049 async fn recall_tiered_with_classifier_uses_hybrid_router() {
1050 use zeph_llm::mock::MockProvider;
1051
1052 let memory = crate::testing::mock_semantic_memory()
1053 .await
1054 .expect("mock_semantic_memory");
1055
1056 let route_json = r#"{"route": "Semantic", "confidence": 0.9}"#.to_owned();
1058 let mut mock = MockProvider::with_responses(vec![route_json]);
1059 mock.supports_embeddings = true;
1060 mock.embedding = vec![0.1_f32; 384];
1061 let classifier = Arc::new(AnyProvider::Mock(mock));
1062
1063 let config = TieredRetrievalConfig {
1064 enabled: true,
1065 validation_enabled: false,
1066 ..TieredRetrievalConfig::default()
1067 };
1068
1069 let result = recall_tiered(
1070 &memory,
1071 "semantic query about the user",
1072 None,
1073 Some(&classifier),
1074 None,
1075 &config,
1076 None,
1077 )
1078 .await
1079 .expect("recall_tiered with classifier must not fail");
1080
1081 assert!(!result.tier_escalated);
1082 assert!(result.tokens_used <= config.token_budget);
1083 }
1084
1085 #[tokio::test]
1091 async fn recall_tiered_escalates_when_evidence_insufficient() {
1092 use zeph_llm::mock::MockProvider;
1093
1094 let memory = crate::testing::mock_semantic_memory()
1095 .await
1096 .expect("mock_semantic_memory");
1097
1098 let insufficient = r#"{"sufficient": false, "confidence": 0.1}"#.to_owned();
1100 let sufficient = r#"{"sufficient": true, "confidence": 0.95}"#.to_owned();
1101 let mut validator_mock = MockProvider::with_responses(vec![insufficient, sufficient]);
1102 validator_mock.supports_embeddings = true;
1103 let validator = Arc::new(AnyProvider::Mock(validator_mock));
1104
1105 let config = TieredRetrievalConfig {
1106 enabled: true,
1107 validation_enabled: true,
1108 validation_threshold: 0.6,
1109 max_escalations: 2,
1110 ..TieredRetrievalConfig::default()
1111 };
1112
1113 let result = recall_tiered(
1114 &memory,
1115 "deep query",
1116 None,
1117 None,
1118 Some(&validator),
1119 &config,
1120 None,
1121 )
1122 .await
1123 .expect("escalation path must not fail");
1124
1125 assert!(
1126 result.tier_escalated,
1127 "must set tier_escalated when validator triggers escalation"
1128 );
1129 }
1130
1131 #[tokio::test]
1136 async fn validate_evidence_timeout_is_fail_open() {
1137 use zeph_llm::mock::MockProvider;
1138
1139 let memory = crate::testing::mock_semantic_memory()
1140 .await
1141 .expect("mock_semantic_memory");
1142
1143 let conv_id = memory
1145 .sqlite()
1146 .create_conversation()
1147 .await
1148 .expect("create_conversation");
1149 memory
1150 .remember(conv_id, "user", "some evidence content", None)
1151 .await
1152 .expect("remember");
1153
1154 let slow_mock = MockProvider::default().with_delay(6_000);
1156 let validator = Arc::new(AnyProvider::Mock(slow_mock));
1157
1158 let config = TieredRetrievalConfig {
1159 enabled: true,
1160 validation_enabled: true,
1161 validation_threshold: 0.6,
1162 max_escalations: 1,
1163 validator_timeout_secs: 5,
1164 ..TieredRetrievalConfig::default()
1165 };
1166
1167 let result = recall_tiered(
1169 &memory,
1170 "evidence",
1171 None,
1172 None,
1173 Some(&validator),
1174 &config,
1175 None,
1176 )
1177 .await
1178 .expect("timeout path must not propagate as error");
1179
1180 assert!(
1182 !result.tier_escalated,
1183 "validator timeout must be treated as sufficient (fail-open)"
1184 );
1185 }
1186
1187 #[tokio::test]
1191 async fn validate_evidence_llm_error_is_fail_open() {
1192 use zeph_llm::mock::MockProvider;
1193
1194 let memory = crate::testing::mock_semantic_memory()
1195 .await
1196 .expect("mock_semantic_memory");
1197
1198 let conv_id = memory
1200 .sqlite()
1201 .create_conversation()
1202 .await
1203 .expect("create_conversation");
1204 memory
1205 .remember(conv_id, "user", "some evidence content", None)
1206 .await
1207 .expect("remember");
1208
1209 let failing_mock = MockProvider::failing();
1210 let validator = Arc::new(AnyProvider::Mock(failing_mock));
1211
1212 let config = TieredRetrievalConfig {
1213 enabled: true,
1214 validation_enabled: true,
1215 validation_threshold: 0.6,
1216 max_escalations: 1,
1217 ..TieredRetrievalConfig::default()
1218 };
1219
1220 let result = recall_tiered(
1221 &memory,
1222 "evidence",
1223 None,
1224 None,
1225 Some(&validator),
1226 &config,
1227 None,
1228 )
1229 .await
1230 .expect("LLM error path must not propagate as retrieval error");
1231
1232 assert!(
1233 !result.tier_escalated,
1234 "validator LLM error must be treated as sufficient (fail-open)"
1235 );
1236 }
1237
1238 #[tokio::test]
1243 async fn recall_tiered_with_conversation_id_filter() {
1244 let memory = crate::testing::mock_semantic_memory()
1245 .await
1246 .expect("mock_semantic_memory");
1247
1248 let conv_id = ConversationId(42);
1249 let config = TieredRetrievalConfig {
1250 enabled: true,
1251 validation_enabled: false,
1252 ..TieredRetrievalConfig::default()
1253 };
1254
1255 let result = recall_tiered(
1256 &memory,
1257 "what did we discuss",
1258 Some(conv_id),
1259 None,
1260 None,
1261 &config,
1262 None,
1263 )
1264 .await
1265 .expect("conversation-scoped recall must not fail");
1266
1267 assert!(result.messages.is_empty());
1269 assert_eq!(result.tokens_used, 0);
1270 assert!(!result.tier_escalated);
1271 }
1272}