1use std::future::Future;
17use std::pin::Pin;
18
19use futures::StreamExt as _;
20use futures::stream::FuturesUnordered;
21
22use zeph_common::memory::{AsyncMemoryRouter, CompressionLevel, GraphRecallParams, TokenCounting};
23use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
24
25use crate::error::ContextError;
26use crate::input::ContextAssemblyInput;
27use crate::slot::ContextSlot;
28
29pub(crate) fn levels_to_flags(levels: &[CompressionLevel]) -> (bool, bool, bool) {
37 if levels.is_empty() {
38 return (true, true, true);
39 }
40 let episodic = levels.contains(&CompressionLevel::Episodic);
41 let procedural = levels.contains(&CompressionLevel::Procedural);
42 let declarative = levels.contains(&CompressionLevel::Declarative);
43 (episodic, procedural, declarative)
44}
45
46pub const SUMMARY_PREFIX: &str = "[conversation summaries]\n";
48pub const CROSS_SESSION_PREFIX: &str = "[cross-session context]\n";
50pub const RECALL_PREFIX: &str = "[semantic recall]\n";
52pub const CORRECTIONS_PREFIX: &str = "[past corrections]\n";
54pub const DOCUMENT_RAG_PREFIX: &str = "## Relevant documents\n";
56pub const GRAPH_FACTS_PREFIX: &str = "[known facts]\n";
58
59pub struct PreparedContext {
64 pub graph_facts: Option<Message>,
66 pub doc_rag: Option<Message>,
68 pub corrections: Option<Message>,
70 pub recall: Option<Message>,
72 pub recall_confidence: Option<f32>,
74 pub cross_session: Option<Message>,
76 pub summaries: Option<Message>,
78 pub code_context: Option<String>,
80 pub persona_facts: Option<Message>,
82 pub trajectory_hints: Option<Message>,
84 pub tree_memory: Option<Message>,
86 pub reasoning_hints: Option<Message>,
88 pub memory_first: bool,
90 pub recent_history_budget: usize,
92}
93
94pub struct ContextAssembler;
98
99type CtxFuture<'a> = Pin<Box<dyn Future<Output = Result<ContextSlot, ContextError>> + Send + 'a>>;
100
101fn empty_prepared_context() -> PreparedContext {
102 PreparedContext {
103 graph_facts: None,
104 doc_rag: None,
105 corrections: None,
106 recall: None,
107 recall_confidence: None,
108 cross_session: None,
109 summaries: None,
110 code_context: None,
111 persona_facts: None,
112 trajectory_hints: None,
113 tree_memory: None,
114 reasoning_hints: None,
115 memory_first: false,
116 recent_history_budget: 0,
117 }
118}
119
120fn resolve_effective_strategy(
123 memory: &crate::input::ContextMemoryView,
124 sidequest_turn_counter: u64,
125) -> zeph_config::ContextStrategy {
126 match memory.context_strategy {
127 zeph_config::ContextStrategy::FullHistory => zeph_config::ContextStrategy::FullHistory,
128 zeph_config::ContextStrategy::MemoryFirst => zeph_config::ContextStrategy::MemoryFirst,
129 zeph_config::ContextStrategy::Adaptive => {
130 if sidequest_turn_counter >= u64::from(memory.crossover_turn_threshold) {
131 zeph_config::ContextStrategy::MemoryFirst
132 } else {
133 zeph_config::ContextStrategy::FullHistory
134 }
135 }
136 }
137}
138
139fn correction_params(cfg: Option<&crate::input::CorrectionConfig>) -> (usize, f32) {
140 cfg.filter(|c| c.correction_detection)
141 .map_or((3, 0.75), |c| {
142 (
143 c.correction_recall_limit as usize,
144 c.correction_min_similarity,
145 )
146 })
147}
148
149#[allow(clippy::too_many_arguments)]
156fn schedule_context_fetchers<'r>(
157 memory: &'r crate::input::ContextMemoryView,
158 tc: &'r dyn TokenCounting,
159 query: &'r str,
160 scrub: fn(&str) -> std::borrow::Cow<'_, str>,
161 index: Option<&'r dyn crate::input::IndexAccess>,
162 router_ref: &'r dyn AsyncMemoryRouter,
163 summaries_budget: usize,
164 cross_session_budget: usize,
165 semantic_recall_budget: usize,
166 code_context_budget: usize,
167 graph_facts_budget: usize,
168 recall_limit: usize,
169 min_sim: f32,
170 active_levels: &[CompressionLevel],
171) -> FuturesUnordered<CtxFuture<'r>> {
172 let (episodic_active, procedural_active, declarative_active) = levels_to_flags(active_levels);
176
177 let fetchers: FuturesUnordered<CtxFuture<'r>> = FuturesUnordered::new();
178
179 if episodic_active && summaries_budget > 0 {
180 fetchers.push(Box::pin(async move {
181 fetch_summaries(memory, summaries_budget, tc)
182 .await
183 .map(ContextSlot::Summaries)
184 }));
185 }
186 if episodic_active && cross_session_budget > 0 {
187 fetchers.push(Box::pin(async move {
188 fetch_cross_session(memory, query, cross_session_budget, tc)
189 .await
190 .map(ContextSlot::CrossSession)
191 }));
192 }
193 if episodic_active && semantic_recall_budget > 0 {
194 fetchers.push(Box::pin(async move {
195 fetch_semantic_recall(memory, query, semantic_recall_budget, tc, Some(router_ref))
196 .await
197 .map(|(msg, score)| ContextSlot::SemanticRecall(msg, score))
198 }));
199 fetchers.push(Box::pin(async move {
200 fetch_document_rag(memory, query, semantic_recall_budget, tc)
201 .await
202 .map(ContextSlot::DocumentRag)
203 }));
204 }
205 fetchers.push(Box::pin(async move {
207 fetch_corrections(memory, query, recall_limit, min_sim, scrub)
208 .await
209 .map(ContextSlot::Corrections)
210 }));
211 if code_context_budget > 0
213 && let Some(idx) = index
214 {
215 fetchers.push(Box::pin(async move {
216 let result: Result<Option<String>, ContextError> =
217 idx.fetch_code_rag(query, code_context_budget).await;
218 result.map(ContextSlot::CodeContext)
219 }));
220 }
221 if declarative_active && graph_facts_budget > 0 {
222 fetchers.push(Box::pin(async move {
223 fetch_graph_facts(memory, query, graph_facts_budget, tc)
224 .await
225 .map(ContextSlot::GraphFacts)
226 }));
227 }
228 if declarative_active && memory.persona_config.context_budget_tokens > 0 {
229 fetchers.push(Box::pin(async move {
230 let persona_budget = memory.persona_config.context_budget_tokens;
231 fetch_persona_facts(memory, persona_budget, tc)
232 .await
233 .map(ContextSlot::PersonaFacts)
234 }));
235 }
236 if procedural_active && memory.trajectory_config.context_budget_tokens > 0 {
237 fetchers.push(Box::pin(async move {
238 let tbudget = memory.trajectory_config.context_budget_tokens;
239 fetch_trajectory_hints(memory, tbudget, tc)
240 .await
241 .map(ContextSlot::TrajectoryHints)
242 }));
243 }
244 if declarative_active && memory.tree_config.context_budget_tokens > 0 {
245 fetchers.push(Box::pin(async move {
246 let tbudget = memory.tree_config.context_budget_tokens;
247 fetch_tree_memory(memory, tbudget, tc)
248 .await
249 .map(ContextSlot::TreeMemory)
250 }));
251 }
252 if procedural_active
253 && memory.reasoning_config.enabled
254 && memory.reasoning_config.context_budget_tokens > 0
255 {
256 fetchers.push(Box::pin(async move {
257 let rbudget = memory.reasoning_config.context_budget_tokens;
258 let top_k = memory.reasoning_config.top_k;
259 fetch_reasoning_strategies(memory, query, rbudget, top_k, tc)
260 .await
261 .map(ContextSlot::ReasoningStrategies)
262 }));
263 }
264
265 fetchers
266}
267
268async fn drive_fetchers(
269 mut fetchers: FuturesUnordered<CtxFuture<'_>>,
270 prepared: &mut PreparedContext,
271) -> Result<(), ContextError> {
272 while let Some(result) = fetchers.next().await {
273 match result {
274 Ok(slot) => match slot {
275 ContextSlot::Summaries(msg) => prepared.summaries = msg,
276 ContextSlot::CrossSession(msg) => prepared.cross_session = msg,
277 ContextSlot::SemanticRecall(msg, score) => {
278 prepared.recall = msg;
279 prepared.recall_confidence = score;
280 }
281 ContextSlot::DocumentRag(msg) => prepared.doc_rag = msg,
282 ContextSlot::Corrections(msg) => prepared.corrections = msg,
283 ContextSlot::CodeContext(text) => prepared.code_context = text,
284 ContextSlot::GraphFacts(msg) => prepared.graph_facts = msg,
285 ContextSlot::PersonaFacts(msg) => prepared.persona_facts = msg,
286 ContextSlot::TrajectoryHints(msg) => prepared.trajectory_hints = msg,
287 ContextSlot::TreeMemory(msg) => prepared.tree_memory = msg,
288 ContextSlot::ReasoningStrategies(msg) => prepared.reasoning_hints = msg,
289 },
290 Err(e) => return Err(e),
291 }
292 }
293 Ok(())
294}
295
296impl ContextAssembler {
297 pub async fn gather(input: &ContextAssemblyInput<'_>) -> Result<PreparedContext, ContextError> {
305 let Some(ref budget) = input.context_manager.budget else {
306 return Ok(empty_prepared_context());
307 };
308
309 let memory = input.memory;
310 let tc = input.token_counter;
311
312 let effective_strategy = resolve_effective_strategy(memory, input.sidequest_turn_counter);
313 let memory_first = effective_strategy == zeph_config::ContextStrategy::MemoryFirst;
314
315 let system_prompt = input
316 .messages
317 .first()
318 .filter(|m| m.role == Role::System)
319 .map_or("", |m| m.content.as_str());
320
321 let digest_tokens = memory
322 .cached_session_digest
323 .as_ref()
324 .map_or(0, |(_, tokens)| *tokens);
325
326 let alloc = budget.allocate_with_opts(
327 system_prompt,
328 input.skills_prompt,
329 tc,
330 memory.graph_config.enabled,
331 digest_tokens,
332 memory_first,
333 );
334
335 let (recall_limit, min_sim) = correction_params(input.correction_config.as_ref());
336
337 let router_ref: &dyn AsyncMemoryRouter = input.router.as_ref();
338
339 tracing::debug!(
340 active_sources = alloc.active_sources(),
341 active_levels = ?input.active_levels,
342 "context budget allocated"
343 );
344
345 let fetchers = schedule_context_fetchers(
346 memory,
347 tc,
348 input.query,
349 input.scrub,
350 input.index,
351 router_ref,
352 alloc.summaries,
353 alloc.cross_session,
354 alloc.semantic_recall,
355 alloc.code_context,
356 alloc.graph_facts,
357 recall_limit,
358 min_sim,
359 input.active_levels,
360 );
361
362 let mut prepared = empty_prepared_context();
363 prepared.memory_first = memory_first;
364 prepared.recent_history_budget = alloc.recent_history;
365
366 drive_fetchers(fetchers, &mut prepared).await?;
367 Ok(prepared)
368 }
369}
370
371pub fn effective_recall_timeout_ms(configured: u64) -> u64 {
376 if configured == 0 {
377 tracing::warn!(
378 "recall_timeout_ms is 0, which would disable spreading activation recall; \
379 clamping to 100ms"
380 );
381 100
382 } else {
383 configured
384 }
385}
386
387use crate::input::ContextMemoryView;
388
389#[allow(clippy::too_many_lines)] pub(crate) async fn fetch_graph_facts(
391 memory: &ContextMemoryView,
392 query: &str,
393 budget_tokens: usize,
394 tc: &dyn TokenCounting,
395) -> Result<Option<Message>, ContextError> {
396 use zeph_common::memory::{RecallView, SpreadingActivationParams, classify_graph_subgraph};
397
398 if budget_tokens == 0 || !memory.graph_config.enabled {
399 return Ok(None);
400 }
401 let Some(ref mem) = memory.memory else {
402 return Ok(None);
403 };
404 let recall_limit = memory.graph_config.recall_limit;
405 let temporal_decay_rate = memory.graph_config.temporal_decay_rate;
406 let sa_config = &memory.graph_config.spreading_activation;
407
408 let fused_query;
410 let effective_query = if let Some(ref state) = memory.memcot_state {
411 let max_state_chars = 2 * query.len();
412 let state_slice = if state.len() > max_state_chars {
413 let boundary = state.floor_char_boundary(max_state_chars);
414 &state[..boundary]
415 } else {
416 state.as_str()
417 };
418 fused_query = format!("[state] {state_slice}\n{query}");
419 &fused_query as &str
420 } else {
421 query
422 };
423
424 let edge_types = classify_graph_subgraph(effective_query);
425
426 let view = match memory.memcot_config.recall_view {
427 zeph_config::RecallViewConfig::Head => RecallView::Head,
428 zeph_config::RecallViewConfig::ZoomIn => RecallView::ZoomIn,
429 zeph_config::RecallViewConfig::ZoomOut => RecallView::ZoomOut,
430 };
431
432 let sa_params = if sa_config.enabled {
433 Some(SpreadingActivationParams {
434 decay_lambda: sa_config.decay_lambda,
435 max_hops: sa_config.max_hops,
436 activation_threshold: sa_config.activation_threshold,
437 inhibition_threshold: sa_config.inhibition_threshold,
438 max_activated_nodes: sa_config.max_activated_nodes,
439 temporal_decay_rate,
440 seed_structural_weight: sa_config.seed_structural_weight,
441 seed_community_cap: sa_config.seed_community_cap,
442 })
443 } else {
444 None
445 };
446
447 let timeout_ms = effective_recall_timeout_ms(sa_config.recall_timeout_ms);
448 let recall_fut = mem.recall_graph_facts(
449 effective_query,
450 GraphRecallParams {
451 limit: recall_limit,
452 view,
453 zoom_out_neighbor_cap: memory.memcot_config.zoom_out_neighbor_cap,
454 max_hops: memory.graph_config.max_hops,
455 temporal_decay_rate,
456 edge_types: &edge_types,
457 spreading_activation: sa_params,
458 },
459 );
460 let recalled = match tokio::time::timeout(
461 std::time::Duration::from_millis(timeout_ms),
462 recall_fut,
463 )
464 .await
465 {
466 Ok(Ok(facts)) => facts,
467 Ok(Err(e)) => {
468 tracing::warn!("graph recall failed: {e:#}");
469 Vec::new()
470 }
471 Err(_) => {
472 tracing::warn!("graph recall timed out ({timeout_ms}ms)");
473 Vec::new()
474 }
475 };
476
477 if recalled.is_empty() {
478 return Ok(None);
479 }
480
481 let mut body = String::from(GRAPH_FACTS_PREFIX);
482 let mut tokens_so_far = tc.count_tokens(&body);
483
484 for rf in &recalled {
485 let fact_text = rf.fact.replace(['\n', '\r', '<', '>'], " ");
486 let line = if let Some(score) = rf.activation_score {
487 format!(
488 "- {} (confidence: {:.2}, activation: {:.2})\n",
489 fact_text, rf.confidence, score
490 )
491 } else {
492 format!("- {} (confidence: {:.2})\n", fact_text, rf.confidence)
493 };
494 let line_tokens = tc.count_tokens(&line);
495 if tokens_so_far + line_tokens > budget_tokens {
496 break;
497 }
498 body.push_str(&line);
499 tokens_so_far += line_tokens;
500
501 for nb in &rf.neighbors {
503 let nb_text = nb.fact.replace(['\n', '\r', '<', '>'], " ");
504 let nb_line = format!(" ~ {} (confidence: {:.2})\n", nb_text, nb.confidence);
505 let nb_tokens = tc.count_tokens(&nb_line);
506 if tokens_so_far + nb_tokens > budget_tokens {
507 break;
508 }
509 body.push_str(&nb_line);
510 tokens_so_far += nb_tokens;
511 }
512
513 if let Some(ref snippet) = rf.provenance_snippet {
515 let snip_line = format!(
516 " [source: {}]\n",
517 snippet.replace(['\n', '\r', '<', '>'], " ")
518 );
519 let snip_tokens = tc.count_tokens(&snip_line);
520 if tokens_so_far + snip_tokens <= budget_tokens {
521 body.push_str(&snip_line);
522 tokens_so_far += snip_tokens;
523 }
524 }
525 }
526
527 if body == GRAPH_FACTS_PREFIX {
528 return Ok(None);
529 }
530
531 Ok(Some(Message::from_legacy(Role::System, body)))
532}
533
534pub(crate) async fn fetch_persona_facts(
535 memory: &ContextMemoryView,
536 budget_tokens: usize,
537 tc: &dyn TokenCounting,
538) -> Result<Option<Message>, ContextError> {
539 if budget_tokens == 0 || !memory.persona_config.enabled {
540 return Ok(None);
541 }
542 let Some(ref mem) = memory.memory else {
543 return Ok(None);
544 };
545
546 let min_confidence = memory.persona_config.min_confidence;
547 let facts = mem
548 .load_persona_facts(min_confidence)
549 .await
550 .map_err(ContextError::Memory)?;
551
552 if facts.is_empty() {
553 return Ok(None);
554 }
555
556 let mut body = String::from(crate::slot::PERSONA_PREFIX);
557 let mut tokens_so_far = tc.count_tokens(&body);
558
559 for fact in &facts {
560 let line = format!("[{}] {}\n", fact.category, fact.content);
561 let line_tokens = tc.count_tokens(&line);
562 if tokens_so_far + line_tokens > budget_tokens {
563 break;
564 }
565 body.push_str(&line);
566 tokens_so_far += line_tokens;
567 }
568
569 if body == crate::slot::PERSONA_PREFIX {
570 return Ok(None);
571 }
572
573 Ok(Some(Message::from_legacy(Role::System, body)))
574}
575
576pub(crate) async fn fetch_trajectory_hints(
577 memory: &ContextMemoryView,
578 budget_tokens: usize,
579 tc: &dyn TokenCounting,
580) -> Result<Option<Message>, ContextError> {
581 if budget_tokens == 0 || !memory.trajectory_config.enabled {
582 return Ok(None);
583 }
584 let Some(ref mem) = memory.memory else {
585 return Ok(None);
586 };
587
588 let top_k = memory.trajectory_config.recall_top_k;
589 let min_conf = memory.trajectory_config.min_confidence;
590 let entries = mem
594 .load_trajectory_entries(Some("procedural"), top_k)
595 .await
596 .map_err(ContextError::Memory)?;
597
598 if entries.is_empty() {
599 return Ok(None);
600 }
601
602 let mut body = String::from(crate::slot::TRAJECTORY_PREFIX);
603 let mut tokens_so_far = tc.count_tokens(&body);
604
605 for entry in entries
606 .iter()
607 .filter(|e| e.confidence >= min_conf)
608 .take(top_k)
609 {
610 let line = format!("- {}: {}\n", entry.intent, entry.outcome);
611 let line_tokens = tc.count_tokens(&line);
612 if tokens_so_far + line_tokens > budget_tokens {
613 break;
614 }
615 body.push_str(&line);
616 tokens_so_far += line_tokens;
617 }
618
619 if body == crate::slot::TRAJECTORY_PREFIX {
620 return Ok(None);
621 }
622
623 Ok(Some(Message::from_legacy(Role::System, body)))
624}
625
626pub(crate) async fn fetch_tree_memory(
627 memory: &ContextMemoryView,
628 budget_tokens: usize,
629 tc: &dyn TokenCounting,
630) -> Result<Option<Message>, ContextError> {
631 if budget_tokens == 0 || !memory.tree_config.enabled {
632 return Ok(None);
633 }
634 let Some(ref mem) = memory.memory else {
635 return Ok(None);
636 };
637
638 let top_k = memory.tree_config.recall_top_k;
639 let nodes = mem
640 .load_tree_nodes(1, top_k)
641 .await
642 .map_err(ContextError::Memory)?;
643
644 if nodes.is_empty() {
645 return Ok(None);
646 }
647
648 let mut body = String::from(crate::slot::TREE_MEMORY_PREFIX);
649 let mut tokens_so_far = tc.count_tokens(&body);
650
651 for node in nodes.iter().take(top_k) {
652 let line = format!("- {}\n", node.content);
653 let line_tokens = tc.count_tokens(&line);
654 if tokens_so_far + line_tokens > budget_tokens {
655 break;
656 }
657 body.push_str(&line);
658 tokens_so_far += line_tokens;
659 }
660
661 if body == crate::slot::TREE_MEMORY_PREFIX {
662 return Ok(None);
663 }
664
665 Ok(Some(Message::from_legacy(Role::System, body)))
666}
667
668pub(crate) async fn fetch_reasoning_strategies(
669 memory: &ContextMemoryView,
670 query: &str,
671 budget_tokens: usize,
672 top_k: usize,
673 tc: &dyn TokenCounting,
674) -> Result<Option<Message>, ContextError> {
675 let budget_tokens = budget_tokens.min(500);
677 if budget_tokens == 0 {
678 return Ok(None);
679 }
680 let Some(ref mem) = memory.memory else {
681 return Ok(None);
682 };
683
684 let strategies = mem
685 .retrieve_reasoning_strategies(query, top_k)
686 .await
687 .map_err(ContextError::Memory)?;
688
689 if strategies.is_empty() {
690 return Ok(None);
691 }
692
693 let mut body = String::from(crate::slot::REASONING_PREFIX);
694 let mut tokens_so_far = tc.count_tokens(&body);
695 let mut injected_ids: Vec<String> = Vec::new();
696
697 for s in strategies.iter().take(top_k) {
698 let safe_summary = s.summary.replace(['\n', '\r', '<', '>'], " ");
701 let line = format!("- [{}] {}\n", s.outcome, safe_summary);
702 let line_tokens = tc.count_tokens(&line);
703 if tokens_so_far + line_tokens > budget_tokens {
704 break;
705 }
706 body.push_str(&line);
707 tokens_so_far += line_tokens;
708 injected_ids.push(s.id.clone());
709 }
710
711 if body == crate::slot::REASONING_PREFIX {
712 return Ok(None);
713 }
714
715 if !injected_ids.is_empty() {
718 let mem_clone = mem.clone();
719 tokio::spawn(async move {
720 if let Err(e) = mem_clone.mark_reasoning_used(&injected_ids).await {
721 tracing::warn!(error = %e, "reasoning: mark_used failed");
722 }
723 });
724 }
725
726 Ok(Some(Message::from_legacy(Role::System, body)))
727}
728
729pub(crate) async fn fetch_corrections(
730 memory: &ContextMemoryView,
731 query: &str,
732 limit: usize,
733 min_score: f32,
734 scrub: fn(&str) -> std::borrow::Cow<'_, str>,
735) -> Result<Option<Message>, ContextError> {
736 let Some(ref mem) = memory.memory else {
737 return Ok(None);
738 };
739 let corrections = mem
740 .retrieve_corrections(query, limit, min_score)
741 .await
742 .unwrap_or_default();
743 if corrections.is_empty() {
744 return Ok(None);
745 }
746 let mut text = String::from(CORRECTIONS_PREFIX);
747 for c in &corrections {
748 text.push_str("- Past user correction: \"");
749 text.push_str(&scrub(&c.correction_text));
750 text.push_str("\"\n");
751 }
752 Ok(Some(Message::from_legacy(Role::System, text)))
753}
754
755pub(crate) async fn fetch_semantic_recall(
756 memory: &ContextMemoryView,
757 query: &str,
758 token_budget: usize,
759 tc: &dyn TokenCounting,
760 router: Option<&dyn AsyncMemoryRouter>,
761) -> Result<(Option<Message>, Option<f32>), ContextError> {
762 let Some(ref mem) = memory.memory else {
763 return Ok((None, None));
764 };
765 if memory.recall_limit == 0 || token_budget == 0 {
766 return Ok((None, None));
767 }
768
769 let recalled = mem
770 .recall(query, memory.recall_limit, router)
771 .await
772 .map_err(ContextError::Memory)?;
773 if recalled.is_empty() {
774 return Ok((None, None));
775 }
776
777 let top_score = recalled.first().map(|r| r.score);
778
779 let mut recall_text = String::with_capacity(token_budget * 3);
780 recall_text.push_str(RECALL_PREFIX);
781 let mut tokens_used = tc.count_tokens(&recall_text);
782
783 for item in &recalled {
784 if item.content.starts_with("[skipped]") || item.content.starts_with("[stopped]") {
785 continue;
786 }
787 let entry = format!("- [{}] {}\n", item.role, item.content);
788 let entry_tokens = tc.count_tokens(&entry);
789 if tokens_used + entry_tokens > token_budget {
790 break;
791 }
792 recall_text.push_str(&entry);
793 tokens_used += entry_tokens;
794 }
795
796 if tokens_used > tc.count_tokens(RECALL_PREFIX) {
797 Ok((
798 Some(Message::from_parts(
799 Role::System,
800 vec![MessagePart::Recall { text: recall_text }],
801 )),
802 top_score,
803 ))
804 } else {
805 Ok((None, None))
806 }
807}
808
809pub(crate) async fn fetch_document_rag(
810 memory: &ContextMemoryView,
811 query: &str,
812 token_budget: usize,
813 tc: &dyn TokenCounting,
814) -> Result<Option<Message>, ContextError> {
815 if !memory.document_config.rag_enabled || token_budget == 0 {
816 return Ok(None);
817 }
818 let Some(ref mem) = memory.memory else {
819 return Ok(None);
820 };
821
822 let collection = &memory.document_config.collection;
823 let top_k = memory.document_config.top_k;
824 let chunks = mem
825 .search_document_collection(collection, query, top_k)
826 .await
827 .map_err(ContextError::Memory)?;
828 if chunks.is_empty() {
829 return Ok(None);
830 }
831
832 let mut text = String::from(DOCUMENT_RAG_PREFIX);
833 let mut tokens_used = tc.count_tokens(&text);
834
835 for chunk in &chunks {
836 if chunk.text.is_empty() {
837 continue;
838 }
839 let entry = format!("{}\n", chunk.text);
840 let cost = tc.count_tokens(&entry);
841 if tokens_used + cost > token_budget {
842 break;
843 }
844 text.push_str(&entry);
845 tokens_used += cost;
846 }
847
848 if tokens_used > tc.count_tokens(DOCUMENT_RAG_PREFIX) {
849 Ok(Some(Message {
850 role: Role::System,
851 content: text,
852 parts: vec![],
853 metadata: MessageMetadata::default(),
854 }))
855 } else {
856 Ok(None)
857 }
858}
859
860pub(crate) async fn fetch_summaries(
861 memory: &ContextMemoryView,
862 token_budget: usize,
863 tc: &dyn TokenCounting,
864) -> Result<Option<Message>, ContextError> {
865 let (Some(mem), Some(cid)) = (&memory.memory, memory.conversation_id) else {
866 return Ok(None);
867 };
868 if token_budget == 0 {
869 return Ok(None);
870 }
871
872 let summaries = mem
873 .load_summaries(cid)
874 .await
875 .map_err(ContextError::Memory)?;
876 if summaries.is_empty() {
877 return Ok(None);
878 }
879
880 let mut summary_text = String::from(SUMMARY_PREFIX);
881 let mut tokens_used = tc.count_tokens(&summary_text);
882
883 for summary in summaries.iter().rev() {
884 let first = summary.first_message_id.unwrap_or(0);
885 let last = summary.last_message_id.unwrap_or(0);
886 let entry = format!("- Messages {first}-{last}: {}\n", summary.content);
887 let cost = tc.count_tokens(&entry);
888 if tokens_used + cost > token_budget {
889 break;
890 }
891 summary_text.push_str(&entry);
892 tokens_used += cost;
893 }
894
895 if tokens_used > tc.count_tokens(SUMMARY_PREFIX) {
896 Ok(Some(Message::from_parts(
897 Role::System,
898 vec![MessagePart::Summary { text: summary_text }],
899 )))
900 } else {
901 Ok(None)
902 }
903}
904
905pub(crate) async fn fetch_cross_session(
906 memory: &ContextMemoryView,
907 query: &str,
908 token_budget: usize,
909 tc: &dyn TokenCounting,
910) -> Result<Option<Message>, ContextError> {
911 let (Some(mem), Some(cid)) = (&memory.memory, memory.conversation_id) else {
912 return Ok(None);
913 };
914 if token_budget == 0 {
915 return Ok(None);
916 }
917
918 let threshold = memory.cross_session_score_threshold;
919 let results: Vec<_> = mem
920 .search_session_summaries(query, 5, Some(cid))
921 .await
922 .map_err(ContextError::Memory)?
923 .into_iter()
924 .filter(|r| r.score >= threshold)
925 .collect();
926 if results.is_empty() {
927 return Ok(None);
928 }
929
930 let mut text = String::from(CROSS_SESSION_PREFIX);
931 let mut tokens_used = tc.count_tokens(&text);
932
933 for item in &results {
934 let entry = format!("- {}\n", item.summary_text);
935 let cost = tc.count_tokens(&entry);
936 if tokens_used + cost > token_budget {
937 break;
938 }
939 text.push_str(&entry);
940 tokens_used += cost;
941 }
942
943 if tokens_used > tc.count_tokens(CROSS_SESSION_PREFIX) {
944 Ok(Some(Message::from_parts(
945 Role::System,
946 vec![MessagePart::CrossSession { text }],
947 )))
948 } else {
949 Ok(None)
950 }
951}
952
953pub const MAX_KEEP_TAIL_SCAN: usize = 50;
956
957#[must_use]
965pub fn memory_first_keep_tail(messages: &[Message], history_start: usize) -> usize {
966 use zeph_llm::provider::MessagePart;
967
968 let mut keep_tail = 2usize;
969 let len = messages.len();
970 let max = len.saturating_sub(history_start);
971
972 while keep_tail < max {
973 let first_retained = &messages[len - keep_tail];
974 let is_tool_result = first_retained.role == Role::User
975 && first_retained
976 .parts
977 .iter()
978 .any(|p| matches!(p, MessagePart::ToolResult { .. }));
979
980 if is_tool_result {
981 keep_tail += 1;
982 } else {
983 break;
984 }
985
986 if keep_tail >= MAX_KEEP_TAIL_SCAN {
987 let preceding_idx = len.saturating_sub(keep_tail + 1);
988 if preceding_idx >= history_start {
989 let preceding = &messages[preceding_idx];
990 let is_tool_use = preceding.role == Role::Assistant
991 && preceding
992 .parts
993 .iter()
994 .any(|p| matches!(p, MessagePart::ToolUse { .. }));
995 if is_tool_use {
996 keep_tail += 1;
997 }
998 }
999 break;
1000 }
1001 }
1002
1003 keep_tail
1004}
1005
1006#[cfg(test)]
1007mod tests {
1008 use super::*;
1009 use crate::input::ContextMemoryView;
1010 use zeph_common::memory::CompressionLevel;
1011 use zeph_config::{
1012 ContextStrategy, DocumentConfig, GraphConfig, PersonaConfig, ReasoningConfig,
1013 TrajectoryConfig, TreeConfig,
1014 };
1015
1016 struct NaiveTokenCounter;
1017 impl zeph_common::memory::TokenCounting for NaiveTokenCounter {
1018 fn count_tokens(&self, text: &str) -> usize {
1019 text.split_whitespace().count()
1020 }
1021 fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
1022 schema.to_string().split_whitespace().count()
1023 }
1024 }
1025
1026 fn empty_view() -> ContextMemoryView {
1027 ContextMemoryView {
1028 memory: None,
1029 conversation_id: None,
1030 recall_limit: 10,
1031 cross_session_score_threshold: 0.5,
1032 context_strategy: ContextStrategy::default(),
1033 crossover_turn_threshold: 5,
1034 cached_session_digest: None,
1035 graph_config: GraphConfig::default(),
1036 document_config: DocumentConfig::default(),
1037 persona_config: PersonaConfig::default(),
1038 trajectory_config: TrajectoryConfig::default(),
1039 reasoning_config: ReasoningConfig::default(),
1040 memcot_config: zeph_config::MemCotConfig::default(),
1041 memcot_state: None,
1042 tree_config: TreeConfig::default(),
1043 }
1044 }
1045
1046 #[tokio::test]
1049 async fn fetch_graph_facts_returns_none_when_memory_is_none() {
1050 let view = empty_view();
1051 let tc = NaiveTokenCounter;
1052 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1053 assert!(result.is_none());
1054 }
1055
1056 #[tokio::test]
1057 async fn fetch_graph_facts_returns_none_when_budget_zero() {
1058 let mut view = empty_view();
1059 view.graph_config.enabled = true;
1060 let tc = NaiveTokenCounter;
1061 let result = fetch_graph_facts(&view, "test", 0, &tc).await.unwrap();
1062 assert!(result.is_none());
1063 }
1064
1065 #[tokio::test]
1066 async fn fetch_graph_facts_returns_none_when_graph_disabled() {
1067 let mut view = empty_view();
1068 view.graph_config.enabled = false;
1069 let tc = NaiveTokenCounter;
1070 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1071 assert!(result.is_none());
1072 }
1073
1074 #[tokio::test]
1077 async fn fetch_persona_facts_returns_none_when_memory_is_none() {
1078 let view = empty_view();
1079 let tc = NaiveTokenCounter;
1080 let result = fetch_persona_facts(&view, 1000, &tc).await.unwrap();
1081 assert!(result.is_none());
1082 }
1083
1084 #[tokio::test]
1085 async fn fetch_persona_facts_returns_none_when_budget_zero() {
1086 let mut view = empty_view();
1087 view.persona_config.enabled = true;
1088 let tc = NaiveTokenCounter;
1089 let result = fetch_persona_facts(&view, 0, &tc).await.unwrap();
1090 assert!(result.is_none());
1091 }
1092
1093 #[tokio::test]
1096 async fn fetch_trajectory_hints_returns_none_when_memory_is_none() {
1097 let view = empty_view();
1098 let tc = NaiveTokenCounter;
1099 let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1100 assert!(result.is_none());
1101 }
1102
1103 #[tokio::test]
1104 async fn fetch_trajectory_hints_returns_none_when_budget_zero() {
1105 let mut view = empty_view();
1106 view.trajectory_config.enabled = true;
1107 let tc = NaiveTokenCounter;
1108 let result = fetch_trajectory_hints(&view, 0, &tc).await.unwrap();
1109 assert!(result.is_none());
1110 }
1111
1112 #[tokio::test]
1115 async fn fetch_tree_memory_returns_none_when_memory_is_none() {
1116 let view = empty_view();
1117 let tc = NaiveTokenCounter;
1118 let result = fetch_tree_memory(&view, 1000, &tc).await.unwrap();
1119 assert!(result.is_none());
1120 }
1121
1122 #[tokio::test]
1123 async fn fetch_tree_memory_returns_none_when_budget_zero() {
1124 let mut view = empty_view();
1125 view.tree_config.enabled = true;
1126 let tc = NaiveTokenCounter;
1127 let result = fetch_tree_memory(&view, 0, &tc).await.unwrap();
1128 assert!(result.is_none());
1129 }
1130
1131 #[tokio::test]
1134 async fn fetch_corrections_returns_none_when_memory_is_none() {
1135 let view = empty_view();
1136 let result = fetch_corrections(&view, "test", 10, 0.5, |s| s.into())
1137 .await
1138 .unwrap();
1139 assert!(result.is_none());
1140 }
1141
1142 #[tokio::test]
1145 async fn fetch_semantic_recall_returns_none_when_memory_is_none() {
1146 let view = empty_view();
1147 let tc = NaiveTokenCounter;
1148 let result = fetch_semantic_recall(&view, "test", 1000, &tc, None)
1149 .await
1150 .unwrap();
1151 assert!(result.0.is_none() && result.1.is_none());
1152 }
1153
1154 #[tokio::test]
1155 async fn fetch_semantic_recall_returns_none_when_budget_zero() {
1156 let view = empty_view();
1157 let tc = NaiveTokenCounter;
1158 let result = fetch_semantic_recall(&view, "test", 0, &tc, None)
1159 .await
1160 .unwrap();
1161 assert!(result.0.is_none() && result.1.is_none());
1162 }
1163
1164 #[tokio::test]
1167 async fn fetch_document_rag_returns_none_when_memory_is_none() {
1168 let mut view = empty_view();
1169 view.document_config.rag_enabled = true;
1170 let tc = NaiveTokenCounter;
1171 let result = fetch_document_rag(&view, "test", 1000, &tc).await.unwrap();
1172 assert!(result.is_none());
1173 }
1174
1175 #[tokio::test]
1176 async fn fetch_document_rag_returns_none_when_rag_disabled() {
1177 let view = empty_view();
1178 let tc = NaiveTokenCounter;
1179 let result = fetch_document_rag(&view, "test", 1000, &tc).await.unwrap();
1180 assert!(result.is_none());
1181 }
1182
1183 #[tokio::test]
1186 async fn fetch_summaries_returns_none_when_memory_is_none() {
1187 let view = empty_view();
1188 let tc = NaiveTokenCounter;
1189 let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1190 assert!(result.is_none());
1191 }
1192
1193 #[tokio::test]
1196 async fn fetch_cross_session_returns_none_when_memory_is_none() {
1197 let view = empty_view();
1198 let tc = NaiveTokenCounter;
1199 let result = fetch_cross_session(&view, "test", 1000, &tc).await.unwrap();
1200 assert!(result.is_none());
1201 }
1202
1203 #[test]
1206 fn levels_to_flags_empty_slice_enables_all_tiers() {
1207 let (e, p, d) = levels_to_flags(&[]);
1208 assert!(e, "episodic should be active for empty slice");
1209 assert!(p, "procedural should be active for empty slice");
1210 assert!(d, "declarative should be active for empty slice");
1211 }
1212
1213 #[test]
1214 fn levels_to_flags_full_set_enables_all_tiers() {
1215 let all = &[
1216 CompressionLevel::Episodic,
1217 CompressionLevel::Procedural,
1218 CompressionLevel::Declarative,
1219 ];
1220 let (e, p, d) = levels_to_flags(all);
1221 assert!(e);
1222 assert!(p);
1223 assert!(d);
1224 }
1225
1226 #[test]
1227 fn levels_to_flags_episodic_only() {
1228 let (e, p, d) = levels_to_flags(&[CompressionLevel::Episodic]);
1229 assert!(e);
1230 assert!(!p, "procedural should be inactive");
1231 assert!(!d, "declarative should be inactive");
1232 }
1233
1234 #[test]
1235 fn levels_to_flags_episodic_and_procedural() {
1236 let (e, p, d) =
1237 levels_to_flags(&[CompressionLevel::Episodic, CompressionLevel::Procedural]);
1238 assert!(e);
1239 assert!(p);
1240 assert!(!d, "declarative should be inactive");
1241 }
1242
1243 #[test]
1244 fn levels_to_flags_declarative_only() {
1245 let (e, p, d) = levels_to_flags(&[CompressionLevel::Declarative]);
1246 assert!(!e, "episodic should be inactive");
1247 assert!(!p, "procedural should be inactive");
1248 assert!(d);
1249 }
1250
1251 #[tokio::test]
1254 async fn fetch_reasoning_strategies_returns_none_when_memory_is_none() {
1255 let mut view = empty_view();
1256 view.reasoning_config.enabled = true;
1257 let tc = NaiveTokenCounter;
1258 let result = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc)
1259 .await
1260 .unwrap();
1261 assert!(result.is_none());
1262 }
1263
1264 #[tokio::test]
1265 async fn fetch_reasoning_strategies_returns_none_when_budget_zero() {
1266 let mut view = empty_view();
1267 view.reasoning_config.enabled = true;
1268 let tc = NaiveTokenCounter;
1269 let result = fetch_reasoning_strategies(&view, "query", 0, 3, &tc)
1270 .await
1271 .unwrap();
1272 assert!(result.is_none());
1273 }
1274
1275 use std::sync::{Arc, Mutex};
1278 use zeph_common::memory::{
1279 ContextMemoryBackend, GraphRecallParams, MemCorrection, MemDocumentChunk, MemGraphFact,
1280 MemPersonaFact, MemReasoningStrategy, MemRecalledMessage, MemSessionSummary, MemSummary,
1281 MemTrajectoryEntry, MemTreeNode,
1282 };
1283
1284 const KNOWN_FAIL_ON: &[&str] = &[
1286 "load_persona_facts",
1287 "load_trajectory_entries",
1288 "load_tree_nodes",
1289 "load_summaries",
1290 "retrieve_reasoning_strategies",
1291 "mark_reasoning_used",
1292 "retrieve_corrections",
1293 "recall",
1294 "recall_graph_facts",
1295 "search_session_summaries",
1296 "search_document_collection",
1297 ];
1298
1299 #[derive(Default)]
1300 struct MockMemoryBackend {
1301 persona_facts: Vec<MemPersonaFact>,
1302 trajectory_entries: Vec<MemTrajectoryEntry>,
1303 tree_nodes: Vec<MemTreeNode>,
1304 summaries: Vec<MemSummary>,
1305 reasoning_strategies: Vec<MemReasoningStrategy>,
1306 corrections: Vec<MemCorrection>,
1307 recalled: Vec<MemRecalledMessage>,
1308 graph_facts: Vec<MemGraphFact>,
1309 session_summaries: Vec<MemSessionSummary>,
1310 document_chunks: Vec<MemDocumentChunk>,
1311 fail_on: Option<&'static str>,
1313 marked_ids: Mutex<Vec<String>>,
1315 }
1316
1317 impl MockMemoryBackend {
1318 fn with_fail_on(method: &'static str) -> Self {
1319 debug_assert!(
1320 KNOWN_FAIL_ON.contains(&method),
1321 "unknown fail_on method name: {method}"
1322 );
1323 Self {
1324 fail_on: Some(method),
1325 ..Default::default()
1326 }
1327 }
1328
1329 fn fail_err(method: &str) -> Box<dyn std::error::Error + Send + Sync> {
1330 format!("mock error in {method}").into()
1331 }
1332 }
1333
1334 impl ContextMemoryBackend for MockMemoryBackend {
1335 fn load_persona_facts<'a>(
1336 &'a self,
1337 _min_confidence: f64,
1338 ) -> std::pin::Pin<
1339 Box<
1340 dyn std::future::Future<
1341 Output = Result<
1342 Vec<MemPersonaFact>,
1343 Box<dyn std::error::Error + Send + Sync>,
1344 >,
1345 > + Send
1346 + 'a,
1347 >,
1348 > {
1349 let result = if self.fail_on == Some("load_persona_facts") {
1350 Err(Self::fail_err("load_persona_facts"))
1351 } else {
1352 Ok(self.persona_facts.clone())
1353 };
1354 Box::pin(async move { result })
1355 }
1356
1357 fn load_trajectory_entries<'a>(
1358 &'a self,
1359 _tier: Option<&'a str>,
1360 _top_k: usize,
1361 ) -> std::pin::Pin<
1362 Box<
1363 dyn std::future::Future<
1364 Output = Result<
1365 Vec<MemTrajectoryEntry>,
1366 Box<dyn std::error::Error + Send + Sync>,
1367 >,
1368 > + Send
1369 + 'a,
1370 >,
1371 > {
1372 let result = if self.fail_on == Some("load_trajectory_entries") {
1373 Err(Self::fail_err("load_trajectory_entries"))
1374 } else {
1375 Ok(self.trajectory_entries.clone())
1376 };
1377 Box::pin(async move { result })
1378 }
1379
1380 fn load_tree_nodes<'a>(
1381 &'a self,
1382 _level: u32,
1383 _top_k: usize,
1384 ) -> std::pin::Pin<
1385 Box<
1386 dyn std::future::Future<
1387 Output = Result<Vec<MemTreeNode>, Box<dyn std::error::Error + Send + Sync>>,
1388 > + Send
1389 + 'a,
1390 >,
1391 > {
1392 let result = if self.fail_on == Some("load_tree_nodes") {
1393 Err(Self::fail_err("load_tree_nodes"))
1394 } else {
1395 Ok(self.tree_nodes.clone())
1396 };
1397 Box::pin(async move { result })
1398 }
1399
1400 fn load_summaries<'a>(
1401 &'a self,
1402 _conversation_id: i64,
1403 ) -> std::pin::Pin<
1404 Box<
1405 dyn std::future::Future<
1406 Output = Result<Vec<MemSummary>, Box<dyn std::error::Error + Send + Sync>>,
1407 > + Send
1408 + 'a,
1409 >,
1410 > {
1411 let result = if self.fail_on == Some("load_summaries") {
1412 Err(Self::fail_err("load_summaries"))
1413 } else {
1414 Ok(self.summaries.clone())
1415 };
1416 Box::pin(async move { result })
1417 }
1418
1419 fn retrieve_reasoning_strategies<'a>(
1420 &'a self,
1421 _query: &'a str,
1422 _top_k: usize,
1423 ) -> std::pin::Pin<
1424 Box<
1425 dyn std::future::Future<
1426 Output = Result<
1427 Vec<MemReasoningStrategy>,
1428 Box<dyn std::error::Error + Send + Sync>,
1429 >,
1430 > + Send
1431 + 'a,
1432 >,
1433 > {
1434 let result = if self.fail_on == Some("retrieve_reasoning_strategies") {
1435 Err(Self::fail_err("retrieve_reasoning_strategies"))
1436 } else {
1437 Ok(self.reasoning_strategies.clone())
1438 };
1439 Box::pin(async move { result })
1440 }
1441
1442 fn mark_reasoning_used<'a>(
1443 &'a self,
1444 ids: &'a [String],
1445 ) -> std::pin::Pin<
1446 Box<
1447 dyn std::future::Future<
1448 Output = Result<(), Box<dyn std::error::Error + Send + Sync>>,
1449 > + Send
1450 + 'a,
1451 >,
1452 > {
1453 if self.fail_on == Some("mark_reasoning_used") {
1454 return Box::pin(async move { Err(Self::fail_err("mark_reasoning_used")) });
1455 }
1456 let mut guard = self.marked_ids.lock().expect("marked_ids poisoned");
1457 guard.extend_from_slice(ids);
1458 Box::pin(async move { Ok(()) })
1459 }
1460
1461 fn retrieve_corrections<'a>(
1462 &'a self,
1463 _query: &'a str,
1464 _limit: usize,
1465 _min_score: f32,
1466 ) -> std::pin::Pin<
1467 Box<
1468 dyn std::future::Future<
1469 Output = Result<
1470 Vec<MemCorrection>,
1471 Box<dyn std::error::Error + Send + Sync>,
1472 >,
1473 > + Send
1474 + 'a,
1475 >,
1476 > {
1477 let result = if self.fail_on == Some("retrieve_corrections") {
1478 Err(Self::fail_err("retrieve_corrections"))
1479 } else {
1480 Ok(self.corrections.clone())
1481 };
1482 Box::pin(async move { result })
1483 }
1484
1485 fn recall<'a>(
1486 &'a self,
1487 _query: &'a str,
1488 _limit: usize,
1489 _router: Option<&'a dyn zeph_common::memory::AsyncMemoryRouter>,
1490 ) -> std::pin::Pin<
1491 Box<
1492 dyn std::future::Future<
1493 Output = Result<
1494 Vec<MemRecalledMessage>,
1495 Box<dyn std::error::Error + Send + Sync>,
1496 >,
1497 > + Send
1498 + 'a,
1499 >,
1500 > {
1501 let result = if self.fail_on == Some("recall") {
1502 Err(Self::fail_err("recall"))
1503 } else {
1504 Ok(self.recalled.clone())
1505 };
1506 Box::pin(async move { result })
1507 }
1508
1509 fn recall_graph_facts<'a>(
1510 &'a self,
1511 _query: &'a str,
1512 _params: GraphRecallParams<'a>,
1513 ) -> std::pin::Pin<
1514 Box<
1515 dyn std::future::Future<
1516 Output = Result<
1517 Vec<MemGraphFact>,
1518 Box<dyn std::error::Error + Send + Sync>,
1519 >,
1520 > + Send
1521 + 'a,
1522 >,
1523 > {
1524 let result = if self.fail_on == Some("recall_graph_facts") {
1525 Err(Self::fail_err("recall_graph_facts"))
1526 } else {
1527 Ok(self.graph_facts.clone())
1528 };
1529 Box::pin(async move { result })
1530 }
1531
1532 fn search_session_summaries<'a>(
1533 &'a self,
1534 _query: &'a str,
1535 _limit: usize,
1536 _current_conversation_id: Option<i64>,
1537 ) -> std::pin::Pin<
1538 Box<
1539 dyn std::future::Future<
1540 Output = Result<
1541 Vec<MemSessionSummary>,
1542 Box<dyn std::error::Error + Send + Sync>,
1543 >,
1544 > + Send
1545 + 'a,
1546 >,
1547 > {
1548 let result = if self.fail_on == Some("search_session_summaries") {
1549 Err(Self::fail_err("search_session_summaries"))
1550 } else {
1551 Ok(self.session_summaries.clone())
1552 };
1553 Box::pin(async move { result })
1554 }
1555
1556 fn search_document_collection<'a>(
1557 &'a self,
1558 _collection: &'a str,
1559 _query: &'a str,
1560 _top_k: usize,
1561 ) -> std::pin::Pin<
1562 Box<
1563 dyn std::future::Future<
1564 Output = Result<
1565 Vec<MemDocumentChunk>,
1566 Box<dyn std::error::Error + Send + Sync>,
1567 >,
1568 > + Send
1569 + 'a,
1570 >,
1571 > {
1572 let result = if self.fail_on == Some("search_document_collection") {
1573 Err(Self::fail_err("search_document_collection"))
1574 } else {
1575 Ok(self.document_chunks.clone())
1576 };
1577 Box::pin(async move { result })
1578 }
1579 }
1580
1581 fn mock_view(mock: MockMemoryBackend) -> ContextMemoryView {
1582 let mut v = empty_view();
1583 v.memory = Some(Arc::new(mock));
1584 v
1585 }
1586
1587 #[tokio::test]
1590 async fn fetch_graph_facts_returns_message_when_memory_present() {
1591 let mock = MockMemoryBackend {
1592 graph_facts: vec![zeph_common::memory::MemGraphFact {
1593 fact: "Rust is fast".to_string(),
1594 confidence: 0.9,
1595 activation_score: None,
1596 neighbors: vec![],
1597 provenance_snippet: None,
1598 }],
1599 ..Default::default()
1600 };
1601 let mut view = mock_view(mock);
1602 view.graph_config.enabled = true;
1603 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1605 let tc = NaiveTokenCounter;
1606 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1607 assert!(result.is_some(), "expected Some message");
1608 let msg = result.unwrap();
1609 assert!(
1610 msg.content.contains("Rust is fast"),
1611 "expected fact text in output, got: {}",
1612 msg.content
1613 );
1614 assert!(
1615 msg.content.starts_with(GRAPH_FACTS_PREFIX),
1616 "expected GRAPH_FACTS_PREFIX"
1617 );
1618 }
1619
1620 #[tokio::test]
1621 async fn fetch_graph_facts_swallows_error_and_returns_none() {
1622 let mock = MockMemoryBackend::with_fail_on("recall_graph_facts");
1623 let mut view = mock_view(mock);
1624 view.graph_config.enabled = true;
1625 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1626 let tc = NaiveTokenCounter;
1627 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1629 assert!(
1630 result.is_none(),
1631 "expected None when recall_graph_facts errors"
1632 );
1633 }
1634
1635 #[tokio::test]
1636 async fn fetch_graph_facts_returns_none_when_facts_empty() {
1637 let mock = MockMemoryBackend::default(); let mut view = mock_view(mock);
1639 view.graph_config.enabled = true;
1640 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1641 let tc = NaiveTokenCounter;
1642 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1643 assert!(result.is_none());
1644 }
1645
1646 #[tokio::test]
1649 async fn fetch_persona_facts_returns_message_when_persona_enabled() {
1650 let mock = MockMemoryBackend {
1651 persona_facts: vec![MemPersonaFact {
1652 category: "preference".to_string(),
1653 content: "prefers concise answers".to_string(),
1654 }],
1655 ..Default::default()
1656 };
1657 let mut view = mock_view(mock);
1658 view.persona_config.enabled = true;
1659 view.persona_config.context_budget_tokens = 1000;
1660 let tc = NaiveTokenCounter;
1661 let result = fetch_persona_facts(&view, 1000, &tc).await.unwrap();
1662 assert!(result.is_some());
1663 let msg = result.unwrap();
1664 assert!(msg.content.contains("preference"));
1665 assert!(msg.content.contains("prefers concise answers"));
1666 assert!(msg.content.starts_with(crate::slot::PERSONA_PREFIX));
1667 }
1668
1669 #[tokio::test]
1670 async fn fetch_persona_facts_propagates_error() {
1671 let mock = MockMemoryBackend::with_fail_on("load_persona_facts");
1672 let mut view = mock_view(mock);
1673 view.persona_config.enabled = true;
1674 let tc = NaiveTokenCounter;
1675 let result = fetch_persona_facts(&view, 1000, &tc).await;
1676 assert!(
1677 result.is_err(),
1678 "expected Err from load_persona_facts failure"
1679 );
1680 }
1681
1682 #[tokio::test]
1685 async fn fetch_trajectory_hints_returns_message_when_trajectory_enabled() {
1686 let mock = MockMemoryBackend {
1687 trajectory_entries: vec![MemTrajectoryEntry {
1688 intent: "summarize code".to_string(),
1689 outcome: "produced concise summary".to_string(),
1690 confidence: 0.9,
1691 }],
1692 ..Default::default()
1693 };
1694 let mut view = mock_view(mock);
1695 view.trajectory_config.enabled = true;
1696 view.trajectory_config.context_budget_tokens = 1000;
1697 view.trajectory_config.min_confidence = 0.5;
1698 let tc = NaiveTokenCounter;
1699 let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1700 assert!(result.is_some());
1701 let msg = result.unwrap();
1702 assert!(msg.content.contains("summarize code"));
1703 assert!(msg.content.starts_with(crate::slot::TRAJECTORY_PREFIX));
1704 }
1705
1706 #[tokio::test]
1707 async fn fetch_trajectory_hints_passes_tier_filter() {
1708 let mock = MockMemoryBackend {
1711 trajectory_entries: vec![
1712 MemTrajectoryEntry {
1713 intent: "debug async code".to_string(),
1714 outcome: "fixed deadlock".to_string(),
1715 confidence: 0.85,
1716 },
1717 MemTrajectoryEntry {
1718 intent: "low confidence task".to_string(),
1719 outcome: "irrelevant".to_string(),
1720 confidence: 0.3,
1721 },
1722 ],
1723 ..Default::default()
1724 };
1725 let mut view = mock_view(mock);
1726 view.trajectory_config.enabled = true;
1727 view.trajectory_config.context_budget_tokens = 1000;
1728 view.trajectory_config.min_confidence = 0.5;
1729 let tc = NaiveTokenCounter;
1730 let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1731 assert!(result.is_some(), "expected Some message");
1732 let msg = result.unwrap();
1733 assert!(
1734 msg.content.contains("debug async code"),
1735 "high-confidence entry must be included"
1736 );
1737 assert!(
1738 !msg.content.contains("low confidence task"),
1739 "entry below min_confidence must be filtered out"
1740 );
1741 }
1742
1743 #[tokio::test]
1744 async fn fetch_trajectory_hints_propagates_error() {
1745 let mock = MockMemoryBackend::with_fail_on("load_trajectory_entries");
1746 let mut view = mock_view(mock);
1747 view.trajectory_config.enabled = true;
1748 let tc = NaiveTokenCounter;
1749 let result = fetch_trajectory_hints(&view, 1000, &tc).await;
1750 assert!(result.is_err());
1751 }
1752
1753 #[tokio::test]
1756 async fn fetch_tree_memory_returns_message_when_tree_enabled() {
1757 let mock = MockMemoryBackend {
1758 tree_nodes: vec![MemTreeNode {
1759 content: "Topic: async Rust patterns".to_string(),
1760 }],
1761 ..Default::default()
1762 };
1763 let mut view = mock_view(mock);
1764 view.tree_config.enabled = true;
1765 view.tree_config.context_budget_tokens = 1000;
1766 let tc = NaiveTokenCounter;
1767 let result = fetch_tree_memory(&view, 1000, &tc).await.unwrap();
1768 assert!(result.is_some());
1769 let msg = result.unwrap();
1770 assert!(msg.content.contains("async Rust patterns"));
1771 assert!(msg.content.starts_with(crate::slot::TREE_MEMORY_PREFIX));
1772 }
1773
1774 #[tokio::test]
1775 async fn fetch_tree_memory_propagates_error() {
1776 let mock = MockMemoryBackend::with_fail_on("load_tree_nodes");
1777 let mut view = mock_view(mock);
1778 view.tree_config.enabled = true;
1779 let tc = NaiveTokenCounter;
1780 let result = fetch_tree_memory(&view, 1000, &tc).await;
1781 assert!(result.is_err());
1782 }
1783
1784 #[tokio::test]
1787 async fn fetch_corrections_returns_message_when_corrections_present() {
1788 let mock = MockMemoryBackend {
1789 corrections: vec![MemCorrection {
1790 correction_text: "use snake_case not camelCase".to_string(),
1791 }],
1792 ..Default::default()
1793 };
1794 let view = mock_view(mock);
1795 let result = fetch_corrections(&view, "query", 10, 0.5, |s| s.into())
1796 .await
1797 .unwrap();
1798 assert!(result.is_some());
1799 let msg = result.unwrap();
1800 assert!(msg.content.contains("snake_case"));
1801 assert!(msg.content.starts_with(CORRECTIONS_PREFIX));
1802 }
1803
1804 #[tokio::test]
1805 async fn fetch_corrections_swallows_error_returns_none() {
1806 let mock = MockMemoryBackend::with_fail_on("retrieve_corrections");
1809 let view = mock_view(mock);
1810 let result = fetch_corrections(&view, "query", 10, 0.5, |s| s.into())
1811 .await
1812 .unwrap();
1813 assert!(result.is_none());
1814 }
1815
1816 #[tokio::test]
1819 async fn fetch_semantic_recall_returns_message_with_content() {
1820 let mock = MockMemoryBackend {
1821 recalled: vec![
1822 MemRecalledMessage {
1823 role: "user".to_string(),
1824 content: "how does tokio work".to_string(),
1825 score: 0.95,
1826 },
1827 MemRecalledMessage {
1828 role: "assistant".to_string(),
1829 content: "tokio is an async runtime".to_string(),
1830 score: 0.88,
1831 },
1832 ],
1833 ..Default::default()
1834 };
1835 let mut view = mock_view(mock);
1836 view.recall_limit = 10;
1837 let tc = NaiveTokenCounter;
1838 let (msg, score) = fetch_semantic_recall(&view, "tokio", 1000, &tc, None)
1839 .await
1840 .unwrap();
1841 assert!(msg.is_some(), "expected Some message");
1842 assert!(score.is_some_and(|s| (s - 0.95_f32).abs() < f32::EPSILON));
1844 let msg = msg.unwrap();
1845 let has_recall_part = msg.parts.iter().any(|p| {
1847 if let zeph_llm::provider::MessagePart::Recall { text } = p {
1848 text.contains("how does tokio work")
1849 } else {
1850 false
1851 }
1852 });
1853 assert!(has_recall_part, "expected recalled content in Recall part");
1854 }
1855
1856 #[tokio::test]
1857 async fn fetch_semantic_recall_returns_none_when_recalled_empty() {
1858 let mock = MockMemoryBackend::default();
1859 let mut view = mock_view(mock);
1860 view.recall_limit = 10;
1861 let tc = NaiveTokenCounter;
1862 let (msg, score) = fetch_semantic_recall(&view, "query", 1000, &tc, None)
1863 .await
1864 .unwrap();
1865 assert!(msg.is_none());
1866 assert!(score.is_none());
1867 }
1868
1869 #[tokio::test]
1870 async fn fetch_semantic_recall_propagates_error() {
1871 let mock = MockMemoryBackend::with_fail_on("recall");
1872 let mut view = mock_view(mock);
1873 view.recall_limit = 10;
1874 let tc = NaiveTokenCounter;
1875 let result = fetch_semantic_recall(&view, "query", 1000, &tc, None).await;
1876 assert!(result.is_err());
1877 }
1878
1879 #[tokio::test]
1882 async fn fetch_document_rag_returns_message_when_rag_enabled() {
1883 let mock = MockMemoryBackend {
1884 document_chunks: vec![MemDocumentChunk {
1885 text: "Rust ownership rules prevent data races".to_string(),
1886 }],
1887 ..Default::default()
1888 };
1889 let mut view = mock_view(mock);
1890 view.document_config.rag_enabled = true;
1891 let tc = NaiveTokenCounter;
1892 let result = fetch_document_rag(&view, "ownership", 1000, &tc)
1893 .await
1894 .unwrap();
1895 assert!(result.is_some());
1896 let msg = result.unwrap();
1897 assert!(msg.content.contains("ownership rules"));
1898 assert!(msg.content.starts_with(DOCUMENT_RAG_PREFIX));
1899 }
1900
1901 #[tokio::test]
1902 async fn fetch_document_rag_propagates_error() {
1903 let mock = MockMemoryBackend::with_fail_on("search_document_collection");
1904 let mut view = mock_view(mock);
1905 view.document_config.rag_enabled = true;
1906 let tc = NaiveTokenCounter;
1907 let result = fetch_document_rag(&view, "query", 1000, &tc).await;
1908 assert!(result.is_err());
1909 }
1910
1911 #[tokio::test]
1914 async fn fetch_summaries_returns_message_when_summaries_present() {
1915 let mock = MockMemoryBackend {
1916 summaries: vec![MemSummary {
1917 first_message_id: Some(1),
1918 last_message_id: Some(5),
1919 content: "User asked about async Rust".to_string(),
1920 }],
1921 ..Default::default()
1922 };
1923 let mut view = mock_view(mock);
1924 view.conversation_id = Some(42);
1925 let tc = NaiveTokenCounter;
1926 let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1927 assert!(result.is_some());
1928 let msg = result.unwrap();
1929 let has_summary_part = msg.parts.iter().any(|p| {
1930 if let zeph_llm::provider::MessagePart::Summary { text } = p {
1931 text.contains("Messages 1-5") && text.contains("async Rust")
1932 } else {
1933 false
1934 }
1935 });
1936 assert!(
1937 has_summary_part,
1938 "expected Summary part with messages range"
1939 );
1940 }
1941
1942 #[tokio::test]
1943 async fn fetch_summaries_returns_none_without_conversation_id() {
1944 let mock = MockMemoryBackend {
1945 summaries: vec![MemSummary {
1946 first_message_id: Some(1),
1947 last_message_id: Some(5),
1948 content: "some content".to_string(),
1949 }],
1950 ..Default::default()
1951 };
1952 let mut view = mock_view(mock);
1953 view.conversation_id = None; let tc = NaiveTokenCounter;
1955 let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1956 assert!(result.is_none());
1957 }
1958
1959 #[tokio::test]
1960 async fn fetch_summaries_propagates_error() {
1961 let mock = MockMemoryBackend::with_fail_on("load_summaries");
1962 let mut view = mock_view(mock);
1963 view.conversation_id = Some(42);
1964 let tc = NaiveTokenCounter;
1965 let result = fetch_summaries(&view, 1000, &tc).await;
1966 assert!(result.is_err());
1967 }
1968
1969 #[tokio::test]
1972 async fn fetch_cross_session_returns_message_when_results_present() {
1973 let mock = MockMemoryBackend {
1974 session_summaries: vec![MemSessionSummary {
1975 summary_text: "Previous session: debugging tokio deadlock".to_string(),
1976 score: 0.9,
1977 }],
1978 ..Default::default()
1979 };
1980 let mut view = mock_view(mock);
1981 view.conversation_id = Some(1);
1982 view.cross_session_score_threshold = 0.5;
1983 let tc = NaiveTokenCounter;
1984 let result = fetch_cross_session(&view, "async", 1000, &tc)
1985 .await
1986 .unwrap();
1987 assert!(result.is_some());
1988 let msg = result.unwrap();
1989 let has_cross_session_part = msg.parts.iter().any(|p| {
1990 if let zeph_llm::provider::MessagePart::CrossSession { text } = p {
1991 text.contains("tokio deadlock")
1992 } else {
1993 false
1994 }
1995 });
1996 assert!(has_cross_session_part);
1997 }
1998
1999 #[tokio::test]
2000 async fn fetch_cross_session_propagates_error() {
2001 let mock = MockMemoryBackend::with_fail_on("search_session_summaries");
2002 let mut view = mock_view(mock);
2003 view.conversation_id = Some(1);
2004 let tc = NaiveTokenCounter;
2005 let result = fetch_cross_session(&view, "query", 1000, &tc).await;
2006 assert!(result.is_err());
2007 }
2008
2009 #[tokio::test]
2012 async fn fetch_reasoning_strategies_returns_message_and_marks_used() {
2013 let mock = Arc::new(MockMemoryBackend {
2014 reasoning_strategies: vec![
2015 MemReasoningStrategy {
2016 id: "strat-1".to_string(),
2017 outcome: "success".to_string(),
2018 summary: "break the problem into small steps".to_string(),
2019 },
2020 MemReasoningStrategy {
2021 id: "strat-2".to_string(),
2022 outcome: "success".to_string(),
2023 summary: "use tracing spans for debugging".to_string(),
2024 },
2025 ],
2026 ..Default::default()
2027 });
2028 let marked_ids = Arc::clone(&mock);
2029 let mut view = empty_view();
2030 view.memory = Some(mock);
2031 view.reasoning_config.enabled = true;
2032 view.reasoning_config.context_budget_tokens = 1000;
2033 let tc = NaiveTokenCounter;
2034 let result = fetch_reasoning_strategies(&view, "debug", 1000, 5, &tc)
2035 .await
2036 .unwrap();
2037 assert!(result.is_some());
2038 let msg = result.unwrap();
2039 assert!(msg.content.starts_with(crate::slot::REASONING_PREFIX));
2040 assert!(msg.content.contains("break the problem"));
2041
2042 tokio::task::yield_now().await;
2046 tokio::task::yield_now().await;
2047
2048 let ids = marked_ids.marked_ids.lock().expect("marked_ids poisoned");
2049 assert!(
2050 ids.contains(&"strat-1".to_string()),
2051 "expected strat-1 marked"
2052 );
2053 assert!(
2054 ids.contains(&"strat-2".to_string()),
2055 "expected strat-2 marked"
2056 );
2057 }
2058
2059 #[tokio::test]
2060 async fn fetch_reasoning_strategies_propagates_error() {
2061 let mock = MockMemoryBackend::with_fail_on("retrieve_reasoning_strategies");
2062 let mut view = mock_view(mock);
2063 view.reasoning_config.enabled = true;
2064 let tc = NaiveTokenCounter;
2065 let result = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc).await;
2066 assert!(result.is_err());
2067 }
2068
2069 #[tokio::test]
2072 async fn fetch_semantic_recall_skips_skipped_and_stopped_messages() {
2073 let mock = MockMemoryBackend {
2074 recalled: vec![
2075 MemRecalledMessage {
2076 role: "user".to_string(),
2077 content: "[skipped] some content".to_string(),
2078 score: 0.95,
2079 },
2080 MemRecalledMessage {
2081 role: "user".to_string(),
2082 content: "[stopped] other content".to_string(),
2083 score: 0.90,
2084 },
2085 MemRecalledMessage {
2086 role: "user".to_string(),
2087 content: "valid content to recall".to_string(),
2088 score: 0.85,
2089 },
2090 ],
2091 ..Default::default()
2092 };
2093 let mut view = mock_view(mock);
2094 view.recall_limit = 10;
2095 let tc = NaiveTokenCounter;
2096 let (msg, _) = fetch_semantic_recall(&view, "query", 1000, &tc, None)
2097 .await
2098 .unwrap();
2099 assert!(msg.is_some());
2100 let msg = msg.unwrap();
2101 let full_text = msg.parts.iter().find_map(|p| {
2102 if let zeph_llm::provider::MessagePart::Recall { text } = p {
2103 Some(text.clone())
2104 } else {
2105 None
2106 }
2107 });
2108 let text = full_text.unwrap_or_default();
2109 assert!(
2110 !text.contains("[skipped]"),
2111 "skipped messages must be excluded"
2112 );
2113 assert!(
2114 !text.contains("[stopped]"),
2115 "stopped messages must be excluded"
2116 );
2117 assert!(
2118 text.contains("valid content to recall"),
2119 "valid messages must be included"
2120 );
2121 }
2122
2123 #[tokio::test]
2124 async fn fetch_cross_session_filters_below_threshold() {
2125 let mock = MockMemoryBackend {
2126 session_summaries: vec![
2127 MemSessionSummary {
2128 summary_text: "high relevance session".to_string(),
2129 score: 0.9,
2130 },
2131 MemSessionSummary {
2132 summary_text: "low relevance session".to_string(),
2133 score: 0.2,
2134 },
2135 ],
2136 ..Default::default()
2137 };
2138 let mut view = mock_view(mock);
2139 view.conversation_id = Some(1);
2140 view.cross_session_score_threshold = 0.5;
2141 let tc = NaiveTokenCounter;
2142 let result = fetch_cross_session(&view, "query", 1000, &tc)
2143 .await
2144 .unwrap();
2145 assert!(result.is_some());
2146 let msg = result.unwrap();
2147 let text = msg
2148 .parts
2149 .iter()
2150 .find_map(|p| {
2151 if let zeph_llm::provider::MessagePart::CrossSession { text } = p {
2152 Some(text.clone())
2153 } else {
2154 None
2155 }
2156 })
2157 .unwrap_or_default();
2158 assert!(
2159 text.contains("high relevance"),
2160 "high score must be included"
2161 );
2162 assert!(
2163 !text.contains("low relevance"),
2164 "low score must be filtered out"
2165 );
2166 }
2167
2168 #[tokio::test]
2169 async fn fetch_document_rag_skips_empty_chunks() {
2170 let mock = MockMemoryBackend {
2171 document_chunks: vec![
2172 MemDocumentChunk {
2173 text: String::new(),
2174 }, MemDocumentChunk {
2176 text: "real content here".to_string(),
2177 },
2178 ],
2179 ..Default::default()
2180 };
2181 let mut view = mock_view(mock);
2182 view.document_config.rag_enabled = true;
2183 let tc = NaiveTokenCounter;
2184 let result = fetch_document_rag(&view, "query", 1000, &tc).await.unwrap();
2185 assert!(result.is_some());
2186 let msg = result.unwrap();
2187 assert!(msg.content.contains("real content here"));
2188 assert!(!msg.content.contains("\n\n\n"));
2190 }
2191
2192 #[tokio::test]
2193 async fn fetch_graph_facts_sanitizes_injection_payloads() {
2194 let mock = MockMemoryBackend {
2196 graph_facts: vec![zeph_common::memory::MemGraphFact {
2197 fact: "fact with <script>alert(1)</script> and\nnewline".to_string(),
2198 confidence: 0.8,
2199 activation_score: None,
2200 neighbors: vec![],
2201 provenance_snippet: None,
2202 }],
2203 ..Default::default()
2204 };
2205 let mut view = mock_view(mock);
2206 view.graph_config.enabled = true;
2207 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
2208 let tc = NaiveTokenCounter;
2209 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
2210 assert!(result.is_some());
2211 let msg = result.unwrap();
2212 assert!(
2213 !msg.content.contains('<'),
2214 "angle brackets must be sanitized"
2215 );
2216 assert!(
2219 !msg.content.contains("\n\n"),
2220 "embedded newlines must be sanitized, no double-newline sequences expected"
2221 );
2222 }
2223
2224 #[tokio::test]
2225 async fn fetch_reasoning_strategies_sanitizes_injection_payloads() {
2226 let mock = MockMemoryBackend {
2228 reasoning_strategies: vec![MemReasoningStrategy {
2229 id: "s1".to_string(),
2230 outcome: "success".to_string(),
2231 summary: "strategy with <b>bold</b> and\nnewline".to_string(),
2232 }],
2233 ..Default::default()
2234 };
2235 let mut view = mock_view(mock);
2236 view.reasoning_config.enabled = true;
2237 let tc = NaiveTokenCounter;
2238 let result = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc)
2239 .await
2240 .unwrap();
2241 assert!(result.is_some());
2242 let msg = result.unwrap();
2243 assert!(
2244 !msg.content.contains('<'),
2245 "angle brackets must be sanitized in strategy summaries"
2246 );
2247 }
2248
2249 #[tokio::test]
2252 async fn fetch_persona_facts_truncates_at_budget() {
2253 let tc = NaiveTokenCounter;
2254 let first_line = "[pref] brief\n";
2256 let budget = tc.count_tokens(crate::slot::PERSONA_PREFIX) + tc.count_tokens(first_line);
2257 let mock = MockMemoryBackend {
2258 persona_facts: vec![
2259 MemPersonaFact {
2260 category: "pref".to_string(),
2261 content: "brief".to_string(),
2262 },
2263 MemPersonaFact {
2264 category: "lang".to_string(),
2265 content: "english".to_string(),
2266 },
2267 ],
2268 ..Default::default()
2269 };
2270 let mut view = mock_view(mock);
2271 view.persona_config.enabled = true;
2272 let result = fetch_persona_facts(&view, budget, &tc).await.unwrap();
2273 let msg = result.unwrap();
2274 assert!(msg.content.contains("brief"), "first fact must be included");
2275 assert!(
2276 !msg.content.contains("english"),
2277 "second fact must be truncated by budget"
2278 );
2279 }
2280
2281 #[tokio::test]
2282 async fn fetch_semantic_recall_truncates_at_budget() {
2283 let tc = NaiveTokenCounter;
2284 let first_entry = "- [user] first message\n";
2286 let budget = tc.count_tokens(RECALL_PREFIX) + tc.count_tokens(first_entry);
2287 let mock = MockMemoryBackend {
2288 recalled: vec![
2289 MemRecalledMessage {
2290 role: "user".to_string(),
2291 content: "first message".to_string(),
2292 score: 0.95,
2293 },
2294 MemRecalledMessage {
2295 role: "user".to_string(),
2296 content: "second message that should be truncated".to_string(),
2297 score: 0.80,
2298 },
2299 ],
2300 ..Default::default()
2301 };
2302 let mut view = mock_view(mock);
2303 view.recall_limit = 10;
2304 let (msg, _) = fetch_semantic_recall(&view, "query", budget, &tc, None)
2305 .await
2306 .unwrap();
2307 assert!(msg.is_some());
2308 let text = msg
2309 .unwrap()
2310 .parts
2311 .iter()
2312 .find_map(|p| {
2313 if let zeph_llm::provider::MessagePart::Recall { text } = p {
2314 Some(text.clone())
2315 } else {
2316 None
2317 }
2318 })
2319 .unwrap_or_default();
2320 assert!(
2321 text.contains("first message"),
2322 "first entry must be included"
2323 );
2324 assert!(
2325 !text.contains("second message"),
2326 "second entry must be truncated by budget"
2327 );
2328 }
2329
2330 #[tokio::test]
2333 async fn fetch_graph_facts_sanitizes_provenance_snippet() {
2334 use zeph_common::memory::MemGraphNeighbor;
2335 let mock = MockMemoryBackend {
2336 graph_facts: vec![zeph_common::memory::MemGraphFact {
2337 fact: "safe fact".to_string(),
2338 confidence: 0.9,
2339 activation_score: None,
2340 neighbors: vec![MemGraphNeighbor {
2341 fact: "neighbor".to_string(),
2342 confidence: 0.7,
2343 }],
2344 provenance_snippet: Some("source with <injection>\nand newline".to_string()),
2345 }],
2346 ..Default::default()
2347 };
2348 let mut view = mock_view(mock);
2349 view.graph_config.enabled = true;
2350 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
2351 let tc = NaiveTokenCounter;
2352 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
2353 assert!(result.is_some());
2354 let msg = result.unwrap();
2355 assert!(
2356 !msg.content.contains('<'),
2357 "angle brackets in provenance_snippet must be sanitized"
2358 );
2359 assert!(
2360 !msg.content.contains("\n\n"),
2361 "newlines in provenance_snippet must be sanitized"
2362 );
2363 assert!(
2364 msg.content.contains("[source:"),
2365 "provenance snippet must be rendered"
2366 );
2367 }
2368}