1use std::future::Future;
17use std::pin::Pin;
18
19use futures::StreamExt as _;
20use futures::stream::FuturesUnordered;
21
22use zeph_common::memory::{AsyncMemoryRouter, CompressionLevel, GraphRecallParams, TokenCounting};
23use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
24
25use crate::error::AssemblerError;
26use crate::input::ContextAssemblyInput;
27use crate::slot::ContextSlot;
28
29pub(crate) fn levels_to_flags(levels: &[CompressionLevel]) -> (bool, bool, bool) {
37 if levels.is_empty() {
38 return (true, true, true);
39 }
40 let episodic = levels.contains(&CompressionLevel::Episodic);
41 let procedural = levels.contains(&CompressionLevel::Procedural);
42 let declarative = levels.contains(&CompressionLevel::Declarative);
43 (episodic, procedural, declarative)
44}
45
46pub const SUMMARY_PREFIX: &str = "[conversation summaries]\n";
48pub const CROSS_SESSION_PREFIX: &str = "[cross-session context]\n";
50pub const RECALL_PREFIX: &str = "[semantic recall]\n";
52pub const CORRECTIONS_PREFIX: &str = "[past corrections]\n";
54pub const DOCUMENT_RAG_PREFIX: &str = "## Relevant documents\n";
56pub const GRAPH_FACTS_PREFIX: &str = "[known facts]\n";
58
59#[derive(Default)]
64pub struct PreparedContext {
65 pub graph_facts: Option<Message>,
67 pub doc_rag: Option<Message>,
69 pub corrections: Option<Message>,
71 pub recall: Option<Message>,
73 pub recall_confidence: Option<f32>,
75 pub cross_session: Option<Message>,
77 pub summaries: Option<Message>,
79 pub code_context: Option<String>,
81 pub persona_facts: Option<Message>,
83 pub trajectory_hints: Option<Message>,
85 pub tree_memory: Option<Message>,
87 pub reasoning_hints: Option<Message>,
89 pub memory_first: bool,
91 pub recent_history_budget: usize,
93 pub background_tasks: Vec<tokio::task::JoinHandle<()>>,
98}
99
100pub struct ContextAssembler;
104
105type CtxFuture<'a> = Pin<Box<dyn Future<Output = Result<ContextSlot, AssemblerError>> + Send + 'a>>;
106
107fn empty_prepared_context() -> PreparedContext {
108 PreparedContext::default()
109}
110
111fn resolve_effective_strategy(
112 memory: &crate::input::ContextMemoryView,
113 sidequest_turn_counter: u64,
114) -> zeph_config::ContextStrategy {
115 match memory.context_strategy {
116 zeph_config::ContextStrategy::MemoryFirst => zeph_config::ContextStrategy::MemoryFirst,
117 zeph_config::ContextStrategy::Adaptive => {
118 if sidequest_turn_counter >= u64::from(memory.crossover_turn_threshold) {
119 zeph_config::ContextStrategy::MemoryFirst
120 } else {
121 zeph_config::ContextStrategy::FullHistory
122 }
123 }
124 _ => zeph_config::ContextStrategy::FullHistory,
125 }
126}
127
128fn correction_params(cfg: Option<&crate::input::CorrectionConfig>) -> (usize, f32) {
129 cfg.filter(|c| c.correction_detection)
130 .map_or((3, 0.75), |c| {
131 (
132 c.correction_recall_limit as usize,
133 c.correction_min_similarity,
134 )
135 })
136}
137
138#[allow(clippy::too_many_arguments)]
145fn schedule_context_fetchers<'r>(
146 memory: &'r crate::input::ContextMemoryView,
147 tc: &'r dyn TokenCounting,
148 query: &'r str,
149 scrub: fn(&str) -> std::borrow::Cow<'_, str>,
150 index: Option<&'r dyn crate::input::IndexAccess>,
151 router_ref: &'r dyn AsyncMemoryRouter,
152 summaries_budget: usize,
153 cross_session_budget: usize,
154 semantic_recall_budget: usize,
155 code_context_budget: usize,
156 graph_facts_budget: usize,
157 recall_limit: usize,
158 min_sim: f32,
159 active_levels: &[CompressionLevel],
160) -> FuturesUnordered<CtxFuture<'r>> {
161 let (episodic_active, procedural_active, declarative_active) = levels_to_flags(active_levels);
165
166 let fetchers: FuturesUnordered<CtxFuture<'r>> = FuturesUnordered::new();
167
168 if episodic_active && summaries_budget > 0 {
169 fetchers.push(Box::pin(async move {
170 fetch_summaries(memory, summaries_budget, tc)
171 .await
172 .map(ContextSlot::Summaries)
173 }));
174 }
175 if episodic_active && cross_session_budget > 0 {
176 fetchers.push(Box::pin(async move {
177 fetch_cross_session(memory, query, cross_session_budget, tc)
178 .await
179 .map(ContextSlot::CrossSession)
180 }));
181 }
182 if episodic_active && semantic_recall_budget > 0 {
183 fetchers.push(Box::pin(async move {
184 fetch_semantic_recall(memory, query, semantic_recall_budget, tc, Some(router_ref))
185 .await
186 .map(|(msg, score)| ContextSlot::SemanticRecall(msg, score))
187 }));
188 fetchers.push(Box::pin(async move {
189 fetch_document_rag(memory, query, semantic_recall_budget, tc)
190 .await
191 .map(ContextSlot::DocumentRag)
192 }));
193 }
194 fetchers.push(Box::pin(async move {
196 fetch_corrections(memory, query, recall_limit, min_sim, scrub)
197 .await
198 .map(ContextSlot::Corrections)
199 }));
200 if code_context_budget > 0
202 && let Some(idx) = index
203 {
204 fetchers.push(Box::pin(async move {
205 let result: Result<Option<String>, AssemblerError> =
206 idx.fetch_code_rag(query, code_context_budget).await;
207 result.map(ContextSlot::CodeContext)
208 }));
209 }
210 if declarative_active && graph_facts_budget > 0 {
211 fetchers.push(Box::pin(async move {
212 fetch_graph_facts(memory, query, graph_facts_budget, tc)
213 .await
214 .map(ContextSlot::GraphFacts)
215 }));
216 }
217 if declarative_active && memory.persona_config.context_budget_tokens > 0 {
218 fetchers.push(Box::pin(async move {
219 let persona_budget = memory.persona_config.context_budget_tokens;
220 fetch_persona_facts(memory, persona_budget, tc)
221 .await
222 .map(ContextSlot::PersonaFacts)
223 }));
224 }
225 if procedural_active && memory.trajectory_config.context_budget_tokens > 0 {
226 fetchers.push(Box::pin(async move {
227 let tbudget = memory.trajectory_config.context_budget_tokens;
228 fetch_trajectory_hints(memory, tbudget, tc)
229 .await
230 .map(ContextSlot::TrajectoryHints)
231 }));
232 }
233 if declarative_active && memory.tree_config.context_budget_tokens > 0 {
234 fetchers.push(Box::pin(async move {
235 let tbudget = memory.tree_config.context_budget_tokens;
236 fetch_tree_memory(memory, tbudget, tc)
237 .await
238 .map(ContextSlot::TreeMemory)
239 }));
240 }
241 if procedural_active
242 && memory.reasoning_config.enabled
243 && memory.reasoning_config.context_budget_tokens > 0
244 {
245 fetchers.push(Box::pin(async move {
246 let rbudget = memory.reasoning_config.context_budget_tokens;
247 let top_k = memory.reasoning_config.top_k;
248 fetch_reasoning_strategies(memory, query, rbudget, top_k, tc)
249 .await
250 .map(|(msg, handle)| ContextSlot::ReasoningStrategies(msg, handle))
251 }));
252 }
253
254 fetchers
255}
256
257async fn drive_fetchers(
258 mut fetchers: FuturesUnordered<CtxFuture<'_>>,
259 prepared: &mut PreparedContext,
260) -> Result<(), AssemblerError> {
261 while let Some(result) = fetchers.next().await {
262 match result {
263 Ok(slot) => match slot {
264 ContextSlot::Summaries(msg) => prepared.summaries = msg,
265 ContextSlot::CrossSession(msg) => prepared.cross_session = msg,
266 ContextSlot::SemanticRecall(msg, score) => {
267 prepared.recall = msg;
268 prepared.recall_confidence = score;
269 }
270 ContextSlot::DocumentRag(msg) => prepared.doc_rag = msg,
271 ContextSlot::Corrections(msg) => prepared.corrections = msg,
272 ContextSlot::CodeContext(text) => prepared.code_context = text,
273 ContextSlot::GraphFacts(msg) => prepared.graph_facts = msg,
274 ContextSlot::PersonaFacts(msg) => prepared.persona_facts = msg,
275 ContextSlot::TrajectoryHints(msg) => prepared.trajectory_hints = msg,
276 ContextSlot::TreeMemory(msg) => prepared.tree_memory = msg,
277 ContextSlot::ReasoningStrategies(msg, handle) => {
278 prepared.reasoning_hints = msg;
279 if let Some(h) = handle {
280 prepared.background_tasks.push(h);
281 }
282 }
283 },
284 Err(e) => return Err(e),
285 }
286 }
287 Ok(())
288}
289
290impl ContextAssembler {
291 #[tracing::instrument(name = "context.assembler.gather", skip_all)]
299 pub async fn gather(
300 input: &ContextAssemblyInput<'_>,
301 ) -> Result<PreparedContext, AssemblerError> {
302 let Some(ref budget) = input.context_manager.budget else {
303 return Ok(empty_prepared_context());
304 };
305
306 let memory = input.memory;
307 let tc = input.token_counter;
308
309 let effective_strategy = resolve_effective_strategy(memory, input.sidequest_turn_counter);
310 let memory_first = effective_strategy == zeph_config::ContextStrategy::MemoryFirst;
311
312 let system_prompt = input
313 .messages
314 .first()
315 .filter(|m| m.role == Role::System)
316 .map_or("", |m| m.content.as_str());
317
318 let digest_tokens = memory
319 .cached_session_digest
320 .as_ref()
321 .map_or(0, |(_, tokens)| *tokens);
322
323 let alloc = budget.allocate_with_opts(
324 system_prompt,
325 input.skills_prompt,
326 tc,
327 memory.graph_config.enabled,
328 digest_tokens,
329 memory_first,
330 );
331
332 let (recall_limit, min_sim) = correction_params(input.correction_config.as_ref());
333
334 let router_ref: &dyn AsyncMemoryRouter = input.router.as_ref();
335
336 tracing::debug!(
337 active_sources = alloc.active_sources(),
338 active_levels = ?input.active_levels,
339 "context budget allocated"
340 );
341
342 let fetchers = schedule_context_fetchers(
343 memory,
344 tc,
345 input.query,
346 input.scrub,
347 input.index,
348 router_ref,
349 alloc.summaries,
350 alloc.cross_session,
351 alloc.semantic_recall,
352 alloc.code_context,
353 alloc.graph_facts,
354 recall_limit,
355 min_sim,
356 input.active_levels,
357 );
358
359 let mut prepared = empty_prepared_context();
360 prepared.memory_first = memory_first;
361 prepared.recent_history_budget = alloc.recent_history;
362
363 drive_fetchers(fetchers, &mut prepared).await?;
364 Ok(prepared)
365 }
366}
367
368pub fn effective_recall_timeout_ms(configured: u64) -> u64 {
373 if configured == 0 {
374 tracing::warn!(
375 "recall_timeout_ms is 0, which would disable spreading activation recall; \
376 clamping to 100ms"
377 );
378 100
379 } else {
380 configured
381 }
382}
383
384use crate::input::ContextMemoryView;
385
386#[tracing::instrument(name = "context.graph_facts", skip_all)]
387#[allow(clippy::too_many_lines)] pub(crate) async fn fetch_graph_facts(
389 memory: &ContextMemoryView,
390 query: &str,
391 budget_tokens: usize,
392 tc: &dyn TokenCounting,
393) -> Result<Option<Message>, AssemblerError> {
394 use zeph_common::memory::{RecallView, SpreadingActivationParams, classify_graph_subgraph};
395
396 if budget_tokens == 0 || !memory.graph_config.enabled {
397 return Ok(None);
398 }
399 let Some(ref mem) = memory.memory else {
400 return Ok(None);
401 };
402 let recall_limit = memory.graph_config.recall_limit;
403 let temporal_decay_rate = memory.graph_config.temporal_decay_rate;
404 let sa_config = &memory.graph_config.spreading_activation;
405
406 let fused_query;
408 let effective_query = if let Some(ref state) = memory.memcot_state {
409 let max_state_chars = 2 * query.len();
410 let state_slice = if state.len() > max_state_chars {
411 let boundary = state.floor_char_boundary(max_state_chars);
412 &state[..boundary]
413 } else {
414 state.as_str()
415 };
416 fused_query = format!("[state] {state_slice}\n{query}");
417 &fused_query as &str
418 } else {
419 query
420 };
421
422 let edge_types = classify_graph_subgraph(effective_query);
423
424 let view = match memory.memcot_config.recall_view {
425 zeph_config::RecallViewConfig::ZoomIn => RecallView::ZoomIn,
426 zeph_config::RecallViewConfig::ZoomOut => RecallView::ZoomOut,
427 _ => RecallView::Head,
428 };
429
430 let sa_params = if sa_config.enabled {
431 Some(SpreadingActivationParams {
432 decay_lambda: sa_config.decay_lambda,
433 max_hops: sa_config.max_hops,
434 activation_threshold: sa_config.activation_threshold,
435 inhibition_threshold: sa_config.inhibition_threshold,
436 max_activated_nodes: sa_config.max_activated_nodes,
437 temporal_decay_rate,
438 seed_structural_weight: sa_config.seed_structural_weight,
439 seed_community_cap: sa_config.seed_community_cap,
440 alpha: sa_config.alpha,
441 })
442 } else {
443 None
444 };
445
446 let timeout_ms = effective_recall_timeout_ms(sa_config.recall_timeout_ms);
447 let recall_fut = mem.recall_graph_facts(
448 effective_query,
449 GraphRecallParams {
450 limit: recall_limit,
451 view,
452 zoom_out_neighbor_cap: memory.memcot_config.zoom_out_neighbor_cap,
453 max_hops: memory.graph_config.max_hops,
454 temporal_decay_rate,
455 edge_types: &edge_types,
456 spreading_activation: sa_params,
457 },
458 );
459 let recalled = match tokio::time::timeout(
460 std::time::Duration::from_millis(timeout_ms),
461 recall_fut,
462 )
463 .await
464 {
465 Ok(Ok(facts)) => facts,
466 Ok(Err(e)) => {
467 tracing::warn!("graph recall failed: {e:#}");
468 Vec::new()
469 }
470 Err(_) => {
471 tracing::warn!("graph recall timed out ({timeout_ms}ms)");
472 Vec::new()
473 }
474 };
475
476 if recalled.is_empty() {
477 return Ok(None);
478 }
479
480 let mut body = String::from(GRAPH_FACTS_PREFIX);
481 let mut tokens_so_far = tc.count_tokens(&body);
482
483 for rf in &recalled {
484 let fact_text = rf.fact.replace(['\n', '\r', '<', '>'], " ");
485 let line = if let Some(score) = rf.activation_score {
486 format!(
487 "- {} (confidence: {:.2}, activation: {:.2})\n",
488 fact_text, rf.confidence, score
489 )
490 } else {
491 format!("- {} (confidence: {:.2})\n", fact_text, rf.confidence)
492 };
493 let line_tokens = tc.count_tokens(&line);
494 if tokens_so_far + line_tokens > budget_tokens {
495 break;
496 }
497 body.push_str(&line);
498 tokens_so_far += line_tokens;
499
500 for nb in &rf.neighbors {
502 let nb_text = nb.fact.replace(['\n', '\r', '<', '>'], " ");
503 let nb_line = format!(" ~ {} (confidence: {:.2})\n", nb_text, nb.confidence);
504 let nb_tokens = tc.count_tokens(&nb_line);
505 if tokens_so_far + nb_tokens > budget_tokens {
506 break;
507 }
508 body.push_str(&nb_line);
509 tokens_so_far += nb_tokens;
510 }
511
512 if let Some(ref snippet) = rf.provenance_snippet {
514 let snip_line = format!(
515 " [source: {}]\n",
516 snippet.replace(['\n', '\r', '<', '>'], " ")
517 );
518 let snip_tokens = tc.count_tokens(&snip_line);
519 if tokens_so_far + snip_tokens <= budget_tokens {
520 body.push_str(&snip_line);
521 tokens_so_far += snip_tokens;
522 }
523 }
524 }
525
526 if body == GRAPH_FACTS_PREFIX {
527 return Ok(None);
528 }
529
530 Ok(Some(Message::from_legacy(Role::System, body)))
531}
532
533#[tracing::instrument(name = "context.persona_facts", skip_all)]
534pub(crate) async fn fetch_persona_facts(
535 memory: &ContextMemoryView,
536 budget_tokens: usize,
537 tc: &dyn TokenCounting,
538) -> Result<Option<Message>, AssemblerError> {
539 if budget_tokens == 0 || !memory.persona_config.enabled {
540 return Ok(None);
541 }
542 let Some(ref mem) = memory.memory else {
543 return Ok(None);
544 };
545
546 let min_confidence = memory.persona_config.min_confidence;
547 let facts = mem
548 .load_persona_facts(min_confidence)
549 .await
550 .map_err(AssemblerError::Memory)?;
551
552 if facts.is_empty() {
553 return Ok(None);
554 }
555
556 let mut body = String::from(crate::slot::PERSONA_PREFIX);
557 let mut tokens_so_far = tc.count_tokens(&body);
558
559 for fact in &facts {
560 let line = format!("[{}] {}\n", fact.category, fact.content);
561 let line_tokens = tc.count_tokens(&line);
562 if tokens_so_far + line_tokens > budget_tokens {
563 break;
564 }
565 body.push_str(&line);
566 tokens_so_far += line_tokens;
567 }
568
569 if body == crate::slot::PERSONA_PREFIX {
570 return Ok(None);
571 }
572
573 Ok(Some(Message::from_legacy(Role::System, body)))
574}
575
576#[tracing::instrument(name = "context.trajectory_hints", skip_all)]
577pub(crate) async fn fetch_trajectory_hints(
578 memory: &ContextMemoryView,
579 budget_tokens: usize,
580 tc: &dyn TokenCounting,
581) -> Result<Option<Message>, AssemblerError> {
582 if budget_tokens == 0 || !memory.trajectory_config.enabled {
583 return Ok(None);
584 }
585 let Some(ref mem) = memory.memory else {
586 return Ok(None);
587 };
588
589 let top_k = memory.trajectory_config.recall_top_k;
590 let min_conf = memory.trajectory_config.min_confidence;
591 let entries = mem
595 .load_trajectory_entries(Some("procedural"), top_k)
596 .await
597 .map_err(AssemblerError::Memory)?;
598
599 if entries.is_empty() {
600 return Ok(None);
601 }
602
603 let mut body = String::from(crate::slot::TRAJECTORY_PREFIX);
604 let mut tokens_so_far = tc.count_tokens(&body);
605
606 for entry in entries
607 .iter()
608 .filter(|e| e.confidence >= min_conf)
609 .take(top_k)
610 {
611 let line = format!("- {}: {}\n", entry.intent, entry.outcome);
612 let line_tokens = tc.count_tokens(&line);
613 if tokens_so_far + line_tokens > budget_tokens {
614 break;
615 }
616 body.push_str(&line);
617 tokens_so_far += line_tokens;
618 }
619
620 if body == crate::slot::TRAJECTORY_PREFIX {
621 return Ok(None);
622 }
623
624 Ok(Some(Message::from_legacy(Role::System, body)))
625}
626
627#[tracing::instrument(name = "context.tree_memory", skip_all)]
628pub(crate) async fn fetch_tree_memory(
629 memory: &ContextMemoryView,
630 budget_tokens: usize,
631 tc: &dyn TokenCounting,
632) -> Result<Option<Message>, AssemblerError> {
633 if budget_tokens == 0 || !memory.tree_config.enabled {
634 return Ok(None);
635 }
636 let Some(ref mem) = memory.memory else {
637 return Ok(None);
638 };
639
640 let top_k = memory.tree_config.recall_top_k;
641 let nodes = mem
642 .load_tree_nodes(1, top_k)
643 .await
644 .map_err(AssemblerError::Memory)?;
645
646 if nodes.is_empty() {
647 return Ok(None);
648 }
649
650 let mut body = String::from(crate::slot::TREE_MEMORY_PREFIX);
651 let mut tokens_so_far = tc.count_tokens(&body);
652
653 for node in nodes.iter().take(top_k) {
654 let line = format!("- {}\n", node.content);
655 let line_tokens = tc.count_tokens(&line);
656 if tokens_so_far + line_tokens > budget_tokens {
657 break;
658 }
659 body.push_str(&line);
660 tokens_so_far += line_tokens;
661 }
662
663 if body == crate::slot::TREE_MEMORY_PREFIX {
664 return Ok(None);
665 }
666
667 Ok(Some(Message::from_legacy(Role::System, body)))
668}
669
670#[tracing::instrument(name = "context.reasoning_strategies", skip_all)]
671pub(crate) async fn fetch_reasoning_strategies(
672 memory: &ContextMemoryView,
673 query: &str,
674 budget_tokens: usize,
675 top_k: usize,
676 tc: &dyn TokenCounting,
677) -> Result<(Option<Message>, Option<tokio::task::JoinHandle<()>>), AssemblerError> {
678 let budget_tokens = budget_tokens.min(500);
680 if budget_tokens == 0 {
681 return Ok((None, None));
682 }
683 let Some(ref mem) = memory.memory else {
684 return Ok((None, None));
685 };
686
687 let strategies = mem
688 .retrieve_reasoning_strategies(query, top_k)
689 .await
690 .map_err(AssemblerError::Memory)?;
691
692 if strategies.is_empty() {
693 return Ok((None, None));
694 }
695
696 let mut body = String::from(crate::slot::REASONING_PREFIX);
697 let mut tokens_so_far = tc.count_tokens(&body);
698 let mut injected_ids: Vec<String> = Vec::new();
699
700 for s in strategies.iter().take(top_k) {
701 let safe_summary = s.summary.replace(['\n', '\r', '<', '>'], " ");
704 let line = format!("- [{}] {}\n", s.outcome, safe_summary);
705 let line_tokens = tc.count_tokens(&line);
706 if tokens_so_far + line_tokens > budget_tokens {
707 break;
708 }
709 body.push_str(&line);
710 tokens_so_far += line_tokens;
711 injected_ids.push(s.id.clone());
712 }
713
714 if body == crate::slot::REASONING_PREFIX {
715 return Ok((None, None));
716 }
717
718 let handle = if injected_ids.is_empty() {
722 None
723 } else {
724 let mem_clone = mem.clone();
725 Some(tokio::spawn(async move {
726 if let Err(e) = mem_clone.mark_reasoning_used(&injected_ids).await {
727 tracing::warn!(error = %e, "reasoning: mark_used failed");
728 }
729 }))
730 };
731
732 Ok((Some(Message::from_legacy(Role::System, body)), handle))
733}
734
735#[tracing::instrument(name = "context.corrections", skip_all)]
736pub(crate) async fn fetch_corrections(
737 memory: &ContextMemoryView,
738 query: &str,
739 limit: usize,
740 min_score: f32,
741 scrub: fn(&str) -> std::borrow::Cow<'_, str>,
742) -> Result<Option<Message>, AssemblerError> {
743 let Some(ref mem) = memory.memory else {
744 return Ok(None);
745 };
746 let corrections = mem
747 .retrieve_corrections(query, limit, min_score)
748 .await
749 .map_err(AssemblerError::Memory)?;
750 if corrections.is_empty() {
751 return Ok(None);
752 }
753 let mut text = String::from(CORRECTIONS_PREFIX);
754 for c in &corrections {
755 text.push_str("- Past user correction: \"");
756 text.push_str(&scrub(&c.correction_text));
757 text.push_str("\"\n");
758 }
759 Ok(Some(Message::from_legacy(Role::System, text)))
760}
761
762#[tracing::instrument(name = "context.semantic_recall", skip_all)]
763pub(crate) async fn fetch_semantic_recall(
764 memory: &ContextMemoryView,
765 query: &str,
766 token_budget: usize,
767 tc: &dyn TokenCounting,
768 router: Option<&dyn AsyncMemoryRouter>,
769) -> Result<(Option<Message>, Option<f32>), AssemblerError> {
770 let Some(ref mem) = memory.memory else {
771 return Ok((None, None));
772 };
773 if memory.recall_limit == 0 || token_budget == 0 {
774 return Ok((None, None));
775 }
776
777 let recalled = mem
778 .recall(query, memory.recall_limit, router)
779 .await
780 .map_err(AssemblerError::Memory)?;
781 if recalled.is_empty() {
782 return Ok((None, None));
783 }
784
785 let top_score = recalled.first().map(|r| r.score);
786
787 let mut recall_text = String::with_capacity(token_budget * 3);
788 recall_text.push_str(RECALL_PREFIX);
789 let mut tokens_used = tc.count_tokens(&recall_text);
790
791 for item in &recalled {
792 if item.content.starts_with("[skipped]") || item.content.starts_with("[stopped]") {
793 continue;
794 }
795 let entry = format!("- [{}] {}\n", item.role, item.content);
796 let entry_tokens = tc.count_tokens(&entry);
797 if tokens_used + entry_tokens > token_budget {
798 break;
799 }
800 recall_text.push_str(&entry);
801 tokens_used += entry_tokens;
802 }
803
804 if tokens_used > tc.count_tokens(RECALL_PREFIX) {
805 Ok((
806 Some(Message::from_parts(
807 Role::System,
808 vec![MessagePart::Recall { text: recall_text }],
809 )),
810 top_score,
811 ))
812 } else {
813 Ok((None, None))
814 }
815}
816
817#[tracing::instrument(name = "context.document_rag", skip_all)]
818pub(crate) async fn fetch_document_rag(
819 memory: &ContextMemoryView,
820 query: &str,
821 token_budget: usize,
822 tc: &dyn TokenCounting,
823) -> Result<Option<Message>, AssemblerError> {
824 if !memory.document_config.rag_enabled || token_budget == 0 {
825 return Ok(None);
826 }
827 let Some(ref mem) = memory.memory else {
828 return Ok(None);
829 };
830
831 let collection = &memory.document_config.collection;
832 let top_k = memory.document_config.top_k;
833 let chunks = mem
834 .search_document_collection(collection, query, top_k)
835 .await
836 .map_err(AssemblerError::Memory)?;
837 if chunks.is_empty() {
838 return Ok(None);
839 }
840
841 let mut text = String::from(DOCUMENT_RAG_PREFIX);
842 let mut tokens_used = tc.count_tokens(&text);
843
844 for chunk in &chunks {
845 if chunk.text.is_empty() {
846 continue;
847 }
848 let entry = format!("{}\n", chunk.text);
849 let cost = tc.count_tokens(&entry);
850 if tokens_used + cost > token_budget {
851 break;
852 }
853 text.push_str(&entry);
854 tokens_used += cost;
855 }
856
857 if tokens_used > tc.count_tokens(DOCUMENT_RAG_PREFIX) {
858 Ok(Some(Message {
859 role: Role::System,
860 content: text,
861 parts: vec![],
862 metadata: MessageMetadata::default(),
863 }))
864 } else {
865 Ok(None)
866 }
867}
868
869#[tracing::instrument(name = "context.summaries", skip_all)]
870pub(crate) async fn fetch_summaries(
871 memory: &ContextMemoryView,
872 token_budget: usize,
873 tc: &dyn TokenCounting,
874) -> Result<Option<Message>, AssemblerError> {
875 let (Some(mem), Some(cid)) = (&memory.memory, memory.conversation_id) else {
876 return Ok(None);
877 };
878 if token_budget == 0 {
879 return Ok(None);
880 }
881
882 let summaries = mem
883 .load_summaries(cid)
884 .await
885 .map_err(AssemblerError::Memory)?;
886 if summaries.is_empty() {
887 return Ok(None);
888 }
889
890 let mut summary_text = String::from(SUMMARY_PREFIX);
891 let mut tokens_used = tc.count_tokens(&summary_text);
892
893 for summary in summaries.iter().rev() {
894 let first = summary.first_message_id.unwrap_or(0);
895 let last = summary.last_message_id.unwrap_or(0);
896 let entry = format!("- Messages {first}-{last}: {}\n", summary.content);
897 let cost = tc.count_tokens(&entry);
898 if tokens_used + cost > token_budget {
899 break;
900 }
901 summary_text.push_str(&entry);
902 tokens_used += cost;
903 }
904
905 if tokens_used > tc.count_tokens(SUMMARY_PREFIX) {
906 Ok(Some(Message::from_parts(
907 Role::System,
908 vec![MessagePart::Summary { text: summary_text }],
909 )))
910 } else {
911 Ok(None)
912 }
913}
914
915#[tracing::instrument(name = "context.cross_session", skip_all)]
916pub(crate) async fn fetch_cross_session(
917 memory: &ContextMemoryView,
918 query: &str,
919 token_budget: usize,
920 tc: &dyn TokenCounting,
921) -> Result<Option<Message>, AssemblerError> {
922 let (Some(mem), Some(cid)) = (&memory.memory, memory.conversation_id) else {
923 return Ok(None);
924 };
925 if token_budget == 0 {
926 return Ok(None);
927 }
928
929 let threshold = memory.cross_session_score_threshold;
930 let results: Vec<_> = mem
931 .search_session_summaries(query, 5, Some(cid))
932 .await
933 .map_err(AssemblerError::Memory)?
934 .into_iter()
935 .filter(|r| r.score >= threshold)
936 .collect();
937 if results.is_empty() {
938 return Ok(None);
939 }
940
941 let mut text = String::from(CROSS_SESSION_PREFIX);
942 let mut tokens_used = tc.count_tokens(&text);
943
944 for item in &results {
945 let entry = format!("- {}\n", item.summary_text);
946 let cost = tc.count_tokens(&entry);
947 if tokens_used + cost > token_budget {
948 break;
949 }
950 text.push_str(&entry);
951 tokens_used += cost;
952 }
953
954 if tokens_used > tc.count_tokens(CROSS_SESSION_PREFIX) {
955 Ok(Some(Message::from_parts(
956 Role::System,
957 vec![MessagePart::CrossSession { text }],
958 )))
959 } else {
960 Ok(None)
961 }
962}
963
964pub const MAX_KEEP_TAIL_SCAN: usize = 50;
967
968#[must_use]
976pub fn memory_first_keep_tail(messages: &[Message], history_start: usize) -> usize {
977 use zeph_llm::provider::MessagePart;
978
979 let mut keep_tail = 2usize;
980 let len = messages.len();
981 let max = len.saturating_sub(history_start);
982
983 while keep_tail < max {
984 let first_retained = &messages[len - keep_tail];
985 let is_tool_result = first_retained.role == Role::User
986 && first_retained
987 .parts
988 .iter()
989 .any(|p| matches!(p, MessagePart::ToolResult { .. }));
990
991 if is_tool_result {
992 keep_tail += 1;
993 } else {
994 break;
995 }
996
997 if keep_tail >= MAX_KEEP_TAIL_SCAN {
998 let preceding_idx = len.saturating_sub(keep_tail + 1);
999 if preceding_idx >= history_start {
1000 let preceding = &messages[preceding_idx];
1001 let is_tool_use = preceding.role == Role::Assistant
1002 && preceding
1003 .parts
1004 .iter()
1005 .any(|p| matches!(p, MessagePart::ToolUse { .. }));
1006 if is_tool_use {
1007 keep_tail += 1;
1008 }
1009 }
1010 break;
1011 }
1012 }
1013
1014 keep_tail
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019 use super::*;
1020 use crate::input::ContextMemoryView;
1021 use zeph_common::memory::CompressionLevel;
1022 use zeph_config::{
1023 ContextStrategy, DocumentConfig, GraphConfig, PersonaConfig, ReasoningConfig,
1024 TrajectoryConfig, TreeConfig,
1025 };
1026
1027 struct NaiveTokenCounter;
1028 impl zeph_common::memory::TokenCounting for NaiveTokenCounter {
1029 fn count_tokens(&self, text: &str) -> usize {
1030 text.split_whitespace().count()
1031 }
1032 fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
1033 schema.to_string().split_whitespace().count()
1034 }
1035 }
1036
1037 fn empty_view() -> ContextMemoryView {
1038 ContextMemoryView {
1039 memory: None,
1040 conversation_id: None,
1041 recall_limit: 10,
1042 cross_session_score_threshold: 0.5,
1043 context_strategy: ContextStrategy::default(),
1044 crossover_turn_threshold: 5,
1045 cached_session_digest: None,
1046 graph_config: GraphConfig::default(),
1047 document_config: DocumentConfig::default(),
1048 persona_config: PersonaConfig::default(),
1049 trajectory_config: TrajectoryConfig::default(),
1050 reasoning_config: ReasoningConfig::default(),
1051 memcot_config: zeph_config::MemCotConfig::default(),
1052 memcot_state: None,
1053 tree_config: TreeConfig::default(),
1054 }
1055 }
1056
1057 #[tokio::test]
1060 async fn fetch_graph_facts_returns_none_when_memory_is_none() {
1061 let view = empty_view();
1062 let tc = NaiveTokenCounter;
1063 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1064 assert!(result.is_none());
1065 }
1066
1067 #[tokio::test]
1068 async fn fetch_graph_facts_returns_none_when_budget_zero() {
1069 let mut view = empty_view();
1070 view.graph_config.enabled = true;
1071 let tc = NaiveTokenCounter;
1072 let result = fetch_graph_facts(&view, "test", 0, &tc).await.unwrap();
1073 assert!(result.is_none());
1074 }
1075
1076 #[tokio::test]
1077 async fn fetch_graph_facts_returns_none_when_graph_disabled() {
1078 let mut view = empty_view();
1079 view.graph_config.enabled = false;
1080 let tc = NaiveTokenCounter;
1081 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1082 assert!(result.is_none());
1083 }
1084
1085 #[tokio::test]
1088 async fn fetch_persona_facts_returns_none_when_memory_is_none() {
1089 let view = empty_view();
1090 let tc = NaiveTokenCounter;
1091 let result = fetch_persona_facts(&view, 1000, &tc).await.unwrap();
1092 assert!(result.is_none());
1093 }
1094
1095 #[tokio::test]
1096 async fn fetch_persona_facts_returns_none_when_budget_zero() {
1097 let mut view = empty_view();
1098 view.persona_config.enabled = true;
1099 let tc = NaiveTokenCounter;
1100 let result = fetch_persona_facts(&view, 0, &tc).await.unwrap();
1101 assert!(result.is_none());
1102 }
1103
1104 #[tokio::test]
1107 async fn fetch_trajectory_hints_returns_none_when_memory_is_none() {
1108 let view = empty_view();
1109 let tc = NaiveTokenCounter;
1110 let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1111 assert!(result.is_none());
1112 }
1113
1114 #[tokio::test]
1115 async fn fetch_trajectory_hints_returns_none_when_budget_zero() {
1116 let mut view = empty_view();
1117 view.trajectory_config.enabled = true;
1118 let tc = NaiveTokenCounter;
1119 let result = fetch_trajectory_hints(&view, 0, &tc).await.unwrap();
1120 assert!(result.is_none());
1121 }
1122
1123 #[tokio::test]
1126 async fn fetch_tree_memory_returns_none_when_memory_is_none() {
1127 let view = empty_view();
1128 let tc = NaiveTokenCounter;
1129 let result = fetch_tree_memory(&view, 1000, &tc).await.unwrap();
1130 assert!(result.is_none());
1131 }
1132
1133 #[tokio::test]
1134 async fn fetch_tree_memory_returns_none_when_budget_zero() {
1135 let mut view = empty_view();
1136 view.tree_config.enabled = true;
1137 let tc = NaiveTokenCounter;
1138 let result = fetch_tree_memory(&view, 0, &tc).await.unwrap();
1139 assert!(result.is_none());
1140 }
1141
1142 #[tokio::test]
1145 async fn fetch_corrections_returns_none_when_memory_is_none() {
1146 let view = empty_view();
1147 let result = fetch_corrections(&view, "test", 10, 0.5, |s| s.into())
1148 .await
1149 .unwrap();
1150 assert!(result.is_none());
1151 }
1152
1153 #[tokio::test]
1156 async fn fetch_semantic_recall_returns_none_when_memory_is_none() {
1157 let view = empty_view();
1158 let tc = NaiveTokenCounter;
1159 let result = fetch_semantic_recall(&view, "test", 1000, &tc, None)
1160 .await
1161 .unwrap();
1162 assert!(result.0.is_none() && result.1.is_none());
1163 }
1164
1165 #[tokio::test]
1166 async fn fetch_semantic_recall_returns_none_when_budget_zero() {
1167 let view = empty_view();
1168 let tc = NaiveTokenCounter;
1169 let result = fetch_semantic_recall(&view, "test", 0, &tc, None)
1170 .await
1171 .unwrap();
1172 assert!(result.0.is_none() && result.1.is_none());
1173 }
1174
1175 #[tokio::test]
1178 async fn fetch_document_rag_returns_none_when_memory_is_none() {
1179 let mut view = empty_view();
1180 view.document_config.rag_enabled = true;
1181 let tc = NaiveTokenCounter;
1182 let result = fetch_document_rag(&view, "test", 1000, &tc).await.unwrap();
1183 assert!(result.is_none());
1184 }
1185
1186 #[tokio::test]
1187 async fn fetch_document_rag_returns_none_when_rag_disabled() {
1188 let view = empty_view();
1189 let tc = NaiveTokenCounter;
1190 let result = fetch_document_rag(&view, "test", 1000, &tc).await.unwrap();
1191 assert!(result.is_none());
1192 }
1193
1194 #[tokio::test]
1197 async fn fetch_summaries_returns_none_when_memory_is_none() {
1198 let view = empty_view();
1199 let tc = NaiveTokenCounter;
1200 let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1201 assert!(result.is_none());
1202 }
1203
1204 #[tokio::test]
1207 async fn fetch_cross_session_returns_none_when_memory_is_none() {
1208 let view = empty_view();
1209 let tc = NaiveTokenCounter;
1210 let result = fetch_cross_session(&view, "test", 1000, &tc).await.unwrap();
1211 assert!(result.is_none());
1212 }
1213
1214 #[test]
1217 fn levels_to_flags_empty_slice_enables_all_tiers() {
1218 let (e, p, d) = levels_to_flags(&[]);
1219 assert!(e, "episodic should be active for empty slice");
1220 assert!(p, "procedural should be active for empty slice");
1221 assert!(d, "declarative should be active for empty slice");
1222 }
1223
1224 #[test]
1225 fn levels_to_flags_full_set_enables_all_tiers() {
1226 let all = &[
1227 CompressionLevel::Episodic,
1228 CompressionLevel::Procedural,
1229 CompressionLevel::Declarative,
1230 ];
1231 let (e, p, d) = levels_to_flags(all);
1232 assert!(e);
1233 assert!(p);
1234 assert!(d);
1235 }
1236
1237 #[test]
1238 fn levels_to_flags_episodic_only() {
1239 let (e, p, d) = levels_to_flags(&[CompressionLevel::Episodic]);
1240 assert!(e);
1241 assert!(!p, "procedural should be inactive");
1242 assert!(!d, "declarative should be inactive");
1243 }
1244
1245 #[test]
1246 fn levels_to_flags_episodic_and_procedural() {
1247 let (e, p, d) =
1248 levels_to_flags(&[CompressionLevel::Episodic, CompressionLevel::Procedural]);
1249 assert!(e);
1250 assert!(p);
1251 assert!(!d, "declarative should be inactive");
1252 }
1253
1254 #[test]
1255 fn levels_to_flags_declarative_only() {
1256 let (e, p, d) = levels_to_flags(&[CompressionLevel::Declarative]);
1257 assert!(!e, "episodic should be inactive");
1258 assert!(!p, "procedural should be inactive");
1259 assert!(d);
1260 }
1261
1262 #[tokio::test]
1265 async fn fetch_reasoning_strategies_returns_none_when_memory_is_none() {
1266 let mut view = empty_view();
1267 view.reasoning_config.enabled = true;
1268 let tc = NaiveTokenCounter;
1269 let (result, handle) = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc)
1270 .await
1271 .unwrap();
1272 assert!(result.is_none());
1273 assert!(handle.is_none());
1274 }
1275
1276 #[tokio::test]
1277 async fn fetch_reasoning_strategies_returns_none_when_budget_zero() {
1278 let mut view = empty_view();
1279 view.reasoning_config.enabled = true;
1280 let tc = NaiveTokenCounter;
1281 let (result, handle) = fetch_reasoning_strategies(&view, "query", 0, 3, &tc)
1282 .await
1283 .unwrap();
1284 assert!(result.is_none());
1285 assert!(handle.is_none());
1286 }
1287
1288 use std::sync::{Arc, Mutex};
1291 use zeph_common::memory::{
1292 ContextMemoryBackend, GraphRecallParams, MemCorrection, MemDocumentChunk, MemGraphFact,
1293 MemPersonaFact, MemReasoningStrategy, MemRecalledMessage, MemSessionSummary, MemSummary,
1294 MemTrajectoryEntry, MemTreeNode,
1295 };
1296
1297 const KNOWN_FAIL_ON: &[&str] = &[
1299 "load_persona_facts",
1300 "load_trajectory_entries",
1301 "load_tree_nodes",
1302 "load_summaries",
1303 "retrieve_reasoning_strategies",
1304 "mark_reasoning_used",
1305 "retrieve_corrections",
1306 "recall",
1307 "recall_graph_facts",
1308 "search_session_summaries",
1309 "search_document_collection",
1310 ];
1311
1312 #[derive(Default)]
1313 struct MockMemoryBackend {
1314 persona_facts: Vec<MemPersonaFact>,
1315 trajectory_entries: Vec<MemTrajectoryEntry>,
1316 tree_nodes: Vec<MemTreeNode>,
1317 summaries: Vec<MemSummary>,
1318 reasoning_strategies: Vec<MemReasoningStrategy>,
1319 corrections: Vec<MemCorrection>,
1320 recalled: Vec<MemRecalledMessage>,
1321 graph_facts: Vec<MemGraphFact>,
1322 session_summaries: Vec<MemSessionSummary>,
1323 document_chunks: Vec<MemDocumentChunk>,
1324 fail_on: Option<&'static str>,
1326 marked_ids: Mutex<Vec<String>>,
1328 }
1329
1330 impl MockMemoryBackend {
1331 fn with_fail_on(method: &'static str) -> Self {
1332 debug_assert!(
1333 KNOWN_FAIL_ON.contains(&method),
1334 "unknown fail_on method name: {method}"
1335 );
1336 Self {
1337 fail_on: Some(method),
1338 ..Default::default()
1339 }
1340 }
1341
1342 fn fail_err(method: &str) -> Box<dyn std::error::Error + Send + Sync> {
1343 format!("mock error in {method}").into()
1344 }
1345 }
1346
1347 impl ContextMemoryBackend for MockMemoryBackend {
1348 fn load_persona_facts<'a>(
1349 &'a self,
1350 _min_confidence: f64,
1351 ) -> std::pin::Pin<
1352 Box<
1353 dyn std::future::Future<
1354 Output = Result<
1355 Vec<MemPersonaFact>,
1356 Box<dyn std::error::Error + Send + Sync>,
1357 >,
1358 > + Send
1359 + 'a,
1360 >,
1361 > {
1362 let result = if self.fail_on == Some("load_persona_facts") {
1363 Err(Self::fail_err("load_persona_facts"))
1364 } else {
1365 Ok(self.persona_facts.clone())
1366 };
1367 Box::pin(async move { result })
1368 }
1369
1370 fn load_trajectory_entries<'a>(
1371 &'a self,
1372 _tier: Option<&'a str>,
1373 _top_k: usize,
1374 ) -> std::pin::Pin<
1375 Box<
1376 dyn std::future::Future<
1377 Output = Result<
1378 Vec<MemTrajectoryEntry>,
1379 Box<dyn std::error::Error + Send + Sync>,
1380 >,
1381 > + Send
1382 + 'a,
1383 >,
1384 > {
1385 let result = if self.fail_on == Some("load_trajectory_entries") {
1386 Err(Self::fail_err("load_trajectory_entries"))
1387 } else {
1388 Ok(self.trajectory_entries.clone())
1389 };
1390 Box::pin(async move { result })
1391 }
1392
1393 fn load_tree_nodes<'a>(
1394 &'a self,
1395 _level: u32,
1396 _top_k: usize,
1397 ) -> std::pin::Pin<
1398 Box<
1399 dyn std::future::Future<
1400 Output = Result<Vec<MemTreeNode>, Box<dyn std::error::Error + Send + Sync>>,
1401 > + Send
1402 + 'a,
1403 >,
1404 > {
1405 let result = if self.fail_on == Some("load_tree_nodes") {
1406 Err(Self::fail_err("load_tree_nodes"))
1407 } else {
1408 Ok(self.tree_nodes.clone())
1409 };
1410 Box::pin(async move { result })
1411 }
1412
1413 fn load_summaries<'a>(
1414 &'a self,
1415 _conversation_id: i64,
1416 ) -> std::pin::Pin<
1417 Box<
1418 dyn std::future::Future<
1419 Output = Result<Vec<MemSummary>, Box<dyn std::error::Error + Send + Sync>>,
1420 > + Send
1421 + 'a,
1422 >,
1423 > {
1424 let result = if self.fail_on == Some("load_summaries") {
1425 Err(Self::fail_err("load_summaries"))
1426 } else {
1427 Ok(self.summaries.clone())
1428 };
1429 Box::pin(async move { result })
1430 }
1431
1432 fn retrieve_reasoning_strategies<'a>(
1433 &'a self,
1434 _query: &'a str,
1435 _top_k: usize,
1436 ) -> std::pin::Pin<
1437 Box<
1438 dyn std::future::Future<
1439 Output = Result<
1440 Vec<MemReasoningStrategy>,
1441 Box<dyn std::error::Error + Send + Sync>,
1442 >,
1443 > + Send
1444 + 'a,
1445 >,
1446 > {
1447 let result = if self.fail_on == Some("retrieve_reasoning_strategies") {
1448 Err(Self::fail_err("retrieve_reasoning_strategies"))
1449 } else {
1450 Ok(self.reasoning_strategies.clone())
1451 };
1452 Box::pin(async move { result })
1453 }
1454
1455 fn mark_reasoning_used<'a>(
1456 &'a self,
1457 ids: &'a [String],
1458 ) -> std::pin::Pin<
1459 Box<
1460 dyn std::future::Future<
1461 Output = Result<(), Box<dyn std::error::Error + Send + Sync>>,
1462 > + Send
1463 + 'a,
1464 >,
1465 > {
1466 if self.fail_on == Some("mark_reasoning_used") {
1467 return Box::pin(async move { Err(Self::fail_err("mark_reasoning_used")) });
1468 }
1469 let mut guard = self.marked_ids.lock().expect("marked_ids poisoned");
1470 guard.extend_from_slice(ids);
1471 Box::pin(async move { Ok(()) })
1472 }
1473
1474 fn retrieve_corrections<'a>(
1475 &'a self,
1476 _query: &'a str,
1477 _limit: usize,
1478 _min_score: f32,
1479 ) -> std::pin::Pin<
1480 Box<
1481 dyn std::future::Future<
1482 Output = Result<
1483 Vec<MemCorrection>,
1484 Box<dyn std::error::Error + Send + Sync>,
1485 >,
1486 > + Send
1487 + 'a,
1488 >,
1489 > {
1490 let result = if self.fail_on == Some("retrieve_corrections") {
1491 Err(Self::fail_err("retrieve_corrections"))
1492 } else {
1493 Ok(self.corrections.clone())
1494 };
1495 Box::pin(async move { result })
1496 }
1497
1498 fn recall<'a>(
1499 &'a self,
1500 _query: &'a str,
1501 _limit: usize,
1502 _router: Option<&'a dyn zeph_common::memory::AsyncMemoryRouter>,
1503 ) -> std::pin::Pin<
1504 Box<
1505 dyn std::future::Future<
1506 Output = Result<
1507 Vec<MemRecalledMessage>,
1508 Box<dyn std::error::Error + Send + Sync>,
1509 >,
1510 > + Send
1511 + 'a,
1512 >,
1513 > {
1514 let result = if self.fail_on == Some("recall") {
1515 Err(Self::fail_err("recall"))
1516 } else {
1517 Ok(self.recalled.clone())
1518 };
1519 Box::pin(async move { result })
1520 }
1521
1522 fn recall_graph_facts<'a>(
1523 &'a self,
1524 _query: &'a str,
1525 _params: GraphRecallParams<'a>,
1526 ) -> std::pin::Pin<
1527 Box<
1528 dyn std::future::Future<
1529 Output = Result<
1530 Vec<MemGraphFact>,
1531 Box<dyn std::error::Error + Send + Sync>,
1532 >,
1533 > + Send
1534 + 'a,
1535 >,
1536 > {
1537 let result = if self.fail_on == Some("recall_graph_facts") {
1538 Err(Self::fail_err("recall_graph_facts"))
1539 } else {
1540 Ok(self.graph_facts.clone())
1541 };
1542 Box::pin(async move { result })
1543 }
1544
1545 fn search_session_summaries<'a>(
1546 &'a self,
1547 _query: &'a str,
1548 _limit: usize,
1549 _current_conversation_id: Option<i64>,
1550 ) -> std::pin::Pin<
1551 Box<
1552 dyn std::future::Future<
1553 Output = Result<
1554 Vec<MemSessionSummary>,
1555 Box<dyn std::error::Error + Send + Sync>,
1556 >,
1557 > + Send
1558 + 'a,
1559 >,
1560 > {
1561 let result = if self.fail_on == Some("search_session_summaries") {
1562 Err(Self::fail_err("search_session_summaries"))
1563 } else {
1564 Ok(self.session_summaries.clone())
1565 };
1566 Box::pin(async move { result })
1567 }
1568
1569 fn search_document_collection<'a>(
1570 &'a self,
1571 _collection: &'a str,
1572 _query: &'a str,
1573 _top_k: usize,
1574 ) -> std::pin::Pin<
1575 Box<
1576 dyn std::future::Future<
1577 Output = Result<
1578 Vec<MemDocumentChunk>,
1579 Box<dyn std::error::Error + Send + Sync>,
1580 >,
1581 > + Send
1582 + 'a,
1583 >,
1584 > {
1585 let result = if self.fail_on == Some("search_document_collection") {
1586 Err(Self::fail_err("search_document_collection"))
1587 } else {
1588 Ok(self.document_chunks.clone())
1589 };
1590 Box::pin(async move { result })
1591 }
1592 }
1593
1594 fn mock_view(mock: MockMemoryBackend) -> ContextMemoryView {
1595 let mut v = empty_view();
1596 v.memory = Some(Arc::new(mock));
1597 v
1598 }
1599
1600 #[tokio::test]
1603 async fn fetch_graph_facts_returns_message_when_memory_present() {
1604 let mock = MockMemoryBackend {
1605 graph_facts: vec![zeph_common::memory::MemGraphFact {
1606 fact: "Rust is fast".to_string(),
1607 confidence: 0.9,
1608 activation_score: None,
1609 neighbors: vec![],
1610 provenance_snippet: None,
1611 }],
1612 ..Default::default()
1613 };
1614 let mut view = mock_view(mock);
1615 view.graph_config.enabled = true;
1616 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1618 let tc = NaiveTokenCounter;
1619 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1620 assert!(result.is_some(), "expected Some message");
1621 let msg = result.unwrap();
1622 assert!(
1623 msg.content.contains("Rust is fast"),
1624 "expected fact text in output, got: {}",
1625 msg.content
1626 );
1627 assert!(
1628 msg.content.starts_with(GRAPH_FACTS_PREFIX),
1629 "expected GRAPH_FACTS_PREFIX"
1630 );
1631 }
1632
1633 #[tokio::test]
1634 async fn fetch_graph_facts_swallows_error_and_returns_none() {
1635 let mock = MockMemoryBackend::with_fail_on("recall_graph_facts");
1636 let mut view = mock_view(mock);
1637 view.graph_config.enabled = true;
1638 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1639 let tc = NaiveTokenCounter;
1640 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1642 assert!(
1643 result.is_none(),
1644 "expected None when recall_graph_facts errors"
1645 );
1646 }
1647
1648 #[tokio::test]
1649 async fn fetch_graph_facts_returns_none_when_facts_empty() {
1650 let mock = MockMemoryBackend::default(); let mut view = mock_view(mock);
1652 view.graph_config.enabled = true;
1653 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1654 let tc = NaiveTokenCounter;
1655 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1656 assert!(result.is_none());
1657 }
1658
1659 #[tokio::test]
1662 async fn fetch_persona_facts_returns_message_when_persona_enabled() {
1663 let mock = MockMemoryBackend {
1664 persona_facts: vec![MemPersonaFact {
1665 category: "preference".to_string(),
1666 content: "prefers concise answers".to_string(),
1667 }],
1668 ..Default::default()
1669 };
1670 let mut view = mock_view(mock);
1671 view.persona_config.enabled = true;
1672 view.persona_config.context_budget_tokens = 1000;
1673 let tc = NaiveTokenCounter;
1674 let result = fetch_persona_facts(&view, 1000, &tc).await.unwrap();
1675 assert!(result.is_some());
1676 let msg = result.unwrap();
1677 assert!(msg.content.contains("preference"));
1678 assert!(msg.content.contains("prefers concise answers"));
1679 assert!(msg.content.starts_with(crate::slot::PERSONA_PREFIX));
1680 }
1681
1682 #[tokio::test]
1683 async fn fetch_persona_facts_propagates_error() {
1684 let mock = MockMemoryBackend::with_fail_on("load_persona_facts");
1685 let mut view = mock_view(mock);
1686 view.persona_config.enabled = true;
1687 let tc = NaiveTokenCounter;
1688 let result = fetch_persona_facts(&view, 1000, &tc).await;
1689 assert!(
1690 result.is_err(),
1691 "expected Err from load_persona_facts failure"
1692 );
1693 }
1694
1695 #[tokio::test]
1698 async fn fetch_trajectory_hints_returns_message_when_trajectory_enabled() {
1699 let mock = MockMemoryBackend {
1700 trajectory_entries: vec![MemTrajectoryEntry {
1701 intent: "summarize code".to_string(),
1702 outcome: "produced concise summary".to_string(),
1703 confidence: 0.9,
1704 }],
1705 ..Default::default()
1706 };
1707 let mut view = mock_view(mock);
1708 view.trajectory_config.enabled = true;
1709 view.trajectory_config.context_budget_tokens = 1000;
1710 view.trajectory_config.min_confidence = 0.5;
1711 let tc = NaiveTokenCounter;
1712 let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1713 assert!(result.is_some());
1714 let msg = result.unwrap();
1715 assert!(msg.content.contains("summarize code"));
1716 assert!(msg.content.starts_with(crate::slot::TRAJECTORY_PREFIX));
1717 }
1718
1719 #[tokio::test]
1720 async fn fetch_trajectory_hints_passes_tier_filter() {
1721 let mock = MockMemoryBackend {
1724 trajectory_entries: vec![
1725 MemTrajectoryEntry {
1726 intent: "debug async code".to_string(),
1727 outcome: "fixed deadlock".to_string(),
1728 confidence: 0.85,
1729 },
1730 MemTrajectoryEntry {
1731 intent: "low confidence task".to_string(),
1732 outcome: "irrelevant".to_string(),
1733 confidence: 0.3,
1734 },
1735 ],
1736 ..Default::default()
1737 };
1738 let mut view = mock_view(mock);
1739 view.trajectory_config.enabled = true;
1740 view.trajectory_config.context_budget_tokens = 1000;
1741 view.trajectory_config.min_confidence = 0.5;
1742 let tc = NaiveTokenCounter;
1743 let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1744 assert!(result.is_some(), "expected Some message");
1745 let msg = result.unwrap();
1746 assert!(
1747 msg.content.contains("debug async code"),
1748 "high-confidence entry must be included"
1749 );
1750 assert!(
1751 !msg.content.contains("low confidence task"),
1752 "entry below min_confidence must be filtered out"
1753 );
1754 }
1755
1756 #[tokio::test]
1757 async fn fetch_trajectory_hints_propagates_error() {
1758 let mock = MockMemoryBackend::with_fail_on("load_trajectory_entries");
1759 let mut view = mock_view(mock);
1760 view.trajectory_config.enabled = true;
1761 let tc = NaiveTokenCounter;
1762 let result = fetch_trajectory_hints(&view, 1000, &tc).await;
1763 assert!(result.is_err());
1764 }
1765
1766 #[tokio::test]
1769 async fn fetch_tree_memory_returns_message_when_tree_enabled() {
1770 let mock = MockMemoryBackend {
1771 tree_nodes: vec![MemTreeNode {
1772 content: "Topic: async Rust patterns".to_string(),
1773 }],
1774 ..Default::default()
1775 };
1776 let mut view = mock_view(mock);
1777 view.tree_config.enabled = true;
1778 view.tree_config.context_budget_tokens = 1000;
1779 let tc = NaiveTokenCounter;
1780 let result = fetch_tree_memory(&view, 1000, &tc).await.unwrap();
1781 assert!(result.is_some());
1782 let msg = result.unwrap();
1783 assert!(msg.content.contains("async Rust patterns"));
1784 assert!(msg.content.starts_with(crate::slot::TREE_MEMORY_PREFIX));
1785 }
1786
1787 #[tokio::test]
1788 async fn fetch_tree_memory_propagates_error() {
1789 let mock = MockMemoryBackend::with_fail_on("load_tree_nodes");
1790 let mut view = mock_view(mock);
1791 view.tree_config.enabled = true;
1792 let tc = NaiveTokenCounter;
1793 let result = fetch_tree_memory(&view, 1000, &tc).await;
1794 assert!(result.is_err());
1795 }
1796
1797 #[tokio::test]
1800 async fn fetch_corrections_returns_message_when_corrections_present() {
1801 let mock = MockMemoryBackend {
1802 corrections: vec![MemCorrection {
1803 correction_text: "use snake_case not camelCase".to_string(),
1804 }],
1805 ..Default::default()
1806 };
1807 let view = mock_view(mock);
1808 let result = fetch_corrections(&view, "query", 10, 0.5, |s| s.into())
1809 .await
1810 .unwrap();
1811 assert!(result.is_some());
1812 let msg = result.unwrap();
1813 assert!(msg.content.contains("snake_case"));
1814 assert!(msg.content.starts_with(CORRECTIONS_PREFIX));
1815 }
1816
1817 #[tokio::test]
1818 async fn fetch_corrections_propagates_error() {
1819 let mock = MockMemoryBackend::with_fail_on("retrieve_corrections");
1822 let view = mock_view(mock);
1823 let result = fetch_corrections(&view, "query", 10, 0.5, |s| s.into()).await;
1824 assert!(result.is_err(), "expected Err, got {result:?}");
1825 }
1826
1827 #[tokio::test]
1830 async fn fetch_semantic_recall_returns_message_with_content() {
1831 let mock = MockMemoryBackend {
1832 recalled: vec![
1833 MemRecalledMessage {
1834 role: "user".to_string(),
1835 content: "how does tokio work".to_string(),
1836 score: 0.95,
1837 },
1838 MemRecalledMessage {
1839 role: "assistant".to_string(),
1840 content: "tokio is an async runtime".to_string(),
1841 score: 0.88,
1842 },
1843 ],
1844 ..Default::default()
1845 };
1846 let mut view = mock_view(mock);
1847 view.recall_limit = 10;
1848 let tc = NaiveTokenCounter;
1849 let (msg, score) = fetch_semantic_recall(&view, "tokio", 1000, &tc, None)
1850 .await
1851 .unwrap();
1852 assert!(msg.is_some(), "expected Some message");
1853 assert!(score.is_some_and(|s| (s - 0.95_f32).abs() < f32::EPSILON));
1855 let msg = msg.unwrap();
1856 let has_recall_part = msg.parts.iter().any(|p| {
1858 if let zeph_llm::provider::MessagePart::Recall { text } = p {
1859 text.contains("how does tokio work")
1860 } else {
1861 false
1862 }
1863 });
1864 assert!(has_recall_part, "expected recalled content in Recall part");
1865 }
1866
1867 #[tokio::test]
1868 async fn fetch_semantic_recall_returns_none_when_recalled_empty() {
1869 let mock = MockMemoryBackend::default();
1870 let mut view = mock_view(mock);
1871 view.recall_limit = 10;
1872 let tc = NaiveTokenCounter;
1873 let (msg, score) = fetch_semantic_recall(&view, "query", 1000, &tc, None)
1874 .await
1875 .unwrap();
1876 assert!(msg.is_none());
1877 assert!(score.is_none());
1878 }
1879
1880 #[tokio::test]
1881 async fn fetch_semantic_recall_propagates_error() {
1882 let mock = MockMemoryBackend::with_fail_on("recall");
1883 let mut view = mock_view(mock);
1884 view.recall_limit = 10;
1885 let tc = NaiveTokenCounter;
1886 let result = fetch_semantic_recall(&view, "query", 1000, &tc, None).await;
1887 assert!(result.is_err());
1888 }
1889
1890 #[tokio::test]
1893 async fn fetch_document_rag_returns_message_when_rag_enabled() {
1894 let mock = MockMemoryBackend {
1895 document_chunks: vec![MemDocumentChunk {
1896 text: "Rust ownership rules prevent data races".to_string(),
1897 }],
1898 ..Default::default()
1899 };
1900 let mut view = mock_view(mock);
1901 view.document_config.rag_enabled = true;
1902 let tc = NaiveTokenCounter;
1903 let result = fetch_document_rag(&view, "ownership", 1000, &tc)
1904 .await
1905 .unwrap();
1906 assert!(result.is_some());
1907 let msg = result.unwrap();
1908 assert!(msg.content.contains("ownership rules"));
1909 assert!(msg.content.starts_with(DOCUMENT_RAG_PREFIX));
1910 }
1911
1912 #[tokio::test]
1913 async fn fetch_document_rag_propagates_error() {
1914 let mock = MockMemoryBackend::with_fail_on("search_document_collection");
1915 let mut view = mock_view(mock);
1916 view.document_config.rag_enabled = true;
1917 let tc = NaiveTokenCounter;
1918 let result = fetch_document_rag(&view, "query", 1000, &tc).await;
1919 assert!(result.is_err());
1920 }
1921
1922 #[tokio::test]
1925 async fn fetch_summaries_returns_message_when_summaries_present() {
1926 let mock = MockMemoryBackend {
1927 summaries: vec![MemSummary {
1928 first_message_id: Some(1),
1929 last_message_id: Some(5),
1930 content: "User asked about async Rust".to_string(),
1931 }],
1932 ..Default::default()
1933 };
1934 let mut view = mock_view(mock);
1935 view.conversation_id = Some(42);
1936 let tc = NaiveTokenCounter;
1937 let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1938 assert!(result.is_some());
1939 let msg = result.unwrap();
1940 let has_summary_part = msg.parts.iter().any(|p| {
1941 if let zeph_llm::provider::MessagePart::Summary { text } = p {
1942 text.contains("Messages 1-5") && text.contains("async Rust")
1943 } else {
1944 false
1945 }
1946 });
1947 assert!(
1948 has_summary_part,
1949 "expected Summary part with messages range"
1950 );
1951 }
1952
1953 #[tokio::test]
1954 async fn fetch_summaries_returns_none_without_conversation_id() {
1955 let mock = MockMemoryBackend {
1956 summaries: vec![MemSummary {
1957 first_message_id: Some(1),
1958 last_message_id: Some(5),
1959 content: "some content".to_string(),
1960 }],
1961 ..Default::default()
1962 };
1963 let mut view = mock_view(mock);
1964 view.conversation_id = None; let tc = NaiveTokenCounter;
1966 let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1967 assert!(result.is_none());
1968 }
1969
1970 #[tokio::test]
1971 async fn fetch_summaries_propagates_error() {
1972 let mock = MockMemoryBackend::with_fail_on("load_summaries");
1973 let mut view = mock_view(mock);
1974 view.conversation_id = Some(42);
1975 let tc = NaiveTokenCounter;
1976 let result = fetch_summaries(&view, 1000, &tc).await;
1977 assert!(result.is_err());
1978 }
1979
1980 #[tokio::test]
1983 async fn fetch_cross_session_returns_message_when_results_present() {
1984 let mock = MockMemoryBackend {
1985 session_summaries: vec![MemSessionSummary {
1986 summary_text: "Previous session: debugging tokio deadlock".to_string(),
1987 score: 0.9,
1988 }],
1989 ..Default::default()
1990 };
1991 let mut view = mock_view(mock);
1992 view.conversation_id = Some(1);
1993 view.cross_session_score_threshold = 0.5;
1994 let tc = NaiveTokenCounter;
1995 let result = fetch_cross_session(&view, "async", 1000, &tc)
1996 .await
1997 .unwrap();
1998 assert!(result.is_some());
1999 let msg = result.unwrap();
2000 let has_cross_session_part = msg.parts.iter().any(|p| {
2001 if let zeph_llm::provider::MessagePart::CrossSession { text } = p {
2002 text.contains("tokio deadlock")
2003 } else {
2004 false
2005 }
2006 });
2007 assert!(has_cross_session_part);
2008 }
2009
2010 #[tokio::test]
2011 async fn fetch_cross_session_propagates_error() {
2012 let mock = MockMemoryBackend::with_fail_on("search_session_summaries");
2013 let mut view = mock_view(mock);
2014 view.conversation_id = Some(1);
2015 let tc = NaiveTokenCounter;
2016 let result = fetch_cross_session(&view, "query", 1000, &tc).await;
2017 assert!(result.is_err());
2018 }
2019
2020 #[tokio::test]
2023 async fn fetch_reasoning_strategies_returns_message_and_marks_used() {
2024 let mock = Arc::new(MockMemoryBackend {
2025 reasoning_strategies: vec![
2026 MemReasoningStrategy {
2027 id: "strat-1".to_string(),
2028 outcome: "success".to_string(),
2029 summary: "break the problem into small steps".to_string(),
2030 },
2031 MemReasoningStrategy {
2032 id: "strat-2".to_string(),
2033 outcome: "success".to_string(),
2034 summary: "use tracing spans for debugging".to_string(),
2035 },
2036 ],
2037 ..Default::default()
2038 });
2039 let marked_ids = Arc::clone(&mock);
2040 let mut view = empty_view();
2041 view.memory = Some(mock);
2042 view.reasoning_config.enabled = true;
2043 view.reasoning_config.context_budget_tokens = 1000;
2044 let tc = NaiveTokenCounter;
2045 let (result, handle) = fetch_reasoning_strategies(&view, "debug", 1000, 5, &tc)
2046 .await
2047 .unwrap();
2048 assert!(result.is_some());
2049 let msg = result.unwrap();
2050 assert!(msg.content.starts_with(crate::slot::REASONING_PREFIX));
2051 assert!(msg.content.contains("break the problem"));
2052
2053 if let Some(h) = handle {
2055 h.await.expect("mark_reasoning_used task panicked");
2056 }
2057
2058 let ids = marked_ids.marked_ids.lock().expect("marked_ids poisoned");
2059 assert!(
2060 ids.contains(&"strat-1".to_string()),
2061 "expected strat-1 marked"
2062 );
2063 assert!(
2064 ids.contains(&"strat-2".to_string()),
2065 "expected strat-2 marked"
2066 );
2067 }
2068
2069 #[tokio::test]
2070 async fn fetch_reasoning_strategies_propagates_error() {
2071 let mock = MockMemoryBackend::with_fail_on("retrieve_reasoning_strategies");
2072 let mut view = mock_view(mock);
2073 view.reasoning_config.enabled = true;
2074 let tc = NaiveTokenCounter;
2075 let result = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc).await;
2076 assert!(result.is_err());
2077 }
2078
2079 #[tokio::test]
2082 async fn fetch_semantic_recall_skips_skipped_and_stopped_messages() {
2083 let mock = MockMemoryBackend {
2084 recalled: vec![
2085 MemRecalledMessage {
2086 role: "user".to_string(),
2087 content: "[skipped] some content".to_string(),
2088 score: 0.95,
2089 },
2090 MemRecalledMessage {
2091 role: "user".to_string(),
2092 content: "[stopped] other content".to_string(),
2093 score: 0.90,
2094 },
2095 MemRecalledMessage {
2096 role: "user".to_string(),
2097 content: "valid content to recall".to_string(),
2098 score: 0.85,
2099 },
2100 ],
2101 ..Default::default()
2102 };
2103 let mut view = mock_view(mock);
2104 view.recall_limit = 10;
2105 let tc = NaiveTokenCounter;
2106 let (msg, _) = fetch_semantic_recall(&view, "query", 1000, &tc, None)
2107 .await
2108 .unwrap();
2109 assert!(msg.is_some());
2110 let msg = msg.unwrap();
2111 let full_text = msg.parts.iter().find_map(|p| {
2112 if let zeph_llm::provider::MessagePart::Recall { text } = p {
2113 Some(text.clone())
2114 } else {
2115 None
2116 }
2117 });
2118 let text = full_text.unwrap_or_default();
2119 assert!(
2120 !text.contains("[skipped]"),
2121 "skipped messages must be excluded"
2122 );
2123 assert!(
2124 !text.contains("[stopped]"),
2125 "stopped messages must be excluded"
2126 );
2127 assert!(
2128 text.contains("valid content to recall"),
2129 "valid messages must be included"
2130 );
2131 }
2132
2133 #[tokio::test]
2134 async fn fetch_cross_session_filters_below_threshold() {
2135 let mock = MockMemoryBackend {
2136 session_summaries: vec![
2137 MemSessionSummary {
2138 summary_text: "high relevance session".to_string(),
2139 score: 0.9,
2140 },
2141 MemSessionSummary {
2142 summary_text: "low relevance session".to_string(),
2143 score: 0.2,
2144 },
2145 ],
2146 ..Default::default()
2147 };
2148 let mut view = mock_view(mock);
2149 view.conversation_id = Some(1);
2150 view.cross_session_score_threshold = 0.5;
2151 let tc = NaiveTokenCounter;
2152 let result = fetch_cross_session(&view, "query", 1000, &tc)
2153 .await
2154 .unwrap();
2155 assert!(result.is_some());
2156 let msg = result.unwrap();
2157 let text = msg
2158 .parts
2159 .iter()
2160 .find_map(|p| {
2161 if let zeph_llm::provider::MessagePart::CrossSession { text } = p {
2162 Some(text.clone())
2163 } else {
2164 None
2165 }
2166 })
2167 .unwrap_or_default();
2168 assert!(
2169 text.contains("high relevance"),
2170 "high score must be included"
2171 );
2172 assert!(
2173 !text.contains("low relevance"),
2174 "low score must be filtered out"
2175 );
2176 }
2177
2178 #[tokio::test]
2179 async fn fetch_document_rag_skips_empty_chunks() {
2180 let mock = MockMemoryBackend {
2181 document_chunks: vec![
2182 MemDocumentChunk {
2183 text: String::new(),
2184 }, MemDocumentChunk {
2186 text: "real content here".to_string(),
2187 },
2188 ],
2189 ..Default::default()
2190 };
2191 let mut view = mock_view(mock);
2192 view.document_config.rag_enabled = true;
2193 let tc = NaiveTokenCounter;
2194 let result = fetch_document_rag(&view, "query", 1000, &tc).await.unwrap();
2195 assert!(result.is_some());
2196 let msg = result.unwrap();
2197 assert!(msg.content.contains("real content here"));
2198 assert!(!msg.content.contains("\n\n\n"));
2200 }
2201
2202 #[tokio::test]
2203 async fn fetch_graph_facts_sanitizes_injection_payloads() {
2204 let mock = MockMemoryBackend {
2206 graph_facts: vec![zeph_common::memory::MemGraphFact {
2207 fact: "fact with <script>alert(1)</script> and\nnewline".to_string(),
2208 confidence: 0.8,
2209 activation_score: None,
2210 neighbors: vec![],
2211 provenance_snippet: None,
2212 }],
2213 ..Default::default()
2214 };
2215 let mut view = mock_view(mock);
2216 view.graph_config.enabled = true;
2217 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
2218 let tc = NaiveTokenCounter;
2219 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
2220 assert!(result.is_some());
2221 let msg = result.unwrap();
2222 assert!(
2223 !msg.content.contains('<'),
2224 "angle brackets must be sanitized"
2225 );
2226 assert!(
2229 !msg.content.contains("\n\n"),
2230 "embedded newlines must be sanitized, no double-newline sequences expected"
2231 );
2232 }
2233
2234 #[tokio::test]
2235 async fn fetch_reasoning_strategies_sanitizes_injection_payloads() {
2236 let mock = MockMemoryBackend {
2238 reasoning_strategies: vec![MemReasoningStrategy {
2239 id: "s1".to_string(),
2240 outcome: "success".to_string(),
2241 summary: "strategy with <b>bold</b> and\nnewline".to_string(),
2242 }],
2243 ..Default::default()
2244 };
2245 let mut view = mock_view(mock);
2246 view.reasoning_config.enabled = true;
2247 let tc = NaiveTokenCounter;
2248 let (result, _handle) = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc)
2249 .await
2250 .unwrap();
2251 assert!(result.is_some());
2252 let msg = result.unwrap();
2253 assert!(
2254 !msg.content.contains('<'),
2255 "angle brackets must be sanitized in strategy summaries"
2256 );
2257 }
2258
2259 #[tokio::test]
2262 async fn fetch_persona_facts_truncates_at_budget() {
2263 let tc = NaiveTokenCounter;
2264 let first_line = "[pref] brief\n";
2266 let budget = tc.count_tokens(crate::slot::PERSONA_PREFIX) + tc.count_tokens(first_line);
2267 let mock = MockMemoryBackend {
2268 persona_facts: vec![
2269 MemPersonaFact {
2270 category: "pref".to_string(),
2271 content: "brief".to_string(),
2272 },
2273 MemPersonaFact {
2274 category: "lang".to_string(),
2275 content: "english".to_string(),
2276 },
2277 ],
2278 ..Default::default()
2279 };
2280 let mut view = mock_view(mock);
2281 view.persona_config.enabled = true;
2282 let result = fetch_persona_facts(&view, budget, &tc).await.unwrap();
2283 let msg = result.unwrap();
2284 assert!(msg.content.contains("brief"), "first fact must be included");
2285 assert!(
2286 !msg.content.contains("english"),
2287 "second fact must be truncated by budget"
2288 );
2289 }
2290
2291 #[tokio::test]
2292 async fn fetch_semantic_recall_truncates_at_budget() {
2293 let tc = NaiveTokenCounter;
2294 let first_entry = "- [user] first message\n";
2296 let budget = tc.count_tokens(RECALL_PREFIX) + tc.count_tokens(first_entry);
2297 let mock = MockMemoryBackend {
2298 recalled: vec![
2299 MemRecalledMessage {
2300 role: "user".to_string(),
2301 content: "first message".to_string(),
2302 score: 0.95,
2303 },
2304 MemRecalledMessage {
2305 role: "user".to_string(),
2306 content: "second message that should be truncated".to_string(),
2307 score: 0.80,
2308 },
2309 ],
2310 ..Default::default()
2311 };
2312 let mut view = mock_view(mock);
2313 view.recall_limit = 10;
2314 let (msg, _) = fetch_semantic_recall(&view, "query", budget, &tc, None)
2315 .await
2316 .unwrap();
2317 assert!(msg.is_some());
2318 let text = msg
2319 .unwrap()
2320 .parts
2321 .iter()
2322 .find_map(|p| {
2323 if let zeph_llm::provider::MessagePart::Recall { text } = p {
2324 Some(text.clone())
2325 } else {
2326 None
2327 }
2328 })
2329 .unwrap_or_default();
2330 assert!(
2331 text.contains("first message"),
2332 "first entry must be included"
2333 );
2334 assert!(
2335 !text.contains("second message"),
2336 "second entry must be truncated by budget"
2337 );
2338 }
2339
2340 #[tokio::test]
2343 async fn fetch_graph_facts_sanitizes_provenance_snippet() {
2344 use zeph_common::memory::MemGraphNeighbor;
2345 let mock = MockMemoryBackend {
2346 graph_facts: vec![zeph_common::memory::MemGraphFact {
2347 fact: "safe fact".to_string(),
2348 confidence: 0.9,
2349 activation_score: None,
2350 neighbors: vec![MemGraphNeighbor {
2351 fact: "neighbor".to_string(),
2352 confidence: 0.7,
2353 }],
2354 provenance_snippet: Some("source with <injection>\nand newline".to_string()),
2355 }],
2356 ..Default::default()
2357 };
2358 let mut view = mock_view(mock);
2359 view.graph_config.enabled = true;
2360 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
2361 let tc = NaiveTokenCounter;
2362 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
2363 assert!(result.is_some());
2364 let msg = result.unwrap();
2365 assert!(
2366 !msg.content.contains('<'),
2367 "angle brackets in provenance_snippet must be sanitized"
2368 );
2369 assert!(
2370 !msg.content.contains("\n\n"),
2371 "newlines in provenance_snippet must be sanitized"
2372 );
2373 assert!(
2374 msg.content.contains("[source:"),
2375 "provenance snippet must be rendered"
2376 );
2377 }
2378}