1use std::collections::HashMap;
32use std::sync::Arc;
33
34use tracing::Instrument as _;
35pub use zeph_config::memory::TieredRetrievalConfig;
36use zeph_llm::any::AnyProvider;
37
38use crate::embedding_store::SearchFilter;
39use crate::error::MemoryError;
40use crate::router::{HeuristicRouter, HybridRouter, MemoryRoute, MemoryRouter};
41use crate::semantic::RecalledMessage;
42use crate::semantic::SemanticMemory;
43use crate::types::{ConversationId, MessageId};
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51#[non_exhaustive]
52pub enum IntentClass {
53 ProfileLookup,
55 TargetedRetrieval,
57 DeepReasoning,
59}
60
61impl IntentClass {
62 fn from_route(route: MemoryRoute) -> Self {
63 match route {
64 MemoryRoute::Keyword | MemoryRoute::Episodic => Self::ProfileLookup,
65 MemoryRoute::Graph => Self::DeepReasoning,
66 _ => Self::TargetedRetrieval,
67 }
68 }
69
70 fn top_k(self) -> usize {
71 match self {
72 Self::ProfileLookup => 3,
73 Self::TargetedRetrieval => 10,
74 Self::DeepReasoning => 20,
75 }
76 }
77
78 fn escalate(self) -> Option<Self> {
80 match self {
81 Self::ProfileLookup => Some(Self::TargetedRetrieval),
82 Self::TargetedRetrieval => Some(Self::DeepReasoning),
83 Self::DeepReasoning => None,
84 }
85 }
86}
87
88impl std::fmt::Display for IntentClass {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 match self {
91 Self::ProfileLookup => f.write_str("ProfileLookup"),
92 Self::TargetedRetrieval => f.write_str("TargetedRetrieval"),
93 Self::DeepReasoning => f.write_str("DeepReasoning"),
94 }
95 }
96}
97
98#[derive(Debug)]
102pub struct TieredRetrievalResult {
103 pub messages: Vec<RecalledMessage>,
105 pub intent: IntentClass,
107 pub tokens_used: usize,
109 pub tier_escalated: bool,
111}
112
113#[tracing::instrument(name = "memory.tiered.retrieve", skip_all, fields(intent = tracing::field::Empty))]
136pub async fn recall_tiered(
137 memory: &SemanticMemory,
138 query: &str,
139 conversation_id: Option<ConversationId>,
140 classifier: Option<&Arc<AnyProvider>>,
141 validator: Option<&Arc<AnyProvider>>,
142 config: &TieredRetrievalConfig,
143 remaining_budget: Option<usize>,
144) -> Result<TieredRetrievalResult, MemoryError> {
145 let effective_budget =
146 remaining_budget.map_or(config.token_budget, |rb| rb.min(config.token_budget));
147
148 let initial_intent = if let Some(classifier_provider) = classifier {
149 let hybrid = HybridRouter::new(
150 Arc::clone(classifier_provider),
151 MemoryRoute::Hybrid,
152 0.7,
154 );
155 let decision = if let Ok(d) = tokio::time::timeout(
156 std::time::Duration::from_secs(config.classifier_timeout_secs),
157 hybrid.classify_async(query),
158 )
159 .await
160 {
161 d
162 } else {
163 tracing::warn!("tiered: classifier LLM timed out, falling back to heuristic");
164 HeuristicRouter.route_with_confidence(query)
165 };
166 IntentClass::from_route(decision.route)
167 } else {
168 let decision = HeuristicRouter.route_with_confidence(query);
169 IntentClass::from_route(decision.route)
170 };
171
172 tracing::debug!(intent = %initial_intent, query_len = query.len(), "tiered: classified intent");
173
174 escalation_loop(
175 memory,
176 query,
177 conversation_id,
178 initial_intent,
179 validator,
180 config,
181 effective_budget,
182 )
183 .await
184}
185
186async fn escalation_loop(
192 memory: &SemanticMemory,
193 query: &str,
194 conversation_id: Option<ConversationId>,
195 initial_intent: IntentClass,
196 validator: Option<&Arc<AnyProvider>>,
197 config: &TieredRetrievalConfig,
198 effective_budget: usize,
199) -> Result<TieredRetrievalResult, MemoryError> {
200 let mut intent = initial_intent;
201 let mut escalations: u8 = 0;
202 let mut tier_escalated = false;
203
204 loop {
205 let raw_candidates = retrieve_tier(memory, query, conversation_id, intent, config)
206 .instrument(tracing::debug_span!("memory.tiered.retrieve_tier", tier = %intent))
207 .await?;
208
209 let candidates = score_candidates(memory, query, raw_candidates, config)
210 .instrument(tracing::debug_span!("memory.tiered.score_candidates", tier = %intent))
211 .await?;
212
213 let (messages, tokens_used) = {
214 let _span = tracing::debug_span!("memory.tiered.assemble").entered();
215 assemble_within_budget(candidates, effective_budget)
216 };
217
218 if config.validation_enabled
220 && escalations < config.max_escalations
221 && let Some(validator_provider) = validator
222 && let Some(next_tier) = intent.escalate()
223 {
224 let sufficient = validate_evidence(
225 validator_provider,
226 query,
227 &messages,
228 config.validation_threshold,
229 config.validator_timeout_secs,
230 )
231 .instrument(tracing::debug_span!("memory.tiered.validate"))
232 .await;
233 if !sufficient {
234 tracing::debug!(
235 current_tier = %intent,
236 next_tier = %next_tier,
237 escalations,
238 "tiered: evidence insufficient, escalating tier"
239 );
240 intent = next_tier;
241 escalations += 1;
242 tier_escalated = true;
243 continue;
244 }
245 }
246
247 return Ok(TieredRetrievalResult {
248 messages,
249 intent,
250 tokens_used,
251 tier_escalated,
252 });
253 }
254}
255
256async fn retrieve_tier(
261 memory: &SemanticMemory,
262 query: &str,
263 conversation_id: Option<ConversationId>,
264 intent: IntentClass,
265 config: &TieredRetrievalConfig,
266) -> Result<Vec<RecalledMessage>, MemoryError> {
267 let top_k = intent.top_k();
268 let heuristic = HeuristicRouter;
269
270 let filter = conversation_id.map(|cid| SearchFilter {
271 conversation_id: Some(cid),
272 role: None,
273 category: None,
274 });
275
276 if intent == IntentClass::DeepReasoning && config.deep_reasoning_query_conditioned {
278 use crate::graph::HelaSpreadParams;
279 use zeph_llm::provider::{Message, MessageMetadata, Role};
280 let params = HelaSpreadParams::default();
281 match memory.recall_graph_hela(query, top_k, params).await {
282 Ok(hela_facts) if !hela_facts.is_empty() => {
283 let messages: Vec<RecalledMessage> = hela_facts
284 .into_iter()
285 .map(|f| {
286 let content = format!(
287 "{} — {} — {}",
288 f.edge.relation, f.edge.fact, f.edge.canonical_relation
289 );
290 RecalledMessage {
291 message: Message {
292 role: Role::Assistant,
293 content,
294 parts: vec![],
295 metadata: MessageMetadata::default(),
296 },
297 score: f.score,
298 }
299 })
300 .collect();
301 tracing::debug!(
302 count = messages.len(),
303 "tiered: DeepReasoning via query-conditioned HELA recall"
304 );
305 return Ok(messages);
306 }
307 Ok(_) => {
308 tracing::debug!("tiered: HELA returned no results, falling back to recall_routed");
309 }
310 Err(e) => {
311 tracing::warn!("tiered: HELA recall failed ({e:#}), falling back to recall_routed");
312 }
313 }
314 }
315
316 memory
318 .recall_routed(query, top_k, filter, &heuristic, None)
319 .await
320}
321
322struct ScoredCandidate {
326 recalled: RecalledMessage,
327}
328
329#[allow(clippy::too_many_lines)]
339#[tracing::instrument(name = "memory.tiered.score_candidates", skip_all)]
340async fn score_candidates(
341 memory: &SemanticMemory,
342 query: &str,
343 candidates: Vec<RecalledMessage>,
344 config: &TieredRetrievalConfig,
345) -> Result<Vec<RecalledMessage>, MemoryError> {
346 if candidates.is_empty() {
347 return Ok(candidates);
348 }
349
350 let total_weight = config.similarity_weight
351 + config.recency_weight
352 + config.tfidf_weight
353 + config.cognitive_signal_weight
354 + config.tier_boost_weight;
355
356 if total_weight < f64::EPSILON {
357 tracing::debug!("score_candidates: all signal weights are zero, returning original order");
358 return Ok(candidates);
359 }
360
361 let ids: Vec<MessageId> = candidates
362 .iter()
363 .map(|c| MessageId(c.message.metadata.db_id.unwrap_or(0)))
364 .collect();
365
366 let (timestamps_res, tiers_res) = tokio::join!(
368 async {
369 if config.recency_weight > 0.0 {
370 memory.sqlite().message_timestamps(&ids).await
371 } else {
372 Ok(HashMap::new())
373 }
374 },
375 async {
376 if config.tier_boost_weight > 0.0 {
377 memory.sqlite().fetch_tiers(&ids).await
378 } else {
379 Ok(HashMap::new())
380 }
381 },
382 );
383 let timestamps: HashMap<MessageId, i64> = timestamps_res.unwrap_or_else(|e| {
384 tracing::warn!("score_candidates: failed to fetch timestamps: {e:#}");
385 HashMap::new()
386 });
387 let tiers: HashMap<MessageId, String> = tiers_res.unwrap_or_else(|e| {
388 tracing::warn!("score_candidates: failed to fetch tiers: {e:#}");
389 HashMap::new()
390 });
391
392 let access_counts: HashMap<MessageId, i64> = if config.cognitive_signal_weight > 0.0 {
394 memory
395 .sqlite()
396 .message_access_counts(&ids)
397 .await
398 .unwrap_or_else(|e| {
399 tracing::warn!("score_candidates: failed to fetch access counts: {e:#}");
400 HashMap::new()
401 })
402 } else {
403 HashMap::new()
404 };
405
406 let tfidf_scores = if config.tfidf_weight > 0.0 {
407 compute_tfidf_scores(query, &candidates)
408 } else {
409 vec![0.0_f64; candidates.len()]
410 };
411
412 let max_access: i64 = access_counts.values().copied().max().unwrap_or(0);
413
414 let now_secs = std::time::SystemTime::now()
415 .duration_since(std::time::UNIX_EPOCH)
416 .map_or(0_i64, |d| i64::try_from(d.as_secs()).unwrap_or(i64::MAX));
417
418 let mut scored: Vec<ScoredCandidate> = candidates
419 .into_iter()
420 .zip(tfidf_scores)
421 .map(|(recalled, tfidf)| {
422 let msg_id = MessageId(recalled.message.metadata.db_id.unwrap_or(0));
423
424 let similarity = f64::from(recalled.score);
425 let recency = if config.recency_weight > 0.0 && config.recency_half_life_days > 0 {
426 let ts = timestamps.get(&msg_id).copied().unwrap_or(now_secs);
427 compute_recency(ts, now_secs, config.recency_half_life_days)
428 } else {
429 0.0
430 };
431
432 let cognitive = if config.cognitive_signal_weight > 0.0 && max_access > 0 {
433 let count = access_counts.get(&msg_id).copied().unwrap_or(0);
434 #[allow(clippy::cast_precision_loss)]
436 let ratio = count as f64 / max_access as f64;
437 ratio
438 } else {
439 0.0
440 };
441
442 let tier_signal = if config.tier_boost_weight > 0.0 {
443 let tier = tiers.get(&msg_id).map_or("episodic", String::as_str);
444 if tier == "semantic" {
445 config.semantic_tier_boost
446 } else {
447 0.0
448 }
449 } else {
450 0.0
451 };
452
453 let final_score = config.similarity_weight * similarity
454 + config.recency_weight * recency
455 + config.tfidf_weight * tfidf
456 + config.cognitive_signal_weight * cognitive
457 + config.tier_boost_weight * tier_signal;
458
459 ScoredCandidate {
460 recalled: RecalledMessage {
461 #[allow(clippy::cast_possible_truncation)]
463 score: final_score as f32,
464 ..recalled
465 },
466 }
467 })
468 .collect();
469
470 scored.sort_by(|a, b| {
471 b.recalled
472 .score
473 .partial_cmp(&a.recalled.score)
474 .unwrap_or(std::cmp::Ordering::Equal)
475 });
476
477 Ok(scored.into_iter().map(|s| s.recalled).collect())
478}
479
480fn compute_recency(created_at_secs: i64, now_secs: i64, half_life_days: u32) -> f64 {
490 debug_assert!(half_life_days > 0, "half_life_days must be > 0");
491 #[allow(clippy::cast_precision_loss)]
493 let age_days = (now_secs - created_at_secs).max(0) as f64 / 86_400.0;
494 let lambda = std::f64::consts::LN_2 / f64::from(half_life_days);
495 (-lambda * age_days).exp()
496}
497
498fn compute_tfidf_scores(query: &str, candidates: &[RecalledMessage]) -> Vec<f64> {
503 const K1: f64 = 1.2;
504 const B: f64 = 0.75;
505
506 let query_terms: Vec<String> = query.split_whitespace().map(str::to_lowercase).collect();
507
508 if query_terms.is_empty() || candidates.is_empty() {
509 return vec![0.0; candidates.len()];
510 }
511
512 let docs: Vec<Vec<String>> = candidates
514 .iter()
515 .map(|c| {
516 c.message
517 .content
518 .split_whitespace()
519 .map(str::to_lowercase)
520 .collect()
521 })
522 .collect();
523
524 #[allow(clippy::cast_precision_loss)]
526 let n = docs.len() as f64;
527 #[allow(clippy::cast_precision_loss)]
528 let avg_dl = docs.iter().map(|d| d.len() as f64).sum::<f64>().max(1.0) / n;
529
530 let mut scores = vec![0.0_f64; docs.len()];
531
532 for term in &query_terms {
533 #[allow(clippy::cast_precision_loss)]
535 let df = docs.iter().filter(|d| d.contains(term)).count() as f64;
536 if df == 0.0 {
537 continue;
538 }
539 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
541
542 for (i, doc) in docs.iter().enumerate() {
543 #[allow(clippy::cast_precision_loss)]
544 let dl = doc.len() as f64;
545 #[allow(clippy::cast_precision_loss)]
546 let tf = doc.iter().filter(|t| *t == term).count() as f64;
547 let bm25_tf = (tf * (K1 + 1.0)) / (tf + K1 * (1.0 - B + B * dl / avg_dl));
548 scores[i] += idf * bm25_tf;
549 }
550 }
551
552 let max_score = scores.iter().copied().fold(0.0_f64, f64::max);
554 if max_score > 0.0 {
555 for s in &mut scores {
556 *s /= max_score;
557 }
558 }
559
560 scores
561}
562
563fn assemble_within_budget(
568 candidates: Vec<RecalledMessage>,
569 budget: usize,
570) -> (Vec<RecalledMessage>, usize) {
571 let mut retained = Vec::with_capacity(candidates.len());
572 let mut total_tokens: usize = 0;
573
574 for msg in candidates {
575 let msg_tokens = zeph_common::text::estimate_tokens(&msg.message.content);
576 if total_tokens.saturating_add(msg_tokens) > budget {
577 break;
578 }
579 total_tokens += msg_tokens;
580 retained.push(msg);
581 }
582
583 (retained, total_tokens)
584}
585
586async fn validate_evidence(
591 provider: &Arc<AnyProvider>,
592 query: &str,
593 messages: &[RecalledMessage],
594 threshold: f32,
595 timeout_secs: u64,
596) -> bool {
597 use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
598
599 if messages.is_empty() {
600 return false;
601 }
602
603 let evidence_snippet = messages
604 .iter()
605 .take(5)
606 .map(|m| {
607 zeph_common::sanitize::strip_control_chars_preserve_whitespace(&m.message.content)
608 .chars()
609 .take(200)
610 .collect::<String>()
611 })
612 .collect::<Vec<_>>()
613 .join("\n---\n");
614
615 let system = "You are an evidence quality judge. \
616 Given a query and evidence snippets, decide if the evidence is sufficient to answer the query. \
617 Respond ONLY with a JSON object: {\"sufficient\": true|false, \"confidence\": 0.0-1.0}";
618
619 let sanitized_query = zeph_common::sanitize::strip_control_chars_preserve_whitespace(query);
620 let user = format!(
621 "<query>{}</query>\n<evidence>{}</evidence>",
622 sanitized_query.chars().take(500).collect::<String>(),
623 evidence_snippet
624 );
625
626 let msgs = vec![
627 Message {
628 role: Role::System,
629 content: system.to_owned(),
630 parts: vec![],
631 metadata: MessageMetadata::default(),
632 },
633 Message {
634 role: Role::User,
635 content: user,
636 parts: vec![],
637 metadata: MessageMetadata::default(),
638 },
639 ];
640
641 match tokio::time::timeout(
642 std::time::Duration::from_secs(timeout_secs),
643 provider.chat(&msgs),
644 )
645 .await
646 {
647 Ok(Ok(raw)) => parse_validation_response(&raw, threshold),
648 Ok(Err(e)) => {
649 tracing::warn!(error = %e, "tiered: validator LLM call failed, treating as sufficient");
650 true
651 }
652 Err(_) => {
653 tracing::warn!("tiered: validator LLM call timed out, treating as sufficient");
654 true
655 }
656 }
657}
658
659fn parse_validation_response(raw: &str, threshold: f32) -> bool {
660 let json_str = raw
661 .find('{')
662 .and_then(|s| raw[s..].rfind('}').map(|e| &raw[s..=s + e]))
663 .unwrap_or("");
664
665 if let Ok(v) = serde_json::from_str::<serde_json::Value>(json_str) {
666 let sufficient = v
667 .get("sufficient")
668 .and_then(serde_json::Value::as_bool)
669 .unwrap_or(true);
670 #[allow(clippy::cast_possible_truncation)]
671 let confidence = v
672 .get("confidence")
673 .and_then(serde_json::Value::as_f64)
674 .map_or(1.0, |c| c.clamp(0.0, 1.0) as f32);
675
676 return sufficient && confidence >= threshold;
677 }
678
679 tracing::debug!("tiered: could not parse validator response, treating as sufficient");
680 true
681}
682
683#[cfg(test)]
686mod tests {
687 use super::*;
688 use crate::router::MemoryRoute;
689 use crate::semantic::RecalledMessage;
690 use zeph_llm::provider::{Message, MessageMetadata, Role};
691
692 fn make_message(content: &str) -> RecalledMessage {
693 RecalledMessage {
694 message: Message {
695 role: Role::User,
696 content: content.to_owned(),
697 parts: vec![],
698 metadata: MessageMetadata::default(),
699 },
700 score: 1.0,
701 }
702 }
703
704 #[test]
707 fn compute_recency_zero_age_returns_one() {
708 let now = 1_000_000_i64;
709 let score = compute_recency(now, now, 7);
710 assert!((score - 1.0).abs() < 1e-9);
711 }
712
713 #[test]
714 fn compute_recency_half_life_returns_half() {
715 let now = 1_000_000_i64;
716 let half_life_days = 7_u32;
717 let age_secs = i64::from(half_life_days) * 86_400;
718 let score = compute_recency(now - age_secs, now, half_life_days);
719 assert!((score - 0.5).abs() < 1e-9);
720 }
721
722 #[test]
723 fn compute_recency_large_age_approaches_zero() {
724 let now = 1_000_i64 * 86_400;
726 let score = compute_recency(0, now, 7);
727 assert!(score < 1e-6, "score was {score}");
728 }
729
730 #[test]
731 fn compute_recency_future_timestamp_clamped_to_one() {
732 let now = 1_000_000_i64;
733 let score = compute_recency(now + 86_400, now, 7);
735 assert!((score - 1.0).abs() < 1e-9);
736 }
737
738 #[test]
739 fn compute_tfidf_empty_candidates_returns_empty() {
740 let scores = compute_tfidf_scores("hello", &[]);
741 assert!(scores.is_empty());
742 }
743
744 #[test]
745 fn compute_tfidf_empty_query_returns_zeros() {
746 let candidates = vec![make_message("hello world")];
747 let scores = compute_tfidf_scores("", &candidates);
748 assert_eq!(scores.len(), 1);
749 assert!(scores[0].abs() < f64::EPSILON);
750 }
751
752 #[test]
753 fn compute_tfidf_exact_match_scores_nonzero() {
754 let candidates = vec![
755 make_message("the quick brown fox"),
756 make_message("completely unrelated content"),
757 ];
758 let scores = compute_tfidf_scores("fox", &candidates);
759 assert_eq!(scores.len(), 2);
760 assert!(scores[0] > scores[1]);
762 }
763
764 #[test]
765 fn compute_tfidf_no_match_returns_zeros() {
766 let candidates = vec![make_message("apple banana cherry")];
767 let scores = compute_tfidf_scores("zzz xyz", &candidates);
768 assert_eq!(scores.len(), 1);
769 assert!(scores[0].abs() < f64::EPSILON);
770 }
771
772 #[test]
773 fn compute_tfidf_max_score_normalised_to_one() {
774 let candidates = vec![
775 make_message("rust programming language"),
776 make_message("python programming language"),
777 make_message("java is a drink"),
778 ];
779 let scores = compute_tfidf_scores("rust programming", &candidates);
780 let max = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
781 assert!((max - 1.0).abs() < 1e-9, "max score must be 1.0, got {max}");
782 }
783
784 #[test]
785 fn score_candidates_empty_input_returns_empty() {
786 let rt = tokio::runtime::Builder::new_current_thread()
788 .enable_all()
789 .build()
790 .unwrap();
791 rt.block_on(async {
792 let memory = crate::testing::mock_semantic_memory()
793 .await
794 .expect("mock_semantic_memory");
795 let config = TieredRetrievalConfig::default();
796 let result = score_candidates(&memory, "query", vec![], &config)
797 .await
798 .expect("score_candidates must not fail on empty input");
799 assert!(result.is_empty());
800 });
801 }
802
803 #[test]
804 fn score_candidates_similarity_weight_reorders_by_score() {
805 let rt = tokio::runtime::Builder::new_current_thread()
806 .enable_all()
807 .build()
808 .unwrap();
809 rt.block_on(async {
810 let memory = crate::testing::mock_semantic_memory()
811 .await
812 .expect("mock_semantic_memory");
813 let config = TieredRetrievalConfig {
816 similarity_weight: 1.0,
817 ..TieredRetrievalConfig::default()
818 };
819 let candidates = vec![
820 RecalledMessage {
821 message: make_message("low score").message,
822 score: 0.1,
823 },
824 RecalledMessage {
825 message: make_message("high score").message,
826 score: 0.9,
827 },
828 RecalledMessage {
829 message: make_message("mid score").message,
830 score: 0.5,
831 },
832 ];
833 let result = score_candidates(&memory, "query", candidates, &config)
834 .await
835 .expect("score_candidates must not fail");
836 assert_eq!(result.len(), 3);
837 assert!(
839 result[0].score >= result[1].score,
840 "first score {} must be >= second score {}",
841 result[0].score,
842 result[1].score
843 );
844 assert!(
845 result[1].score >= result[2].score,
846 "second score {} must be >= third score {}",
847 result[1].score,
848 result[2].score
849 );
850 assert!(
852 (result[0].score - 0.9_f32).abs() < 1e-4,
853 "expected first score ~0.9, got {}",
854 result[0].score
855 );
856 });
857 }
858
859 #[test]
860 fn score_candidates_all_zero_weights_returns_original_order() {
861 let rt = tokio::runtime::Builder::new_current_thread()
862 .enable_all()
863 .build()
864 .unwrap();
865 rt.block_on(async {
866 let memory = crate::testing::mock_semantic_memory()
867 .await
868 .expect("mock_semantic_memory");
869 let config = TieredRetrievalConfig {
871 similarity_weight: 0.0,
872 recency_weight: 0.0,
873 tfidf_weight: 0.0,
874 cognitive_signal_weight: 0.0,
875 tier_boost_weight: 0.0,
876 ..TieredRetrievalConfig::default()
877 };
878 let candidates = vec![
879 RecalledMessage {
880 message: make_message("first").message,
881 score: 0.9,
882 },
883 RecalledMessage {
884 message: make_message("second").message,
885 score: 0.1,
886 },
887 ];
888 let result = score_candidates(&memory, "query", candidates, &config)
889 .await
890 .expect("score_candidates must not fail");
891 assert!((f64::from(result[0].score) - 0.9).abs() < 1e-6);
893 assert!((f64::from(result[1].score) - 0.1).abs() < 1e-6);
894 });
895 }
896
897 #[test]
898 fn tiered_retrieval_config_signal_weight_defaults() {
899 let cfg = TieredRetrievalConfig::default();
900 assert!((cfg.similarity_weight - 1.0).abs() < f64::EPSILON);
901 assert!(cfg.recency_weight.abs() < f64::EPSILON);
902 assert_eq!(cfg.recency_half_life_days, 7);
903 assert!(cfg.tfidf_weight.abs() < f64::EPSILON);
904 assert!(cfg.cognitive_signal_weight.abs() < f64::EPSILON);
905 assert!(cfg.tier_boost_weight.abs() < f64::EPSILON);
906 assert!((cfg.semantic_tier_boost - 1.0).abs() < f64::EPSILON);
907 }
908
909 #[test]
910 fn intent_class_from_route_mapping() {
911 assert_eq!(
912 IntentClass::from_route(MemoryRoute::Keyword),
913 IntentClass::ProfileLookup
914 );
915 assert_eq!(
916 IntentClass::from_route(MemoryRoute::Episodic),
917 IntentClass::ProfileLookup
918 );
919 assert_eq!(
920 IntentClass::from_route(MemoryRoute::Semantic),
921 IntentClass::TargetedRetrieval
922 );
923 assert_eq!(
924 IntentClass::from_route(MemoryRoute::Hybrid),
925 IntentClass::TargetedRetrieval
926 );
927 assert_eq!(
928 IntentClass::from_route(MemoryRoute::Graph),
929 IntentClass::DeepReasoning
930 );
931 }
932
933 #[test]
934 fn intent_class_top_k() {
935 assert_eq!(IntentClass::ProfileLookup.top_k(), 3);
936 assert_eq!(IntentClass::TargetedRetrieval.top_k(), 10);
937 assert_eq!(IntentClass::DeepReasoning.top_k(), 20);
938 }
939
940 #[test]
941 fn intent_class_escalate_chain() {
942 assert_eq!(
943 IntentClass::ProfileLookup.escalate(),
944 Some(IntentClass::TargetedRetrieval)
945 );
946 assert_eq!(
947 IntentClass::TargetedRetrieval.escalate(),
948 Some(IntentClass::DeepReasoning)
949 );
950 assert_eq!(IntentClass::DeepReasoning.escalate(), None);
951 }
952
953 #[test]
954 fn assemble_within_budget_empty_input() {
955 let (retained, tokens) = assemble_within_budget(vec![], 4096);
956 assert!(retained.is_empty());
957 assert_eq!(tokens, 0);
958 }
959
960 #[test]
961 fn assemble_within_budget_zero_budget_returns_nothing() {
962 let candidates = vec![make_message("hello"), make_message("world")];
963 let (retained, tokens) = assemble_within_budget(candidates, 0);
964 assert!(retained.is_empty(), "budget=0 must retain no messages");
965 assert_eq!(tokens, 0);
966 }
967
968 #[test]
969 fn assemble_within_budget_truncates_at_limit() {
970 let msg = "a ".repeat(400);
973 let candidates = vec![make_message(&msg), make_message(&msg)];
974 let (retained, tokens) = assemble_within_budget(candidates, 250);
975 assert_eq!(
976 retained.len(),
977 1,
978 "tight budget must keep only first message"
979 );
980 assert_eq!(tokens, 200);
981 }
982
983 #[test]
984 fn parse_validation_response_missing_fields_defaults_to_sufficient() {
985 let raw = "{}";
987 assert!(
988 parse_validation_response(raw, 0.6),
989 "missing fields must default to sufficient"
990 );
991 }
992
993 #[test]
994 fn tiered_retrieval_config_defaults() {
995 let cfg = TieredRetrievalConfig::default();
996 assert!(!cfg.enabled);
997 assert_eq!(cfg.token_budget, 4096);
998 assert!(!cfg.validation_enabled);
999 assert_eq!(cfg.max_escalations, 1);
1000 assert_eq!(cfg.classifier_timeout_secs, 5);
1002 assert_eq!(cfg.validator_timeout_secs, 5);
1003 }
1004
1005 #[test]
1006 fn tiered_retrieval_config_timeout_fields_propagate() {
1007 let cfg = TieredRetrievalConfig {
1009 classifier_timeout_secs: 10,
1010 validator_timeout_secs: 15,
1011 ..TieredRetrievalConfig::default()
1012 };
1013 assert_eq!(cfg.classifier_timeout_secs, 10);
1014 assert_eq!(cfg.validator_timeout_secs, 15);
1015 let classifier_dur = std::time::Duration::from_secs(cfg.classifier_timeout_secs);
1017 let validator_dur = std::time::Duration::from_secs(cfg.validator_timeout_secs);
1018 assert_eq!(classifier_dur.as_secs(), 10);
1019 assert_eq!(validator_dur.as_secs(), 15);
1020 }
1021
1022 #[test]
1023 fn parse_validation_response_sufficient() {
1024 let raw = r#"{"sufficient": true, "confidence": 0.9}"#;
1025 assert!(parse_validation_response(raw, 0.6));
1026 }
1027
1028 #[test]
1029 fn parse_validation_response_insufficient() {
1030 let raw = r#"{"sufficient": false, "confidence": 0.4}"#;
1031 assert!(!parse_validation_response(raw, 0.6));
1032 }
1033
1034 #[test]
1035 fn parse_validation_response_low_confidence() {
1036 let raw = r#"{"sufficient": true, "confidence": 0.3}"#;
1037 assert!(!parse_validation_response(raw, 0.6));
1039 }
1040
1041 #[test]
1042 fn parse_validation_response_malformed_json_treats_as_sufficient() {
1043 let raw = "not json at all";
1044 assert!(parse_validation_response(raw, 0.6));
1045 }
1046
1047 #[test]
1048 fn intent_class_display() {
1049 assert_eq!(IntentClass::ProfileLookup.to_string(), "ProfileLookup");
1050 assert_eq!(
1051 IntentClass::TargetedRetrieval.to_string(),
1052 "TargetedRetrieval"
1053 );
1054 assert_eq!(IntentClass::DeepReasoning.to_string(), "DeepReasoning");
1055 }
1056
1057 #[tokio::test]
1064 async fn recall_tiered_no_classifier_uses_heuristic_router() {
1065 let memory = crate::testing::mock_semantic_memory()
1066 .await
1067 .expect("mock_semantic_memory");
1068 let config = TieredRetrievalConfig {
1069 enabled: true,
1070 validation_enabled: false,
1071 ..TieredRetrievalConfig::default()
1072 };
1073
1074 let result = recall_tiered(&memory, "what is my name", None, None, None, &config, None)
1075 .await
1076 .expect("recall_tiered must not fail");
1077
1078 assert!(
1081 !result.tier_escalated,
1082 "no escalation when validation is off"
1083 );
1084 assert!(result.tokens_used <= config.token_budget);
1085 }
1086
1087 #[tokio::test]
1092 async fn recall_tiered_with_classifier_uses_hybrid_router() {
1093 use zeph_llm::mock::MockProvider;
1094
1095 let memory = crate::testing::mock_semantic_memory()
1096 .await
1097 .expect("mock_semantic_memory");
1098
1099 let route_json = r#"{"route": "Semantic", "confidence": 0.9}"#.to_owned();
1101 let mut mock = MockProvider::with_responses(vec![route_json]);
1102 mock.supports_embeddings = true;
1103 mock.embedding = vec![0.1_f32; 384];
1104 let classifier = Arc::new(AnyProvider::Mock(mock));
1105
1106 let config = TieredRetrievalConfig {
1107 enabled: true,
1108 validation_enabled: false,
1109 ..TieredRetrievalConfig::default()
1110 };
1111
1112 let result = recall_tiered(
1113 &memory,
1114 "semantic query about the user",
1115 None,
1116 Some(&classifier),
1117 None,
1118 &config,
1119 None,
1120 )
1121 .await
1122 .expect("recall_tiered with classifier must not fail");
1123
1124 assert!(!result.tier_escalated);
1125 assert!(result.tokens_used <= config.token_budget);
1126 }
1127
1128 #[tokio::test]
1134 async fn recall_tiered_escalates_when_evidence_insufficient() {
1135 use zeph_llm::mock::MockProvider;
1136
1137 let memory = crate::testing::mock_semantic_memory()
1138 .await
1139 .expect("mock_semantic_memory");
1140
1141 let insufficient = r#"{"sufficient": false, "confidence": 0.1}"#.to_owned();
1143 let sufficient = r#"{"sufficient": true, "confidence": 0.95}"#.to_owned();
1144 let mut validator_mock = MockProvider::with_responses(vec![insufficient, sufficient]);
1145 validator_mock.supports_embeddings = true;
1146 let validator = Arc::new(AnyProvider::Mock(validator_mock));
1147
1148 let config = TieredRetrievalConfig {
1149 enabled: true,
1150 validation_enabled: true,
1151 validation_threshold: 0.6,
1152 max_escalations: 2,
1153 ..TieredRetrievalConfig::default()
1154 };
1155
1156 let result = recall_tiered(
1157 &memory,
1158 "deep query",
1159 None,
1160 None,
1161 Some(&validator),
1162 &config,
1163 None,
1164 )
1165 .await
1166 .expect("escalation path must not fail");
1167
1168 assert!(
1169 result.tier_escalated,
1170 "must set tier_escalated when validator triggers escalation"
1171 );
1172 }
1173
1174 #[tokio::test]
1179 async fn validate_evidence_timeout_is_fail_open() {
1180 use zeph_llm::mock::MockProvider;
1181
1182 let memory = crate::testing::mock_semantic_memory()
1183 .await
1184 .expect("mock_semantic_memory");
1185
1186 let conv_id = memory
1188 .sqlite()
1189 .create_conversation()
1190 .await
1191 .expect("create_conversation");
1192 memory
1193 .remember(conv_id, "user", "some evidence content", None)
1194 .await
1195 .expect("remember");
1196
1197 let slow_mock = MockProvider::default().with_delay(6_000);
1199 let validator = Arc::new(AnyProvider::Mock(slow_mock));
1200
1201 let config = TieredRetrievalConfig {
1202 enabled: true,
1203 validation_enabled: true,
1204 validation_threshold: 0.6,
1205 max_escalations: 1,
1206 validator_timeout_secs: 5,
1207 ..TieredRetrievalConfig::default()
1208 };
1209
1210 let result = recall_tiered(
1212 &memory,
1213 "evidence",
1214 None,
1215 None,
1216 Some(&validator),
1217 &config,
1218 None,
1219 )
1220 .await
1221 .expect("timeout path must not propagate as error");
1222
1223 assert!(
1225 !result.tier_escalated,
1226 "validator timeout must be treated as sufficient (fail-open)"
1227 );
1228 }
1229
1230 #[tokio::test]
1234 async fn validate_evidence_llm_error_is_fail_open() {
1235 use zeph_llm::mock::MockProvider;
1236
1237 let memory = crate::testing::mock_semantic_memory()
1238 .await
1239 .expect("mock_semantic_memory");
1240
1241 let conv_id = memory
1243 .sqlite()
1244 .create_conversation()
1245 .await
1246 .expect("create_conversation");
1247 memory
1248 .remember(conv_id, "user", "some evidence content", None)
1249 .await
1250 .expect("remember");
1251
1252 let failing_mock = MockProvider::failing();
1253 let validator = Arc::new(AnyProvider::Mock(failing_mock));
1254
1255 let config = TieredRetrievalConfig {
1256 enabled: true,
1257 validation_enabled: true,
1258 validation_threshold: 0.6,
1259 max_escalations: 1,
1260 ..TieredRetrievalConfig::default()
1261 };
1262
1263 let result = recall_tiered(
1264 &memory,
1265 "evidence",
1266 None,
1267 None,
1268 Some(&validator),
1269 &config,
1270 None,
1271 )
1272 .await
1273 .expect("LLM error path must not propagate as retrieval error");
1274
1275 assert!(
1276 !result.tier_escalated,
1277 "validator LLM error must be treated as sufficient (fail-open)"
1278 );
1279 }
1280
1281 #[tokio::test]
1286 async fn recall_tiered_with_conversation_id_filter() {
1287 let memory = crate::testing::mock_semantic_memory()
1288 .await
1289 .expect("mock_semantic_memory");
1290
1291 let conv_id = ConversationId(42);
1292 let config = TieredRetrievalConfig {
1293 enabled: true,
1294 validation_enabled: false,
1295 ..TieredRetrievalConfig::default()
1296 };
1297
1298 let result = recall_tiered(
1299 &memory,
1300 "what did we discuss",
1301 Some(conv_id),
1302 None,
1303 None,
1304 &config,
1305 None,
1306 )
1307 .await
1308 .expect("conversation-scoped recall must not fail");
1309
1310 assert!(result.messages.is_empty());
1312 assert_eq!(result.tokens_used, 0);
1313 assert!(!result.tier_escalated);
1314 }
1315
1316 #[tokio::test]
1323 async fn deep_reasoning_query_conditioned_true_falls_back_when_hela_empty() {
1324 let memory = crate::testing::mock_semantic_memory()
1325 .await
1326 .expect("mock_semantic_memory");
1327
1328 let config = TieredRetrievalConfig {
1329 deep_reasoning_query_conditioned: true,
1330 validation_enabled: false,
1331 ..TieredRetrievalConfig::default()
1332 };
1333
1334 let result = retrieve_tier(
1335 &memory,
1336 "multi-hop reasoning query",
1337 None,
1338 IntentClass::DeepReasoning,
1339 &config,
1340 )
1341 .await
1342 .expect("retrieve_tier with empty HELA must not fail");
1343
1344 assert!(
1346 result.is_empty(),
1347 "expected empty result from fallback recall_routed, got {}",
1348 result.len()
1349 );
1350 }
1351
1352 #[tokio::test]
1358 async fn deep_reasoning_query_conditioned_false_completes_without_panic() {
1359 let memory = crate::testing::mock_semantic_memory()
1360 .await
1361 .expect("mock_semantic_memory");
1362
1363 let config = TieredRetrievalConfig {
1364 deep_reasoning_query_conditioned: false,
1365 validation_enabled: false,
1366 ..TieredRetrievalConfig::default()
1367 };
1368
1369 let result = retrieve_tier(
1370 &memory,
1371 "multi-hop reasoning query",
1372 None,
1373 IntentClass::DeepReasoning,
1374 &config,
1375 )
1376 .await
1377 .expect("retrieve_tier with deep_reasoning_query_conditioned=false must not fail");
1378
1379 assert!(
1381 result.is_empty(),
1382 "expected empty result from recall_routed path, got {}",
1383 result.len()
1384 );
1385 }
1386}