1use std::sync::Arc;
32
33use tracing::Instrument as _;
34pub use zeph_config::memory::TieredRetrievalConfig;
35use zeph_llm::any::AnyProvider;
36
37use crate::embedding_store::SearchFilter;
38use crate::error::MemoryError;
39use crate::router::{HeuristicRouter, HybridRouter, MemoryRoute, MemoryRouter};
40use crate::semantic::RecalledMessage;
41use crate::semantic::SemanticMemory;
42use crate::types::ConversationId;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
50pub enum IntentClass {
51 ProfileLookup,
53 TargetedRetrieval,
55 DeepReasoning,
57}
58
59impl IntentClass {
60 fn from_route(route: MemoryRoute) -> Self {
61 match route {
62 MemoryRoute::Keyword | MemoryRoute::Episodic => Self::ProfileLookup,
63 MemoryRoute::Semantic | MemoryRoute::Hybrid => Self::TargetedRetrieval,
64 MemoryRoute::Graph => Self::DeepReasoning,
65 }
66 }
67
68 fn top_k(self) -> usize {
69 match self {
70 Self::ProfileLookup => 3,
71 Self::TargetedRetrieval => 10,
72 Self::DeepReasoning => 20,
73 }
74 }
75
76 fn escalate(self) -> Option<Self> {
78 match self {
79 Self::ProfileLookup => Some(Self::TargetedRetrieval),
80 Self::TargetedRetrieval => Some(Self::DeepReasoning),
81 Self::DeepReasoning => None,
82 }
83 }
84}
85
86impl std::fmt::Display for IntentClass {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 match self {
89 Self::ProfileLookup => f.write_str("ProfileLookup"),
90 Self::TargetedRetrieval => f.write_str("TargetedRetrieval"),
91 Self::DeepReasoning => f.write_str("DeepReasoning"),
92 }
93 }
94}
95
96#[derive(Debug)]
100pub struct TieredRetrievalResult {
101 pub messages: Vec<RecalledMessage>,
103 pub intent: IntentClass,
105 pub tokens_used: usize,
107 pub tier_escalated: bool,
109}
110
111#[tracing::instrument(name = "memory.tiered.retrieve", skip_all, fields(intent = tracing::field::Empty))]
134pub async fn recall_tiered(
135 memory: &SemanticMemory,
136 query: &str,
137 conversation_id: Option<ConversationId>,
138 classifier: Option<&Arc<AnyProvider>>,
139 validator: Option<&Arc<AnyProvider>>,
140 config: &TieredRetrievalConfig,
141 remaining_budget: Option<usize>,
142) -> Result<TieredRetrievalResult, MemoryError> {
143 let effective_budget =
144 remaining_budget.map_or(config.token_budget, |rb| rb.min(config.token_budget));
145
146 let initial_intent = if let Some(classifier_provider) = classifier {
147 let hybrid = HybridRouter::new(
148 Arc::clone(classifier_provider),
149 MemoryRoute::Hybrid,
150 0.7,
152 );
153 let decision = if let Ok(d) = tokio::time::timeout(
154 std::time::Duration::from_secs(config.classifier_timeout_secs),
155 hybrid.classify_async(query),
156 )
157 .await
158 {
159 d
160 } else {
161 tracing::warn!("tiered: classifier LLM timed out, falling back to heuristic");
162 HeuristicRouter.route_with_confidence(query)
163 };
164 IntentClass::from_route(decision.route)
165 } else {
166 let decision = HeuristicRouter.route_with_confidence(query);
167 IntentClass::from_route(decision.route)
168 };
169
170 tracing::debug!(intent = %initial_intent, query_len = query.len(), "tiered: classified intent");
171
172 escalation_loop(
173 memory,
174 query,
175 conversation_id,
176 initial_intent,
177 validator,
178 config,
179 effective_budget,
180 )
181 .await
182}
183
184async fn escalation_loop(
190 memory: &SemanticMemory,
191 query: &str,
192 conversation_id: Option<ConversationId>,
193 initial_intent: IntentClass,
194 validator: Option<&Arc<AnyProvider>>,
195 config: &TieredRetrievalConfig,
196 effective_budget: usize,
197) -> Result<TieredRetrievalResult, MemoryError> {
198 let mut intent = initial_intent;
199 let mut escalations: u8 = 0;
200 let mut tier_escalated = false;
201
202 loop {
203 let candidates = retrieve_tier(memory, query, conversation_id, intent)
204 .instrument(tracing::debug_span!("memory.tiered.retrieve_tier", tier = %intent))
205 .await?;
206
207 let (messages, tokens_used) = {
208 let _span = tracing::debug_span!("memory.tiered.assemble").entered();
209 assemble_within_budget(candidates, effective_budget)
210 };
211
212 if config.validation_enabled
214 && escalations < config.max_escalations
215 && let Some(validator_provider) = validator
216 && let Some(next_tier) = intent.escalate()
217 {
218 let sufficient = validate_evidence(
219 validator_provider,
220 query,
221 &messages,
222 config.validation_threshold,
223 config.validator_timeout_secs,
224 )
225 .instrument(tracing::debug_span!("memory.tiered.validate"))
226 .await;
227 if !sufficient {
228 tracing::debug!(
229 current_tier = %intent,
230 next_tier = %next_tier,
231 escalations,
232 "tiered: evidence insufficient, escalating tier"
233 );
234 intent = next_tier;
235 escalations += 1;
236 tier_escalated = true;
237 continue;
238 }
239 }
240
241 return Ok(TieredRetrievalResult {
242 messages,
243 intent,
244 tokens_used,
245 tier_escalated,
246 });
247 }
248}
249
250async fn retrieve_tier(
252 memory: &SemanticMemory,
253 query: &str,
254 conversation_id: Option<ConversationId>,
255 intent: IntentClass,
256) -> Result<Vec<RecalledMessage>, MemoryError> {
257 let top_k = intent.top_k();
258 let heuristic = HeuristicRouter;
259
260 let filter = conversation_id.map(|cid| SearchFilter {
261 conversation_id: Some(cid),
262 role: None,
263 category: None,
264 });
265
266 memory.recall_routed(query, top_k, filter, &heuristic).await
269}
270
271fn assemble_within_budget(
276 candidates: Vec<RecalledMessage>,
277 budget: usize,
278) -> (Vec<RecalledMessage>, usize) {
279 let mut retained = Vec::with_capacity(candidates.len());
280 let mut total_tokens: usize = 0;
281
282 for msg in candidates {
283 let msg_tokens = zeph_common::text::estimate_tokens(&msg.message.content);
284 if total_tokens.saturating_add(msg_tokens) > budget {
285 break;
286 }
287 total_tokens += msg_tokens;
288 retained.push(msg);
289 }
290
291 (retained, total_tokens)
292}
293
294async fn validate_evidence(
299 provider: &Arc<AnyProvider>,
300 query: &str,
301 messages: &[RecalledMessage],
302 threshold: f32,
303 timeout_secs: u64,
304) -> bool {
305 use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
306
307 if messages.is_empty() {
308 return false;
309 }
310
311 let evidence_snippet = messages
312 .iter()
313 .take(5)
314 .map(|m| {
315 zeph_common::sanitize::strip_control_chars_preserve_whitespace(&m.message.content)
316 .chars()
317 .take(200)
318 .collect::<String>()
319 })
320 .collect::<Vec<_>>()
321 .join("\n---\n");
322
323 let system = "You are an evidence quality judge. \
324 Given a query and evidence snippets, decide if the evidence is sufficient to answer the query. \
325 Respond ONLY with a JSON object: {\"sufficient\": true|false, \"confidence\": 0.0-1.0}";
326
327 let sanitized_query = zeph_common::sanitize::strip_control_chars_preserve_whitespace(query);
328 let user = format!(
329 "<query>{}</query>\n<evidence>{}</evidence>",
330 sanitized_query.chars().take(500).collect::<String>(),
331 evidence_snippet
332 );
333
334 let msgs = vec![
335 Message {
336 role: Role::System,
337 content: system.to_owned(),
338 parts: vec![],
339 metadata: MessageMetadata::default(),
340 },
341 Message {
342 role: Role::User,
343 content: user,
344 parts: vec![],
345 metadata: MessageMetadata::default(),
346 },
347 ];
348
349 match tokio::time::timeout(
350 std::time::Duration::from_secs(timeout_secs),
351 provider.chat(&msgs),
352 )
353 .await
354 {
355 Ok(Ok(raw)) => parse_validation_response(&raw, threshold),
356 Ok(Err(e)) => {
357 tracing::warn!(error = %e, "tiered: validator LLM call failed, treating as sufficient");
358 true
359 }
360 Err(_) => {
361 tracing::warn!("tiered: validator LLM call timed out, treating as sufficient");
362 true
363 }
364 }
365}
366
367fn parse_validation_response(raw: &str, threshold: f32) -> bool {
368 let json_str = raw
369 .find('{')
370 .and_then(|s| raw[s..].rfind('}').map(|e| &raw[s..=s + e]))
371 .unwrap_or("");
372
373 if let Ok(v) = serde_json::from_str::<serde_json::Value>(json_str) {
374 let sufficient = v
375 .get("sufficient")
376 .and_then(serde_json::Value::as_bool)
377 .unwrap_or(true);
378 #[allow(clippy::cast_possible_truncation)]
379 let confidence = v
380 .get("confidence")
381 .and_then(serde_json::Value::as_f64)
382 .map_or(1.0, |c| c.clamp(0.0, 1.0) as f32);
383
384 return sufficient && confidence >= threshold;
385 }
386
387 tracing::debug!("tiered: could not parse validator response, treating as sufficient");
388 true
389}
390
391#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::router::MemoryRoute;
397 use crate::semantic::RecalledMessage;
398 use zeph_llm::provider::{Message, MessageMetadata, Role};
399
400 fn make_message(content: &str) -> RecalledMessage {
401 RecalledMessage {
402 message: Message {
403 role: Role::User,
404 content: content.to_owned(),
405 parts: vec![],
406 metadata: MessageMetadata::default(),
407 },
408 score: 1.0,
409 }
410 }
411
412 #[test]
413 fn intent_class_from_route_mapping() {
414 assert_eq!(
415 IntentClass::from_route(MemoryRoute::Keyword),
416 IntentClass::ProfileLookup
417 );
418 assert_eq!(
419 IntentClass::from_route(MemoryRoute::Episodic),
420 IntentClass::ProfileLookup
421 );
422 assert_eq!(
423 IntentClass::from_route(MemoryRoute::Semantic),
424 IntentClass::TargetedRetrieval
425 );
426 assert_eq!(
427 IntentClass::from_route(MemoryRoute::Hybrid),
428 IntentClass::TargetedRetrieval
429 );
430 assert_eq!(
431 IntentClass::from_route(MemoryRoute::Graph),
432 IntentClass::DeepReasoning
433 );
434 }
435
436 #[test]
437 fn intent_class_top_k() {
438 assert_eq!(IntentClass::ProfileLookup.top_k(), 3);
439 assert_eq!(IntentClass::TargetedRetrieval.top_k(), 10);
440 assert_eq!(IntentClass::DeepReasoning.top_k(), 20);
441 }
442
443 #[test]
444 fn intent_class_escalate_chain() {
445 assert_eq!(
446 IntentClass::ProfileLookup.escalate(),
447 Some(IntentClass::TargetedRetrieval)
448 );
449 assert_eq!(
450 IntentClass::TargetedRetrieval.escalate(),
451 Some(IntentClass::DeepReasoning)
452 );
453 assert_eq!(IntentClass::DeepReasoning.escalate(), None);
454 }
455
456 #[test]
457 fn assemble_within_budget_empty_input() {
458 let (retained, tokens) = assemble_within_budget(vec![], 4096);
459 assert!(retained.is_empty());
460 assert_eq!(tokens, 0);
461 }
462
463 #[test]
464 fn assemble_within_budget_zero_budget_returns_nothing() {
465 let candidates = vec![make_message("hello"), make_message("world")];
466 let (retained, tokens) = assemble_within_budget(candidates, 0);
467 assert!(retained.is_empty(), "budget=0 must retain no messages");
468 assert_eq!(tokens, 0);
469 }
470
471 #[test]
472 fn assemble_within_budget_truncates_at_limit() {
473 let msg = "a ".repeat(400);
476 let candidates = vec![make_message(&msg), make_message(&msg)];
477 let (retained, tokens) = assemble_within_budget(candidates, 250);
478 assert_eq!(
479 retained.len(),
480 1,
481 "tight budget must keep only first message"
482 );
483 assert_eq!(tokens, 200);
484 }
485
486 #[test]
487 fn parse_validation_response_missing_fields_defaults_to_sufficient() {
488 let raw = "{}";
490 assert!(
491 parse_validation_response(raw, 0.6),
492 "missing fields must default to sufficient"
493 );
494 }
495
496 #[test]
497 fn tiered_retrieval_config_defaults() {
498 let cfg = TieredRetrievalConfig::default();
499 assert!(!cfg.enabled);
500 assert_eq!(cfg.token_budget, 4096);
501 assert!(!cfg.validation_enabled);
502 assert_eq!(cfg.max_escalations, 1);
503 assert_eq!(cfg.classifier_timeout_secs, 5);
505 assert_eq!(cfg.validator_timeout_secs, 5);
506 }
507
508 #[test]
509 fn tiered_retrieval_config_timeout_fields_propagate() {
510 let cfg = TieredRetrievalConfig {
512 classifier_timeout_secs: 10,
513 validator_timeout_secs: 15,
514 ..TieredRetrievalConfig::default()
515 };
516 assert_eq!(cfg.classifier_timeout_secs, 10);
517 assert_eq!(cfg.validator_timeout_secs, 15);
518 let classifier_dur = std::time::Duration::from_secs(cfg.classifier_timeout_secs);
520 let validator_dur = std::time::Duration::from_secs(cfg.validator_timeout_secs);
521 assert_eq!(classifier_dur.as_secs(), 10);
522 assert_eq!(validator_dur.as_secs(), 15);
523 }
524
525 #[test]
526 fn parse_validation_response_sufficient() {
527 let raw = r#"{"sufficient": true, "confidence": 0.9}"#;
528 assert!(parse_validation_response(raw, 0.6));
529 }
530
531 #[test]
532 fn parse_validation_response_insufficient() {
533 let raw = r#"{"sufficient": false, "confidence": 0.4}"#;
534 assert!(!parse_validation_response(raw, 0.6));
535 }
536
537 #[test]
538 fn parse_validation_response_low_confidence() {
539 let raw = r#"{"sufficient": true, "confidence": 0.3}"#;
540 assert!(!parse_validation_response(raw, 0.6));
542 }
543
544 #[test]
545 fn parse_validation_response_malformed_json_treats_as_sufficient() {
546 let raw = "not json at all";
547 assert!(parse_validation_response(raw, 0.6));
548 }
549
550 #[test]
551 fn intent_class_display() {
552 assert_eq!(IntentClass::ProfileLookup.to_string(), "ProfileLookup");
553 assert_eq!(
554 IntentClass::TargetedRetrieval.to_string(),
555 "TargetedRetrieval"
556 );
557 assert_eq!(IntentClass::DeepReasoning.to_string(), "DeepReasoning");
558 }
559
560 #[tokio::test]
567 async fn recall_tiered_no_classifier_uses_heuristic_router() {
568 let memory = crate::testing::mock_semantic_memory()
569 .await
570 .expect("mock_semantic_memory");
571 let config = TieredRetrievalConfig {
572 enabled: true,
573 validation_enabled: false,
574 ..TieredRetrievalConfig::default()
575 };
576
577 let result = recall_tiered(&memory, "what is my name", None, None, None, &config, None)
578 .await
579 .expect("recall_tiered must not fail");
580
581 assert!(
584 !result.tier_escalated,
585 "no escalation when validation is off"
586 );
587 assert!(result.tokens_used <= config.token_budget);
588 }
589
590 #[tokio::test]
595 async fn recall_tiered_with_classifier_uses_hybrid_router() {
596 use zeph_llm::mock::MockProvider;
597
598 let memory = crate::testing::mock_semantic_memory()
599 .await
600 .expect("mock_semantic_memory");
601
602 let route_json = r#"{"route": "Semantic", "confidence": 0.9}"#.to_owned();
604 let mut mock = MockProvider::with_responses(vec![route_json]);
605 mock.supports_embeddings = true;
606 mock.embedding = vec![0.1_f32; 384];
607 let classifier = Arc::new(AnyProvider::Mock(mock));
608
609 let config = TieredRetrievalConfig {
610 enabled: true,
611 validation_enabled: false,
612 ..TieredRetrievalConfig::default()
613 };
614
615 let result = recall_tiered(
616 &memory,
617 "semantic query about the user",
618 None,
619 Some(&classifier),
620 None,
621 &config,
622 None,
623 )
624 .await
625 .expect("recall_tiered with classifier must not fail");
626
627 assert!(!result.tier_escalated);
628 assert!(result.tokens_used <= config.token_budget);
629 }
630
631 #[tokio::test]
637 async fn recall_tiered_escalates_when_evidence_insufficient() {
638 use zeph_llm::mock::MockProvider;
639
640 let memory = crate::testing::mock_semantic_memory()
641 .await
642 .expect("mock_semantic_memory");
643
644 let insufficient = r#"{"sufficient": false, "confidence": 0.1}"#.to_owned();
646 let sufficient = r#"{"sufficient": true, "confidence": 0.95}"#.to_owned();
647 let mut validator_mock = MockProvider::with_responses(vec![insufficient, sufficient]);
648 validator_mock.supports_embeddings = true;
649 let validator = Arc::new(AnyProvider::Mock(validator_mock));
650
651 let config = TieredRetrievalConfig {
652 enabled: true,
653 validation_enabled: true,
654 validation_threshold: 0.6,
655 max_escalations: 2,
656 ..TieredRetrievalConfig::default()
657 };
658
659 let result = recall_tiered(
660 &memory,
661 "deep query",
662 None,
663 None,
664 Some(&validator),
665 &config,
666 None,
667 )
668 .await
669 .expect("escalation path must not fail");
670
671 assert!(
672 result.tier_escalated,
673 "must set tier_escalated when validator triggers escalation"
674 );
675 }
676
677 #[tokio::test]
682 async fn validate_evidence_timeout_is_fail_open() {
683 use zeph_llm::mock::MockProvider;
684
685 let memory = crate::testing::mock_semantic_memory()
686 .await
687 .expect("mock_semantic_memory");
688
689 let conv_id = memory
691 .sqlite()
692 .create_conversation()
693 .await
694 .expect("create_conversation");
695 memory
696 .remember(conv_id, "user", "some evidence content", None)
697 .await
698 .expect("remember");
699
700 let slow_mock = MockProvider::default().with_delay(6_000);
702 let validator = Arc::new(AnyProvider::Mock(slow_mock));
703
704 let config = TieredRetrievalConfig {
705 enabled: true,
706 validation_enabled: true,
707 validation_threshold: 0.6,
708 max_escalations: 1,
709 validator_timeout_secs: 5,
710 ..TieredRetrievalConfig::default()
711 };
712
713 let result = recall_tiered(
715 &memory,
716 "evidence",
717 None,
718 None,
719 Some(&validator),
720 &config,
721 None,
722 )
723 .await
724 .expect("timeout path must not propagate as error");
725
726 assert!(
728 !result.tier_escalated,
729 "validator timeout must be treated as sufficient (fail-open)"
730 );
731 }
732
733 #[tokio::test]
737 async fn validate_evidence_llm_error_is_fail_open() {
738 use zeph_llm::mock::MockProvider;
739
740 let memory = crate::testing::mock_semantic_memory()
741 .await
742 .expect("mock_semantic_memory");
743
744 let conv_id = memory
746 .sqlite()
747 .create_conversation()
748 .await
749 .expect("create_conversation");
750 memory
751 .remember(conv_id, "user", "some evidence content", None)
752 .await
753 .expect("remember");
754
755 let failing_mock = MockProvider::failing();
756 let validator = Arc::new(AnyProvider::Mock(failing_mock));
757
758 let config = TieredRetrievalConfig {
759 enabled: true,
760 validation_enabled: true,
761 validation_threshold: 0.6,
762 max_escalations: 1,
763 ..TieredRetrievalConfig::default()
764 };
765
766 let result = recall_tiered(
767 &memory,
768 "evidence",
769 None,
770 None,
771 Some(&validator),
772 &config,
773 None,
774 )
775 .await
776 .expect("LLM error path must not propagate as retrieval error");
777
778 assert!(
779 !result.tier_escalated,
780 "validator LLM error must be treated as sufficient (fail-open)"
781 );
782 }
783
784 #[tokio::test]
789 async fn recall_tiered_with_conversation_id_filter() {
790 let memory = crate::testing::mock_semantic_memory()
791 .await
792 .expect("mock_semantic_memory");
793
794 let conv_id = ConversationId(42);
795 let config = TieredRetrievalConfig {
796 enabled: true,
797 validation_enabled: false,
798 ..TieredRetrievalConfig::default()
799 };
800
801 let result = recall_tiered(
802 &memory,
803 "what did we discuss",
804 Some(conv_id),
805 None,
806 None,
807 &config,
808 None,
809 )
810 .await
811 .expect("conversation-scoped recall must not fail");
812
813 assert!(result.messages.is_empty());
815 assert_eq!(result.tokens_used, 0);
816 assert!(!result.tier_escalated);
817 }
818}