Skip to main content

zeph_context/
assembler.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Stateless context assembler.
5//!
6//! [`ContextAssembler`] gathers all memory-sourced context for a single agent turn by running
7//! all async fetch operations concurrently. It takes only borrowed references via
8//! [`ContextAssemblyInput`] and returns a [`PreparedContext`] ready for injection.
9//!
10//! Invariants:
11//! - No `Agent` field mutations inside `gather()`.
12//! - No channel communication inside `gather()`.
13//! - All `send_status` calls remain in `Agent::prepare_context`.
14//! - `session_digest` is cached (not async) and stays in `Agent::apply_prepared_context`.
15
16use std::future::Future;
17use std::pin::Pin;
18
19use futures::StreamExt as _;
20use futures::stream::FuturesUnordered;
21
22use zeph_common::memory::{AsyncMemoryRouter, CompressionLevel, GraphRecallParams, TokenCounting};
23use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
24
25use crate::error::AssemblerError;
26use crate::input::ContextAssemblyInput;
27use crate::slot::ContextSlot;
28
29/// Map a slice of active compression levels to per-tier boolean flags.
30///
31/// Returns `(episodic_active, procedural_active, declarative_active)`.
32///
33/// An empty slice means "no tier filtering": all three flags are `true`. This is the defensive
34/// default — passing an empty slice preserves legacy behaviour instead of silently suppressing
35/// all memory recall.
36pub(crate) fn levels_to_flags(levels: &[CompressionLevel]) -> (bool, bool, bool) {
37    if levels.is_empty() {
38        return (true, true, true);
39    }
40    let episodic = levels.contains(&CompressionLevel::Episodic);
41    let procedural = levels.contains(&CompressionLevel::Procedural);
42    let declarative = levels.contains(&CompressionLevel::Declarative);
43    (episodic, procedural, declarative)
44}
45
46/// Prefix for past-session summary injections.
47pub const SUMMARY_PREFIX: &str = "[conversation summaries]\n";
48/// Prefix for cross-session context injections.
49pub const CROSS_SESSION_PREFIX: &str = "[cross-session context]\n";
50/// Prefix for semantic recall injections.
51pub const RECALL_PREFIX: &str = "[semantic recall]\n";
52/// Prefix for past-correction injections.
53pub const CORRECTIONS_PREFIX: &str = "[past corrections]\n";
54/// Prefix for document RAG injections.
55pub const DOCUMENT_RAG_PREFIX: &str = "## Relevant documents\n";
56/// Prefix for knowledge graph fact injections.
57pub const GRAPH_FACTS_PREFIX: &str = "[known facts]\n";
58
59/// Result of one context-assembly pass.
60///
61/// All source fields are `Option` — `None` means disabled, empty, or budget-exhausted.
62/// `session_digest` is excluded: it is a cached value injected by `Agent::apply_prepared_context`.
63#[derive(Default)]
64pub struct PreparedContext {
65    /// Knowledge graph fact recall.
66    pub graph_facts: Option<Message>,
67    /// Document RAG context.
68    pub doc_rag: Option<Message>,
69    /// Past user corrections.
70    pub corrections: Option<Message>,
71    /// Semantic recall results.
72    pub recall: Option<Message>,
73    /// Top-1 similarity score from semantic recall.
74    pub recall_confidence: Option<f32>,
75    /// Cross-session memory context.
76    pub cross_session: Option<Message>,
77    /// Past-conversation summaries.
78    pub summaries: Option<Message>,
79    /// Code-index RAG context (repo map or file context).
80    pub code_context: Option<String>,
81    /// Persona memory facts.
82    pub persona_facts: Option<Message>,
83    /// Trajectory hints.
84    pub trajectory_hints: Option<Message>,
85    /// `TiMem` tree memory summary.
86    pub tree_memory: Option<Message>,
87    /// Distilled reasoning strategies from the `ReasoningBank` (#3343).
88    pub reasoning_hints: Option<Message>,
89    /// Whether the memory-first context strategy is active for this turn.
90    pub memory_first: bool,
91    /// Token budget for recent conversation history (passed to trim step in apply).
92    pub recent_history_budget: usize,
93    /// Background tasks spawned during context assembly that must be tracked to completion.
94    ///
95    /// Callers are responsible for awaiting or aborting these handles at an appropriate boundary
96    /// (e.g., turn end). See async discipline rule: fire-and-forget tasks MUST be tracked.
97    pub background_tasks: Vec<tokio::task::JoinHandle<()>>,
98}
99
100/// Stateless coordinator for parallel context fetching.
101///
102/// All logic is in [`ContextAssembler::gather`]. No state is stored on this type.
103pub struct ContextAssembler;
104
105type CtxFuture<'a> = Pin<Box<dyn Future<Output = Result<ContextSlot, AssemblerError>> + Send + 'a>>;
106
107fn empty_prepared_context() -> PreparedContext {
108    PreparedContext::default()
109}
110
111fn resolve_effective_strategy(
112    memory: &crate::input::ContextMemoryView,
113    sidequest_turn_counter: u64,
114) -> zeph_config::ContextStrategy {
115    match memory.context_strategy {
116        zeph_config::ContextStrategy::MemoryFirst => zeph_config::ContextStrategy::MemoryFirst,
117        zeph_config::ContextStrategy::Adaptive => {
118            if sidequest_turn_counter >= u64::from(memory.crossover_turn_threshold) {
119                zeph_config::ContextStrategy::MemoryFirst
120            } else {
121                zeph_config::ContextStrategy::FullHistory
122            }
123        }
124        _ => zeph_config::ContextStrategy::FullHistory,
125    }
126}
127
128fn correction_params(cfg: Option<&crate::input::CorrectionConfig>) -> (usize, f32) {
129    cfg.filter(|c| c.correction_detection)
130        .map_or((3, 0.75), |c| {
131            (
132                c.correction_recall_limit as usize,
133                c.correction_min_similarity,
134            )
135        })
136}
137
138/// Schedules all enabled context fetchers and returns them as a set of concurrent futures.
139///
140/// `router_ref` borrows from `router`, which is a local owned by `gather`. Using a separate
141/// lifetime `'r` for `router_ref` avoids tying it to `'a` (the input lifetime), which would
142/// require `router` to outlive `input`. All `usize` budget values are passed by copy so the
143/// returned futures do not borrow from `alloc`.
144#[allow(clippy::too_many_arguments)]
145fn schedule_context_fetchers<'r>(
146    memory: &'r crate::input::ContextMemoryView,
147    tc: &'r dyn TokenCounting,
148    query: &'r str,
149    scrub: fn(&str) -> std::borrow::Cow<'_, str>,
150    index: Option<&'r dyn crate::input::IndexAccess>,
151    router_ref: &'r dyn AsyncMemoryRouter,
152    summaries_budget: usize,
153    cross_session_budget: usize,
154    semantic_recall_budget: usize,
155    code_context_budget: usize,
156    graph_facts_budget: usize,
157    recall_limit: usize,
158    min_sim: f32,
159    active_levels: &[CompressionLevel],
160) -> FuturesUnordered<CtxFuture<'r>> {
161    // TODO(critic): episodic_active currently gates summaries + cross-session + recall + doc_rag
162    // together. If future RetrievalPolicy variants ever drop Episodic, the cheap summary fetchers
163    // will be silently disabled — split into raw vs compressed sub-tiers. (#3455 follow-up)
164    let (episodic_active, procedural_active, declarative_active) = levels_to_flags(active_levels);
165
166    let fetchers: FuturesUnordered<CtxFuture<'r>> = FuturesUnordered::new();
167
168    if episodic_active && summaries_budget > 0 {
169        fetchers.push(Box::pin(async move {
170            fetch_summaries(memory, summaries_budget, tc)
171                .await
172                .map(ContextSlot::Summaries)
173        }));
174    }
175    if episodic_active && cross_session_budget > 0 {
176        fetchers.push(Box::pin(async move {
177            fetch_cross_session(memory, query, cross_session_budget, tc)
178                .await
179                .map(ContextSlot::CrossSession)
180        }));
181    }
182    if episodic_active && semantic_recall_budget > 0 {
183        fetchers.push(Box::pin(async move {
184            fetch_semantic_recall(memory, query, semantic_recall_budget, tc, Some(router_ref))
185                .await
186                .map(|(msg, score)| ContextSlot::SemanticRecall(msg, score))
187        }));
188        fetchers.push(Box::pin(async move {
189            fetch_document_rag(memory, query, semantic_recall_budget, tc)
190                .await
191                .map(ContextSlot::DocumentRag)
192        }));
193    }
194    // Corrections are safety-critical and never budget-gated or tier-gated.
195    fetchers.push(Box::pin(async move {
196        fetch_corrections(memory, query, recall_limit, min_sim, scrub)
197            .await
198            .map(ContextSlot::Corrections)
199    }));
200    // Code RAG is request-driven, not memory-tier; exempt from tier filtering.
201    if code_context_budget > 0
202        && let Some(idx) = index
203    {
204        fetchers.push(Box::pin(async move {
205            let result: Result<Option<String>, AssemblerError> =
206                idx.fetch_code_rag(query, code_context_budget).await;
207            result.map(ContextSlot::CodeContext)
208        }));
209    }
210    if declarative_active && graph_facts_budget > 0 {
211        fetchers.push(Box::pin(async move {
212            fetch_graph_facts(memory, query, graph_facts_budget, tc)
213                .await
214                .map(ContextSlot::GraphFacts)
215        }));
216    }
217    if declarative_active && memory.persona_config.context_budget_tokens > 0 {
218        fetchers.push(Box::pin(async move {
219            let persona_budget = memory.persona_config.context_budget_tokens;
220            fetch_persona_facts(memory, persona_budget, tc)
221                .await
222                .map(ContextSlot::PersonaFacts)
223        }));
224    }
225    if procedural_active && memory.trajectory_config.context_budget_tokens > 0 {
226        fetchers.push(Box::pin(async move {
227            let tbudget = memory.trajectory_config.context_budget_tokens;
228            fetch_trajectory_hints(memory, tbudget, tc)
229                .await
230                .map(ContextSlot::TrajectoryHints)
231        }));
232    }
233    if declarative_active && memory.tree_config.context_budget_tokens > 0 {
234        fetchers.push(Box::pin(async move {
235            let tbudget = memory.tree_config.context_budget_tokens;
236            fetch_tree_memory(memory, tbudget, tc)
237                .await
238                .map(ContextSlot::TreeMemory)
239        }));
240    }
241    if procedural_active
242        && memory.reasoning_config.enabled
243        && memory.reasoning_config.context_budget_tokens > 0
244    {
245        fetchers.push(Box::pin(async move {
246            let rbudget = memory.reasoning_config.context_budget_tokens;
247            let top_k = memory.reasoning_config.top_k;
248            fetch_reasoning_strategies(memory, query, rbudget, top_k, tc)
249                .await
250                .map(|(msg, handle)| ContextSlot::ReasoningStrategies(msg, handle))
251        }));
252    }
253
254    fetchers
255}
256
257async fn drive_fetchers(
258    mut fetchers: FuturesUnordered<CtxFuture<'_>>,
259    prepared: &mut PreparedContext,
260) -> Result<(), AssemblerError> {
261    while let Some(result) = fetchers.next().await {
262        match result {
263            Ok(slot) => match slot {
264                ContextSlot::Summaries(msg) => prepared.summaries = msg,
265                ContextSlot::CrossSession(msg) => prepared.cross_session = msg,
266                ContextSlot::SemanticRecall(msg, score) => {
267                    prepared.recall = msg;
268                    prepared.recall_confidence = score;
269                }
270                ContextSlot::DocumentRag(msg) => prepared.doc_rag = msg,
271                ContextSlot::Corrections(msg) => prepared.corrections = msg,
272                ContextSlot::CodeContext(text) => prepared.code_context = text,
273                ContextSlot::GraphFacts(msg) => prepared.graph_facts = msg,
274                ContextSlot::PersonaFacts(msg) => prepared.persona_facts = msg,
275                ContextSlot::TrajectoryHints(msg) => prepared.trajectory_hints = msg,
276                ContextSlot::TreeMemory(msg) => prepared.tree_memory = msg,
277                ContextSlot::ReasoningStrategies(msg, handle) => {
278                    prepared.reasoning_hints = msg;
279                    if let Some(h) = handle {
280                        prepared.background_tasks.push(h);
281                    }
282                }
283            },
284            Err(e) => return Err(e),
285        }
286    }
287    Ok(())
288}
289
290impl ContextAssembler {
291    /// Gather all context sources concurrently and return a [`PreparedContext`].
292    ///
293    /// Returns an empty `PreparedContext` immediately when `context_manager.budget` is `None`.
294    ///
295    /// # Errors
296    ///
297    /// Propagates errors from any async fetch operation.
298    #[tracing::instrument(name = "context.assembler.gather", skip_all)]
299    pub async fn gather(
300        input: &ContextAssemblyInput<'_>,
301    ) -> Result<PreparedContext, AssemblerError> {
302        let Some(ref budget) = input.context_manager.budget else {
303            return Ok(empty_prepared_context());
304        };
305
306        let memory = input.memory;
307        let tc = input.token_counter;
308
309        let effective_strategy = resolve_effective_strategy(memory, input.sidequest_turn_counter);
310        let memory_first = effective_strategy == zeph_config::ContextStrategy::MemoryFirst;
311
312        let system_prompt = input
313            .messages
314            .first()
315            .filter(|m| m.role == Role::System)
316            .map_or("", |m| m.content.as_str());
317
318        let digest_tokens = memory
319            .cached_session_digest
320            .as_ref()
321            .map_or(0, |(_, tokens)| *tokens);
322
323        let alloc = budget.allocate_with_opts(
324            system_prompt,
325            input.skills_prompt,
326            tc,
327            memory.graph_config.enabled,
328            digest_tokens,
329            memory_first,
330        );
331
332        let (recall_limit, min_sim) = correction_params(input.correction_config.as_ref());
333
334        let router_ref: &dyn AsyncMemoryRouter = input.router.as_ref();
335
336        tracing::debug!(
337            active_sources = alloc.active_sources(),
338            active_levels = ?input.active_levels,
339            "context budget allocated"
340        );
341
342        let fetchers = schedule_context_fetchers(
343            memory,
344            tc,
345            input.query,
346            input.scrub,
347            input.index,
348            router_ref,
349            alloc.summaries,
350            alloc.cross_session,
351            alloc.semantic_recall,
352            alloc.code_context,
353            alloc.graph_facts,
354            recall_limit,
355            min_sim,
356            input.active_levels,
357        );
358
359        let mut prepared = empty_prepared_context();
360        prepared.memory_first = memory_first;
361        prepared.recent_history_budget = alloc.recent_history;
362
363        drive_fetchers(fetchers, &mut prepared).await?;
364        Ok(prepared)
365    }
366}
367
368/// Clamp recall timeout to a safe minimum.
369///
370/// A configured value of 0 would disable spreading activation recall entirely;
371/// clamping to 100ms preserves the user's intent while preventing a silent no-op.
372pub fn effective_recall_timeout_ms(configured: u64) -> u64 {
373    if configured == 0 {
374        tracing::warn!(
375            "recall_timeout_ms is 0, which would disable spreading activation recall; \
376             clamping to 100ms"
377        );
378        100
379    } else {
380        configured
381    }
382}
383
384use crate::input::ContextMemoryView;
385
386#[tracing::instrument(name = "context.graph_facts", skip_all)]
387#[allow(clippy::too_many_lines)] // single-pass view-aware enrichment pipeline
388pub(crate) async fn fetch_graph_facts(
389    memory: &ContextMemoryView,
390    query: &str,
391    budget_tokens: usize,
392    tc: &dyn TokenCounting,
393) -> Result<Option<Message>, AssemblerError> {
394    use zeph_common::memory::{RecallView, SpreadingActivationParams, classify_graph_subgraph};
395
396    if budget_tokens == 0 || !memory.graph_config.enabled {
397        return Ok(None);
398    }
399    let Some(ref mem) = memory.memory else {
400        return Ok(None);
401    };
402    let recall_limit = memory.graph_config.recall_limit;
403    let temporal_decay_rate = memory.graph_config.temporal_decay_rate;
404    let sa_config = &memory.graph_config.spreading_activation;
405
406    // Fuse MemCoT semantic state into the recall query (spec §A8: state ≤ 2 × query.len()).
407    let fused_query;
408    let effective_query = if let Some(ref state) = memory.memcot_state {
409        let max_state_chars = 2 * query.len();
410        let state_slice = if state.len() > max_state_chars {
411            let boundary = state.floor_char_boundary(max_state_chars);
412            &state[..boundary]
413        } else {
414            state.as_str()
415        };
416        fused_query = format!("[state] {state_slice}\n{query}");
417        &fused_query as &str
418    } else {
419        query
420    };
421
422    let edge_types = classify_graph_subgraph(effective_query);
423
424    let view = match memory.memcot_config.recall_view {
425        zeph_config::RecallViewConfig::ZoomIn => RecallView::ZoomIn,
426        zeph_config::RecallViewConfig::ZoomOut => RecallView::ZoomOut,
427        _ => RecallView::Head,
428    };
429
430    let sa_params = if sa_config.enabled {
431        Some(SpreadingActivationParams {
432            decay_lambda: sa_config.decay_lambda,
433            max_hops: sa_config.max_hops,
434            activation_threshold: sa_config.activation_threshold,
435            inhibition_threshold: sa_config.inhibition_threshold,
436            max_activated_nodes: sa_config.max_activated_nodes,
437            temporal_decay_rate,
438            seed_structural_weight: sa_config.seed_structural_weight,
439            seed_community_cap: sa_config.seed_community_cap,
440            alpha: sa_config.alpha,
441        })
442    } else {
443        None
444    };
445
446    let timeout_ms = effective_recall_timeout_ms(sa_config.recall_timeout_ms);
447    let recall_fut = mem.recall_graph_facts(
448        effective_query,
449        GraphRecallParams {
450            limit: recall_limit,
451            view,
452            zoom_out_neighbor_cap: memory.memcot_config.zoom_out_neighbor_cap,
453            max_hops: memory.graph_config.max_hops,
454            temporal_decay_rate,
455            edge_types: &edge_types,
456            spreading_activation: sa_params,
457        },
458    );
459    let recalled = match tokio::time::timeout(
460        std::time::Duration::from_millis(timeout_ms),
461        recall_fut,
462    )
463    .await
464    {
465        Ok(Ok(facts)) => facts,
466        Ok(Err(e)) => {
467            tracing::warn!("graph recall failed: {e:#}");
468            Vec::new()
469        }
470        Err(_) => {
471            tracing::warn!("graph recall timed out ({timeout_ms}ms)");
472            Vec::new()
473        }
474    };
475
476    if recalled.is_empty() {
477        return Ok(None);
478    }
479
480    let mut body = String::from(GRAPH_FACTS_PREFIX);
481    let mut tokens_so_far = tc.count_tokens(&body);
482
483    for rf in &recalled {
484        let fact_text = rf.fact.replace(['\n', '\r', '<', '>'], " ");
485        let line = if let Some(score) = rf.activation_score {
486            format!(
487                "- {} (confidence: {:.2}, activation: {:.2})\n",
488                fact_text, rf.confidence, score
489            )
490        } else {
491            format!("- {} (confidence: {:.2})\n", fact_text, rf.confidence)
492        };
493        let line_tokens = tc.count_tokens(&line);
494        if tokens_so_far + line_tokens > budget_tokens {
495            break;
496        }
497        body.push_str(&line);
498        tokens_so_far += line_tokens;
499
500        // Append ZoomOut neighbors after the head fact.
501        for nb in &rf.neighbors {
502            let nb_text = nb.fact.replace(['\n', '\r', '<', '>'], " ");
503            let nb_line = format!("  ~ {} (confidence: {:.2})\n", nb_text, nb.confidence);
504            let nb_tokens = tc.count_tokens(&nb_line);
505            if tokens_so_far + nb_tokens > budget_tokens {
506                break;
507            }
508            body.push_str(&nb_line);
509            tokens_so_far += nb_tokens;
510        }
511
512        // Append ZoomIn provenance snippet if present.
513        if let Some(ref snippet) = rf.provenance_snippet {
514            let snip_line = format!(
515                "  [source: {}]\n",
516                snippet.replace(['\n', '\r', '<', '>'], " ")
517            );
518            let snip_tokens = tc.count_tokens(&snip_line);
519            if tokens_so_far + snip_tokens <= budget_tokens {
520                body.push_str(&snip_line);
521                tokens_so_far += snip_tokens;
522            }
523        }
524    }
525
526    if body == GRAPH_FACTS_PREFIX {
527        return Ok(None);
528    }
529
530    Ok(Some(Message::from_legacy(Role::System, body)))
531}
532
533#[tracing::instrument(name = "context.persona_facts", skip_all)]
534pub(crate) async fn fetch_persona_facts(
535    memory: &ContextMemoryView,
536    budget_tokens: usize,
537    tc: &dyn TokenCounting,
538) -> Result<Option<Message>, AssemblerError> {
539    if budget_tokens == 0 || !memory.persona_config.enabled {
540        return Ok(None);
541    }
542    let Some(ref mem) = memory.memory else {
543        return Ok(None);
544    };
545
546    let min_confidence = memory.persona_config.min_confidence;
547    let facts = mem
548        .load_persona_facts(min_confidence)
549        .await
550        .map_err(AssemblerError::Memory)?;
551
552    if facts.is_empty() {
553        return Ok(None);
554    }
555
556    let mut body = String::from(crate::slot::PERSONA_PREFIX);
557    let mut tokens_so_far = tc.count_tokens(&body);
558
559    for fact in &facts {
560        let line = format!("[{}] {}\n", fact.category, fact.content);
561        let line_tokens = tc.count_tokens(&line);
562        if tokens_so_far + line_tokens > budget_tokens {
563            break;
564        }
565        body.push_str(&line);
566        tokens_so_far += line_tokens;
567    }
568
569    if body == crate::slot::PERSONA_PREFIX {
570        return Ok(None);
571    }
572
573    Ok(Some(Message::from_legacy(Role::System, body)))
574}
575
576#[tracing::instrument(name = "context.trajectory_hints", skip_all)]
577pub(crate) async fn fetch_trajectory_hints(
578    memory: &ContextMemoryView,
579    budget_tokens: usize,
580    tc: &dyn TokenCounting,
581) -> Result<Option<Message>, AssemblerError> {
582    if budget_tokens == 0 || !memory.trajectory_config.enabled {
583        return Ok(None);
584    }
585    let Some(ref mem) = memory.memory else {
586        return Ok(None);
587    };
588
589    let top_k = memory.trajectory_config.recall_top_k;
590    let min_conf = memory.trajectory_config.min_confidence;
591    // Load procedural trajectory entries via the backend abstraction.
592    // The "procedural" filter maps to the same tier used by the original
593    // sqlite().load_trajectory_entries(Some("procedural"), top_k) call.
594    let entries = mem
595        .load_trajectory_entries(Some("procedural"), top_k)
596        .await
597        .map_err(AssemblerError::Memory)?;
598
599    if entries.is_empty() {
600        return Ok(None);
601    }
602
603    let mut body = String::from(crate::slot::TRAJECTORY_PREFIX);
604    let mut tokens_so_far = tc.count_tokens(&body);
605
606    for entry in entries
607        .iter()
608        .filter(|e| e.confidence >= min_conf)
609        .take(top_k)
610    {
611        let line = format!("- {}: {}\n", entry.intent, entry.outcome);
612        let line_tokens = tc.count_tokens(&line);
613        if tokens_so_far + line_tokens > budget_tokens {
614            break;
615        }
616        body.push_str(&line);
617        tokens_so_far += line_tokens;
618    }
619
620    if body == crate::slot::TRAJECTORY_PREFIX {
621        return Ok(None);
622    }
623
624    Ok(Some(Message::from_legacy(Role::System, body)))
625}
626
627#[tracing::instrument(name = "context.tree_memory", skip_all)]
628pub(crate) async fn fetch_tree_memory(
629    memory: &ContextMemoryView,
630    budget_tokens: usize,
631    tc: &dyn TokenCounting,
632) -> Result<Option<Message>, AssemblerError> {
633    if budget_tokens == 0 || !memory.tree_config.enabled {
634        return Ok(None);
635    }
636    let Some(ref mem) = memory.memory else {
637        return Ok(None);
638    };
639
640    let top_k = memory.tree_config.recall_top_k;
641    let nodes = mem
642        .load_tree_nodes(1, top_k)
643        .await
644        .map_err(AssemblerError::Memory)?;
645
646    if nodes.is_empty() {
647        return Ok(None);
648    }
649
650    let mut body = String::from(crate::slot::TREE_MEMORY_PREFIX);
651    let mut tokens_so_far = tc.count_tokens(&body);
652
653    for node in nodes.iter().take(top_k) {
654        let line = format!("- {}\n", node.content);
655        let line_tokens = tc.count_tokens(&line);
656        if tokens_so_far + line_tokens > budget_tokens {
657            break;
658        }
659        body.push_str(&line);
660        tokens_so_far += line_tokens;
661    }
662
663    if body == crate::slot::TREE_MEMORY_PREFIX {
664        return Ok(None);
665    }
666
667    Ok(Some(Message::from_legacy(Role::System, body)))
668}
669
670#[tracing::instrument(name = "context.reasoning_strategies", skip_all)]
671pub(crate) async fn fetch_reasoning_strategies(
672    memory: &ContextMemoryView,
673    query: &str,
674    budget_tokens: usize,
675    top_k: usize,
676    tc: &dyn TokenCounting,
677) -> Result<(Option<Message>, Option<tokio::task::JoinHandle<()>>), AssemblerError> {
678    // S1: enforce the ≤500-token spec cap documented in ReasoningConfig.
679    let budget_tokens = budget_tokens.min(500);
680    if budget_tokens == 0 {
681        return Ok((None, None));
682    }
683    let Some(ref mem) = memory.memory else {
684        return Ok((None, None));
685    };
686
687    let strategies = mem
688        .retrieve_reasoning_strategies(query, top_k)
689        .await
690        .map_err(AssemblerError::Memory)?;
691
692    if strategies.is_empty() {
693        return Ok((None, None));
694    }
695
696    let mut body = String::from(crate::slot::REASONING_PREFIX);
697    let mut tokens_so_far = tc.count_tokens(&body);
698    let mut injected_ids: Vec<String> = Vec::new();
699
700    for s in strategies.iter().take(top_k) {
701        // S-Med1: sanitize distilled summaries to prevent stored injection payloads
702        // from reaching the system prompt (mirrors fetch_graph_facts scrub pattern).
703        let safe_summary = s.summary.replace(['\n', '\r', '<', '>'], " ");
704        let line = format!("- [{}] {}\n", s.outcome, safe_summary);
705        let line_tokens = tc.count_tokens(&line);
706        if tokens_so_far + line_tokens > budget_tokens {
707            break;
708        }
709        body.push_str(&line);
710        tokens_so_far += line_tokens;
711        injected_ids.push(s.id.clone());
712    }
713
714    if body == crate::slot::REASONING_PREFIX {
715        return Ok((None, None));
716    }
717
718    // C4 split: mark_used only for strategies that made it past budget truncation.
719    // Spawn the task and return the handle so the caller can track it (async discipline rule:
720    // fire-and-forget tasks MUST be tracked; handle stored in PreparedContext::background_tasks).
721    let handle = if injected_ids.is_empty() {
722        None
723    } else {
724        let mem_clone = mem.clone();
725        Some(tokio::spawn(async move {
726            if let Err(e) = mem_clone.mark_reasoning_used(&injected_ids).await {
727                tracing::warn!(error = %e, "reasoning: mark_used failed");
728            }
729        }))
730    };
731
732    Ok((Some(Message::from_legacy(Role::System, body)), handle))
733}
734
735#[tracing::instrument(name = "context.corrections", skip_all)]
736pub(crate) async fn fetch_corrections(
737    memory: &ContextMemoryView,
738    query: &str,
739    limit: usize,
740    min_score: f32,
741    scrub: fn(&str) -> std::borrow::Cow<'_, str>,
742) -> Result<Option<Message>, AssemblerError> {
743    let Some(ref mem) = memory.memory else {
744        return Ok(None);
745    };
746    let corrections = mem
747        .retrieve_corrections(query, limit, min_score)
748        .await
749        .map_err(AssemblerError::Memory)?;
750    if corrections.is_empty() {
751        return Ok(None);
752    }
753    let mut text = String::from(CORRECTIONS_PREFIX);
754    for c in &corrections {
755        text.push_str("- Past user correction: \"");
756        text.push_str(&scrub(&c.correction_text));
757        text.push_str("\"\n");
758    }
759    Ok(Some(Message::from_legacy(Role::System, text)))
760}
761
762#[tracing::instrument(name = "context.semantic_recall", skip_all)]
763pub(crate) async fn fetch_semantic_recall(
764    memory: &ContextMemoryView,
765    query: &str,
766    token_budget: usize,
767    tc: &dyn TokenCounting,
768    router: Option<&dyn AsyncMemoryRouter>,
769) -> Result<(Option<Message>, Option<f32>), AssemblerError> {
770    let Some(ref mem) = memory.memory else {
771        return Ok((None, None));
772    };
773    if memory.recall_limit == 0 || token_budget == 0 {
774        return Ok((None, None));
775    }
776
777    let recalled = mem
778        .recall(query, memory.recall_limit, router)
779        .await
780        .map_err(AssemblerError::Memory)?;
781    if recalled.is_empty() {
782        return Ok((None, None));
783    }
784
785    let top_score = recalled.first().map(|r| r.score);
786
787    let mut recall_text = String::with_capacity(token_budget * 3);
788    recall_text.push_str(RECALL_PREFIX);
789    let mut tokens_used = tc.count_tokens(&recall_text);
790
791    for item in &recalled {
792        if item.content.starts_with("[skipped]") || item.content.starts_with("[stopped]") {
793            continue;
794        }
795        let entry = format!("- [{}] {}\n", item.role, item.content);
796        let entry_tokens = tc.count_tokens(&entry);
797        if tokens_used + entry_tokens > token_budget {
798            break;
799        }
800        recall_text.push_str(&entry);
801        tokens_used += entry_tokens;
802    }
803
804    if tokens_used > tc.count_tokens(RECALL_PREFIX) {
805        Ok((
806            Some(Message::from_parts(
807                Role::System,
808                vec![MessagePart::Recall { text: recall_text }],
809            )),
810            top_score,
811        ))
812    } else {
813        Ok((None, None))
814    }
815}
816
817#[tracing::instrument(name = "context.document_rag", skip_all)]
818pub(crate) async fn fetch_document_rag(
819    memory: &ContextMemoryView,
820    query: &str,
821    token_budget: usize,
822    tc: &dyn TokenCounting,
823) -> Result<Option<Message>, AssemblerError> {
824    if !memory.document_config.rag_enabled || token_budget == 0 {
825        return Ok(None);
826    }
827    let Some(ref mem) = memory.memory else {
828        return Ok(None);
829    };
830
831    let collection = &memory.document_config.collection;
832    let top_k = memory.document_config.top_k;
833    let chunks = mem
834        .search_document_collection(collection, query, top_k)
835        .await
836        .map_err(AssemblerError::Memory)?;
837    if chunks.is_empty() {
838        return Ok(None);
839    }
840
841    let mut text = String::from(DOCUMENT_RAG_PREFIX);
842    let mut tokens_used = tc.count_tokens(&text);
843
844    for chunk in &chunks {
845        if chunk.text.is_empty() {
846            continue;
847        }
848        let entry = format!("{}\n", chunk.text);
849        let cost = tc.count_tokens(&entry);
850        if tokens_used + cost > token_budget {
851            break;
852        }
853        text.push_str(&entry);
854        tokens_used += cost;
855    }
856
857    if tokens_used > tc.count_tokens(DOCUMENT_RAG_PREFIX) {
858        Ok(Some(Message {
859            role: Role::System,
860            content: text,
861            parts: vec![],
862            metadata: MessageMetadata::default(),
863        }))
864    } else {
865        Ok(None)
866    }
867}
868
869#[tracing::instrument(name = "context.summaries", skip_all)]
870pub(crate) async fn fetch_summaries(
871    memory: &ContextMemoryView,
872    token_budget: usize,
873    tc: &dyn TokenCounting,
874) -> Result<Option<Message>, AssemblerError> {
875    let (Some(mem), Some(cid)) = (&memory.memory, memory.conversation_id) else {
876        return Ok(None);
877    };
878    if token_budget == 0 {
879        return Ok(None);
880    }
881
882    let summaries = mem
883        .load_summaries(cid)
884        .await
885        .map_err(AssemblerError::Memory)?;
886    if summaries.is_empty() {
887        return Ok(None);
888    }
889
890    let mut summary_text = String::from(SUMMARY_PREFIX);
891    let mut tokens_used = tc.count_tokens(&summary_text);
892
893    for summary in summaries.iter().rev() {
894        let first = summary.first_message_id.unwrap_or(0);
895        let last = summary.last_message_id.unwrap_or(0);
896        let entry = format!("- Messages {first}-{last}: {}\n", summary.content);
897        let cost = tc.count_tokens(&entry);
898        if tokens_used + cost > token_budget {
899            break;
900        }
901        summary_text.push_str(&entry);
902        tokens_used += cost;
903    }
904
905    if tokens_used > tc.count_tokens(SUMMARY_PREFIX) {
906        Ok(Some(Message::from_parts(
907            Role::System,
908            vec![MessagePart::Summary { text: summary_text }],
909        )))
910    } else {
911        Ok(None)
912    }
913}
914
915#[tracing::instrument(name = "context.cross_session", skip_all)]
916pub(crate) async fn fetch_cross_session(
917    memory: &ContextMemoryView,
918    query: &str,
919    token_budget: usize,
920    tc: &dyn TokenCounting,
921) -> Result<Option<Message>, AssemblerError> {
922    let (Some(mem), Some(cid)) = (&memory.memory, memory.conversation_id) else {
923        return Ok(None);
924    };
925    if token_budget == 0 {
926        return Ok(None);
927    }
928
929    let threshold = memory.cross_session_score_threshold;
930    let results: Vec<_> = mem
931        .search_session_summaries(query, 5, Some(cid))
932        .await
933        .map_err(AssemblerError::Memory)?
934        .into_iter()
935        .filter(|r| r.score >= threshold)
936        .collect();
937    if results.is_empty() {
938        return Ok(None);
939    }
940
941    let mut text = String::from(CROSS_SESSION_PREFIX);
942    let mut tokens_used = tc.count_tokens(&text);
943
944    for item in &results {
945        let entry = format!("- {}\n", item.summary_text);
946        let cost = tc.count_tokens(&entry);
947        if tokens_used + cost > token_budget {
948            break;
949        }
950        text.push_str(&entry);
951        tokens_used += cost;
952    }
953
954    if tokens_used > tc.count_tokens(CROSS_SESSION_PREFIX) {
955        Ok(Some(Message::from_parts(
956            Role::System,
957            vec![MessagePart::CrossSession { text }],
958        )))
959    } else {
960        Ok(None)
961    }
962}
963
964/// Maximum number of messages scanned backward by [`memory_first_keep_tail`] before
965/// stopping at the next non-`ToolResult` boundary, to avoid O(N) scans on long sessions.
966pub const MAX_KEEP_TAIL_SCAN: usize = 50;
967
968/// Compute how many tail messages to keep when the `MemoryFirst` strategy is active.
969///
970/// Always keeps at least 2 messages. Extends the tail as long as the boundary message is
971/// a `ToolResult` (user message with a `ToolResult` part) to avoid splitting a tool-call
972/// round-trip. Capped at `MAX_KEEP_TAIL_SCAN` to prevent O(N) scans on long sessions.
973///
974/// `history_start` is the index of the first non-system message (typically 1).
975#[must_use]
976pub fn memory_first_keep_tail(messages: &[Message], history_start: usize) -> usize {
977    use zeph_llm::provider::MessagePart;
978
979    let mut keep_tail = 2usize;
980    let len = messages.len();
981    let max = len.saturating_sub(history_start);
982
983    while keep_tail < max {
984        let first_retained = &messages[len - keep_tail];
985        let is_tool_result = first_retained.role == Role::User
986            && first_retained
987                .parts
988                .iter()
989                .any(|p| matches!(p, MessagePart::ToolResult { .. }));
990
991        if is_tool_result {
992            keep_tail += 1;
993        } else {
994            break;
995        }
996
997        if keep_tail >= MAX_KEEP_TAIL_SCAN {
998            let preceding_idx = len.saturating_sub(keep_tail + 1);
999            if preceding_idx >= history_start {
1000                let preceding = &messages[preceding_idx];
1001                let is_tool_use = preceding.role == Role::Assistant
1002                    && preceding
1003                        .parts
1004                        .iter()
1005                        .any(|p| matches!(p, MessagePart::ToolUse { .. }));
1006                if is_tool_use {
1007                    keep_tail += 1;
1008                }
1009            }
1010            break;
1011        }
1012    }
1013
1014    keep_tail
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019    use super::*;
1020    use crate::input::ContextMemoryView;
1021    use zeph_common::memory::CompressionLevel;
1022    use zeph_config::{
1023        ContextStrategy, DocumentConfig, GraphConfig, PersonaConfig, ReasoningConfig,
1024        TrajectoryConfig, TreeConfig,
1025    };
1026
1027    struct NaiveTokenCounter;
1028    impl zeph_common::memory::TokenCounting for NaiveTokenCounter {
1029        fn count_tokens(&self, text: &str) -> usize {
1030            text.split_whitespace().count()
1031        }
1032        fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
1033            schema.to_string().split_whitespace().count()
1034        }
1035    }
1036
1037    fn empty_view() -> ContextMemoryView {
1038        ContextMemoryView {
1039            memory: None,
1040            conversation_id: None,
1041            recall_limit: 10,
1042            cross_session_score_threshold: 0.5,
1043            context_strategy: ContextStrategy::default(),
1044            crossover_turn_threshold: 5,
1045            cached_session_digest: None,
1046            graph_config: GraphConfig::default(),
1047            document_config: DocumentConfig::default(),
1048            persona_config: PersonaConfig::default(),
1049            trajectory_config: TrajectoryConfig::default(),
1050            reasoning_config: ReasoningConfig::default(),
1051            memcot_config: zeph_config::MemCotConfig::default(),
1052            memcot_state: None,
1053            tree_config: TreeConfig::default(),
1054        }
1055    }
1056
1057    // ── fetch_graph_facts ─────────────────────────────────────────────────────
1058
1059    #[tokio::test]
1060    async fn fetch_graph_facts_returns_none_when_memory_is_none() {
1061        let view = empty_view();
1062        let tc = NaiveTokenCounter;
1063        let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1064        assert!(result.is_none());
1065    }
1066
1067    #[tokio::test]
1068    async fn fetch_graph_facts_returns_none_when_budget_zero() {
1069        let mut view = empty_view();
1070        view.graph_config.enabled = true;
1071        let tc = NaiveTokenCounter;
1072        let result = fetch_graph_facts(&view, "test", 0, &tc).await.unwrap();
1073        assert!(result.is_none());
1074    }
1075
1076    #[tokio::test]
1077    async fn fetch_graph_facts_returns_none_when_graph_disabled() {
1078        let mut view = empty_view();
1079        view.graph_config.enabled = false;
1080        let tc = NaiveTokenCounter;
1081        let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1082        assert!(result.is_none());
1083    }
1084
1085    // ── fetch_persona_facts ───────────────────────────────────────────────────
1086
1087    #[tokio::test]
1088    async fn fetch_persona_facts_returns_none_when_memory_is_none() {
1089        let view = empty_view();
1090        let tc = NaiveTokenCounter;
1091        let result = fetch_persona_facts(&view, 1000, &tc).await.unwrap();
1092        assert!(result.is_none());
1093    }
1094
1095    #[tokio::test]
1096    async fn fetch_persona_facts_returns_none_when_budget_zero() {
1097        let mut view = empty_view();
1098        view.persona_config.enabled = true;
1099        let tc = NaiveTokenCounter;
1100        let result = fetch_persona_facts(&view, 0, &tc).await.unwrap();
1101        assert!(result.is_none());
1102    }
1103
1104    // ── fetch_trajectory_hints ────────────────────────────────────────────────
1105
1106    #[tokio::test]
1107    async fn fetch_trajectory_hints_returns_none_when_memory_is_none() {
1108        let view = empty_view();
1109        let tc = NaiveTokenCounter;
1110        let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1111        assert!(result.is_none());
1112    }
1113
1114    #[tokio::test]
1115    async fn fetch_trajectory_hints_returns_none_when_budget_zero() {
1116        let mut view = empty_view();
1117        view.trajectory_config.enabled = true;
1118        let tc = NaiveTokenCounter;
1119        let result = fetch_trajectory_hints(&view, 0, &tc).await.unwrap();
1120        assert!(result.is_none());
1121    }
1122
1123    // ── fetch_tree_memory ─────────────────────────────────────────────────────
1124
1125    #[tokio::test]
1126    async fn fetch_tree_memory_returns_none_when_memory_is_none() {
1127        let view = empty_view();
1128        let tc = NaiveTokenCounter;
1129        let result = fetch_tree_memory(&view, 1000, &tc).await.unwrap();
1130        assert!(result.is_none());
1131    }
1132
1133    #[tokio::test]
1134    async fn fetch_tree_memory_returns_none_when_budget_zero() {
1135        let mut view = empty_view();
1136        view.tree_config.enabled = true;
1137        let tc = NaiveTokenCounter;
1138        let result = fetch_tree_memory(&view, 0, &tc).await.unwrap();
1139        assert!(result.is_none());
1140    }
1141
1142    // ── fetch_corrections ─────────────────────────────────────────────────────
1143
1144    #[tokio::test]
1145    async fn fetch_corrections_returns_none_when_memory_is_none() {
1146        let view = empty_view();
1147        let result = fetch_corrections(&view, "test", 10, 0.5, |s| s.into())
1148            .await
1149            .unwrap();
1150        assert!(result.is_none());
1151    }
1152
1153    // ── fetch_semantic_recall ─────────────────────────────────────────────────
1154
1155    #[tokio::test]
1156    async fn fetch_semantic_recall_returns_none_when_memory_is_none() {
1157        let view = empty_view();
1158        let tc = NaiveTokenCounter;
1159        let result = fetch_semantic_recall(&view, "test", 1000, &tc, None)
1160            .await
1161            .unwrap();
1162        assert!(result.0.is_none() && result.1.is_none());
1163    }
1164
1165    #[tokio::test]
1166    async fn fetch_semantic_recall_returns_none_when_budget_zero() {
1167        let view = empty_view();
1168        let tc = NaiveTokenCounter;
1169        let result = fetch_semantic_recall(&view, "test", 0, &tc, None)
1170            .await
1171            .unwrap();
1172        assert!(result.0.is_none() && result.1.is_none());
1173    }
1174
1175    // ── fetch_document_rag ────────────────────────────────────────────────────
1176
1177    #[tokio::test]
1178    async fn fetch_document_rag_returns_none_when_memory_is_none() {
1179        let mut view = empty_view();
1180        view.document_config.rag_enabled = true;
1181        let tc = NaiveTokenCounter;
1182        let result = fetch_document_rag(&view, "test", 1000, &tc).await.unwrap();
1183        assert!(result.is_none());
1184    }
1185
1186    #[tokio::test]
1187    async fn fetch_document_rag_returns_none_when_rag_disabled() {
1188        let view = empty_view();
1189        let tc = NaiveTokenCounter;
1190        let result = fetch_document_rag(&view, "test", 1000, &tc).await.unwrap();
1191        assert!(result.is_none());
1192    }
1193
1194    // ── fetch_summaries ───────────────────────────────────────────────────────
1195
1196    #[tokio::test]
1197    async fn fetch_summaries_returns_none_when_memory_is_none() {
1198        let view = empty_view();
1199        let tc = NaiveTokenCounter;
1200        let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1201        assert!(result.is_none());
1202    }
1203
1204    // ── fetch_cross_session ───────────────────────────────────────────────────
1205
1206    #[tokio::test]
1207    async fn fetch_cross_session_returns_none_when_memory_is_none() {
1208        let view = empty_view();
1209        let tc = NaiveTokenCounter;
1210        let result = fetch_cross_session(&view, "test", 1000, &tc).await.unwrap();
1211        assert!(result.is_none());
1212    }
1213
1214    // ── levels_to_flags ───────────────────────────────────────────────────────
1215
1216    #[test]
1217    fn levels_to_flags_empty_slice_enables_all_tiers() {
1218        let (e, p, d) = levels_to_flags(&[]);
1219        assert!(e, "episodic should be active for empty slice");
1220        assert!(p, "procedural should be active for empty slice");
1221        assert!(d, "declarative should be active for empty slice");
1222    }
1223
1224    #[test]
1225    fn levels_to_flags_full_set_enables_all_tiers() {
1226        let all = &[
1227            CompressionLevel::Episodic,
1228            CompressionLevel::Procedural,
1229            CompressionLevel::Declarative,
1230        ];
1231        let (e, p, d) = levels_to_flags(all);
1232        assert!(e);
1233        assert!(p);
1234        assert!(d);
1235    }
1236
1237    #[test]
1238    fn levels_to_flags_episodic_only() {
1239        let (e, p, d) = levels_to_flags(&[CompressionLevel::Episodic]);
1240        assert!(e);
1241        assert!(!p, "procedural should be inactive");
1242        assert!(!d, "declarative should be inactive");
1243    }
1244
1245    #[test]
1246    fn levels_to_flags_episodic_and_procedural() {
1247        let (e, p, d) =
1248            levels_to_flags(&[CompressionLevel::Episodic, CompressionLevel::Procedural]);
1249        assert!(e);
1250        assert!(p);
1251        assert!(!d, "declarative should be inactive");
1252    }
1253
1254    #[test]
1255    fn levels_to_flags_declarative_only() {
1256        let (e, p, d) = levels_to_flags(&[CompressionLevel::Declarative]);
1257        assert!(!e, "episodic should be inactive");
1258        assert!(!p, "procedural should be inactive");
1259        assert!(d);
1260    }
1261
1262    // ── fetch_reasoning_strategies ────────────────────────────────────────────
1263
1264    #[tokio::test]
1265    async fn fetch_reasoning_strategies_returns_none_when_memory_is_none() {
1266        let mut view = empty_view();
1267        view.reasoning_config.enabled = true;
1268        let tc = NaiveTokenCounter;
1269        let (result, handle) = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc)
1270            .await
1271            .unwrap();
1272        assert!(result.is_none());
1273        assert!(handle.is_none());
1274    }
1275
1276    #[tokio::test]
1277    async fn fetch_reasoning_strategies_returns_none_when_budget_zero() {
1278        let mut view = empty_view();
1279        view.reasoning_config.enabled = true;
1280        let tc = NaiveTokenCounter;
1281        let (result, handle) = fetch_reasoning_strategies(&view, "query", 0, 3, &tc)
1282            .await
1283            .unwrap();
1284        assert!(result.is_none());
1285        assert!(handle.is_none());
1286    }
1287
1288    // ── MockMemoryBackend ─────────────────────────────────────────────────────
1289
1290    use std::sync::{Arc, Mutex};
1291    use zeph_common::memory::{
1292        ContextMemoryBackend, GraphRecallParams, MemCorrection, MemDocumentChunk, MemGraphFact,
1293        MemPersonaFact, MemReasoningStrategy, MemRecalledMessage, MemSessionSummary, MemSummary,
1294        MemTrajectoryEntry, MemTreeNode,
1295    };
1296
1297    /// Known method names accepted by [`MockMemoryBackend::fail_on`].
1298    const KNOWN_FAIL_ON: &[&str] = &[
1299        "load_persona_facts",
1300        "load_trajectory_entries",
1301        "load_tree_nodes",
1302        "load_summaries",
1303        "retrieve_reasoning_strategies",
1304        "mark_reasoning_used",
1305        "retrieve_corrections",
1306        "recall",
1307        "recall_graph_facts",
1308        "search_session_summaries",
1309        "search_document_collection",
1310    ];
1311
1312    #[derive(Default)]
1313    struct MockMemoryBackend {
1314        persona_facts: Vec<MemPersonaFact>,
1315        trajectory_entries: Vec<MemTrajectoryEntry>,
1316        tree_nodes: Vec<MemTreeNode>,
1317        summaries: Vec<MemSummary>,
1318        reasoning_strategies: Vec<MemReasoningStrategy>,
1319        corrections: Vec<MemCorrection>,
1320        recalled: Vec<MemRecalledMessage>,
1321        graph_facts: Vec<MemGraphFact>,
1322        session_summaries: Vec<MemSessionSummary>,
1323        document_chunks: Vec<MemDocumentChunk>,
1324        /// When `Some("method_name")`, that method returns `Err(...)`.
1325        fail_on: Option<&'static str>,
1326        /// Tracks IDs passed to `mark_reasoning_used`.
1327        marked_ids: Mutex<Vec<String>>,
1328    }
1329
1330    impl MockMemoryBackend {
1331        fn with_fail_on(method: &'static str) -> Self {
1332            debug_assert!(
1333                KNOWN_FAIL_ON.contains(&method),
1334                "unknown fail_on method name: {method}"
1335            );
1336            Self {
1337                fail_on: Some(method),
1338                ..Default::default()
1339            }
1340        }
1341
1342        fn fail_err(method: &str) -> Box<dyn std::error::Error + Send + Sync> {
1343            format!("mock error in {method}").into()
1344        }
1345    }
1346
1347    impl ContextMemoryBackend for MockMemoryBackend {
1348        fn load_persona_facts<'a>(
1349            &'a self,
1350            _min_confidence: f64,
1351        ) -> std::pin::Pin<
1352            Box<
1353                dyn std::future::Future<
1354                        Output = Result<
1355                            Vec<MemPersonaFact>,
1356                            Box<dyn std::error::Error + Send + Sync>,
1357                        >,
1358                    > + Send
1359                    + 'a,
1360            >,
1361        > {
1362            let result = if self.fail_on == Some("load_persona_facts") {
1363                Err(Self::fail_err("load_persona_facts"))
1364            } else {
1365                Ok(self.persona_facts.clone())
1366            };
1367            Box::pin(async move { result })
1368        }
1369
1370        fn load_trajectory_entries<'a>(
1371            &'a self,
1372            _tier: Option<&'a str>,
1373            _top_k: usize,
1374        ) -> std::pin::Pin<
1375            Box<
1376                dyn std::future::Future<
1377                        Output = Result<
1378                            Vec<MemTrajectoryEntry>,
1379                            Box<dyn std::error::Error + Send + Sync>,
1380                        >,
1381                    > + Send
1382                    + 'a,
1383            >,
1384        > {
1385            let result = if self.fail_on == Some("load_trajectory_entries") {
1386                Err(Self::fail_err("load_trajectory_entries"))
1387            } else {
1388                Ok(self.trajectory_entries.clone())
1389            };
1390            Box::pin(async move { result })
1391        }
1392
1393        fn load_tree_nodes<'a>(
1394            &'a self,
1395            _level: u32,
1396            _top_k: usize,
1397        ) -> std::pin::Pin<
1398            Box<
1399                dyn std::future::Future<
1400                        Output = Result<Vec<MemTreeNode>, Box<dyn std::error::Error + Send + Sync>>,
1401                    > + Send
1402                    + 'a,
1403            >,
1404        > {
1405            let result = if self.fail_on == Some("load_tree_nodes") {
1406                Err(Self::fail_err("load_tree_nodes"))
1407            } else {
1408                Ok(self.tree_nodes.clone())
1409            };
1410            Box::pin(async move { result })
1411        }
1412
1413        fn load_summaries<'a>(
1414            &'a self,
1415            _conversation_id: i64,
1416        ) -> std::pin::Pin<
1417            Box<
1418                dyn std::future::Future<
1419                        Output = Result<Vec<MemSummary>, Box<dyn std::error::Error + Send + Sync>>,
1420                    > + Send
1421                    + 'a,
1422            >,
1423        > {
1424            let result = if self.fail_on == Some("load_summaries") {
1425                Err(Self::fail_err("load_summaries"))
1426            } else {
1427                Ok(self.summaries.clone())
1428            };
1429            Box::pin(async move { result })
1430        }
1431
1432        fn retrieve_reasoning_strategies<'a>(
1433            &'a self,
1434            _query: &'a str,
1435            _top_k: usize,
1436        ) -> std::pin::Pin<
1437            Box<
1438                dyn std::future::Future<
1439                        Output = Result<
1440                            Vec<MemReasoningStrategy>,
1441                            Box<dyn std::error::Error + Send + Sync>,
1442                        >,
1443                    > + Send
1444                    + 'a,
1445            >,
1446        > {
1447            let result = if self.fail_on == Some("retrieve_reasoning_strategies") {
1448                Err(Self::fail_err("retrieve_reasoning_strategies"))
1449            } else {
1450                Ok(self.reasoning_strategies.clone())
1451            };
1452            Box::pin(async move { result })
1453        }
1454
1455        fn mark_reasoning_used<'a>(
1456            &'a self,
1457            ids: &'a [String],
1458        ) -> std::pin::Pin<
1459            Box<
1460                dyn std::future::Future<
1461                        Output = Result<(), Box<dyn std::error::Error + Send + Sync>>,
1462                    > + Send
1463                    + 'a,
1464            >,
1465        > {
1466            if self.fail_on == Some("mark_reasoning_used") {
1467                return Box::pin(async move { Err(Self::fail_err("mark_reasoning_used")) });
1468            }
1469            let mut guard = self.marked_ids.lock().expect("marked_ids poisoned");
1470            guard.extend_from_slice(ids);
1471            Box::pin(async move { Ok(()) })
1472        }
1473
1474        fn retrieve_corrections<'a>(
1475            &'a self,
1476            _query: &'a str,
1477            _limit: usize,
1478            _min_score: f32,
1479        ) -> std::pin::Pin<
1480            Box<
1481                dyn std::future::Future<
1482                        Output = Result<
1483                            Vec<MemCorrection>,
1484                            Box<dyn std::error::Error + Send + Sync>,
1485                        >,
1486                    > + Send
1487                    + 'a,
1488            >,
1489        > {
1490            let result = if self.fail_on == Some("retrieve_corrections") {
1491                Err(Self::fail_err("retrieve_corrections"))
1492            } else {
1493                Ok(self.corrections.clone())
1494            };
1495            Box::pin(async move { result })
1496        }
1497
1498        fn recall<'a>(
1499            &'a self,
1500            _query: &'a str,
1501            _limit: usize,
1502            _router: Option<&'a dyn zeph_common::memory::AsyncMemoryRouter>,
1503        ) -> std::pin::Pin<
1504            Box<
1505                dyn std::future::Future<
1506                        Output = Result<
1507                            Vec<MemRecalledMessage>,
1508                            Box<dyn std::error::Error + Send + Sync>,
1509                        >,
1510                    > + Send
1511                    + 'a,
1512            >,
1513        > {
1514            let result = if self.fail_on == Some("recall") {
1515                Err(Self::fail_err("recall"))
1516            } else {
1517                Ok(self.recalled.clone())
1518            };
1519            Box::pin(async move { result })
1520        }
1521
1522        fn recall_graph_facts<'a>(
1523            &'a self,
1524            _query: &'a str,
1525            _params: GraphRecallParams<'a>,
1526        ) -> std::pin::Pin<
1527            Box<
1528                dyn std::future::Future<
1529                        Output = Result<
1530                            Vec<MemGraphFact>,
1531                            Box<dyn std::error::Error + Send + Sync>,
1532                        >,
1533                    > + Send
1534                    + 'a,
1535            >,
1536        > {
1537            let result = if self.fail_on == Some("recall_graph_facts") {
1538                Err(Self::fail_err("recall_graph_facts"))
1539            } else {
1540                Ok(self.graph_facts.clone())
1541            };
1542            Box::pin(async move { result })
1543        }
1544
1545        fn search_session_summaries<'a>(
1546            &'a self,
1547            _query: &'a str,
1548            _limit: usize,
1549            _current_conversation_id: Option<i64>,
1550        ) -> std::pin::Pin<
1551            Box<
1552                dyn std::future::Future<
1553                        Output = Result<
1554                            Vec<MemSessionSummary>,
1555                            Box<dyn std::error::Error + Send + Sync>,
1556                        >,
1557                    > + Send
1558                    + 'a,
1559            >,
1560        > {
1561            let result = if self.fail_on == Some("search_session_summaries") {
1562                Err(Self::fail_err("search_session_summaries"))
1563            } else {
1564                Ok(self.session_summaries.clone())
1565            };
1566            Box::pin(async move { result })
1567        }
1568
1569        fn search_document_collection<'a>(
1570            &'a self,
1571            _collection: &'a str,
1572            _query: &'a str,
1573            _top_k: usize,
1574        ) -> std::pin::Pin<
1575            Box<
1576                dyn std::future::Future<
1577                        Output = Result<
1578                            Vec<MemDocumentChunk>,
1579                            Box<dyn std::error::Error + Send + Sync>,
1580                        >,
1581                    > + Send
1582                    + 'a,
1583            >,
1584        > {
1585            let result = if self.fail_on == Some("search_document_collection") {
1586                Err(Self::fail_err("search_document_collection"))
1587            } else {
1588                Ok(self.document_chunks.clone())
1589            };
1590            Box::pin(async move { result })
1591        }
1592    }
1593
1594    fn mock_view(mock: MockMemoryBackend) -> ContextMemoryView {
1595        let mut v = empty_view();
1596        v.memory = Some(Arc::new(mock));
1597        v
1598    }
1599
1600    // ── fetch_graph_facts (happy path) ────────────────────────────────────────
1601
1602    #[tokio::test]
1603    async fn fetch_graph_facts_returns_message_when_memory_present() {
1604        let mock = MockMemoryBackend {
1605            graph_facts: vec![zeph_common::memory::MemGraphFact {
1606                fact: "Rust is fast".to_string(),
1607                confidence: 0.9,
1608                activation_score: None,
1609                neighbors: vec![],
1610                provenance_snippet: None,
1611            }],
1612            ..Default::default()
1613        };
1614        let mut view = mock_view(mock);
1615        view.graph_config.enabled = true;
1616        // recall_timeout_ms must be non-zero or it gets clamped to 100ms
1617        view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1618        let tc = NaiveTokenCounter;
1619        let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1620        assert!(result.is_some(), "expected Some message");
1621        let msg = result.unwrap();
1622        assert!(
1623            msg.content.contains("Rust is fast"),
1624            "expected fact text in output, got: {}",
1625            msg.content
1626        );
1627        assert!(
1628            msg.content.starts_with(GRAPH_FACTS_PREFIX),
1629            "expected GRAPH_FACTS_PREFIX"
1630        );
1631    }
1632
1633    #[tokio::test]
1634    async fn fetch_graph_facts_swallows_error_and_returns_none() {
1635        let mock = MockMemoryBackend::with_fail_on("recall_graph_facts");
1636        let mut view = mock_view(mock);
1637        view.graph_config.enabled = true;
1638        view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1639        let tc = NaiveTokenCounter;
1640        // B1: fetch_graph_facts swallows errors via tracing::warn! and returns Ok(None)
1641        let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1642        assert!(
1643            result.is_none(),
1644            "expected None when recall_graph_facts errors"
1645        );
1646    }
1647
1648    #[tokio::test]
1649    async fn fetch_graph_facts_returns_none_when_facts_empty() {
1650        let mock = MockMemoryBackend::default(); // empty graph_facts
1651        let mut view = mock_view(mock);
1652        view.graph_config.enabled = true;
1653        view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1654        let tc = NaiveTokenCounter;
1655        let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1656        assert!(result.is_none());
1657    }
1658
1659    // ── fetch_persona_facts ───────────────────────────────────────────────────
1660
1661    #[tokio::test]
1662    async fn fetch_persona_facts_returns_message_when_persona_enabled() {
1663        let mock = MockMemoryBackend {
1664            persona_facts: vec![MemPersonaFact {
1665                category: "preference".to_string(),
1666                content: "prefers concise answers".to_string(),
1667            }],
1668            ..Default::default()
1669        };
1670        let mut view = mock_view(mock);
1671        view.persona_config.enabled = true;
1672        view.persona_config.context_budget_tokens = 1000;
1673        let tc = NaiveTokenCounter;
1674        let result = fetch_persona_facts(&view, 1000, &tc).await.unwrap();
1675        assert!(result.is_some());
1676        let msg = result.unwrap();
1677        assert!(msg.content.contains("preference"));
1678        assert!(msg.content.contains("prefers concise answers"));
1679        assert!(msg.content.starts_with(crate::slot::PERSONA_PREFIX));
1680    }
1681
1682    #[tokio::test]
1683    async fn fetch_persona_facts_propagates_error() {
1684        let mock = MockMemoryBackend::with_fail_on("load_persona_facts");
1685        let mut view = mock_view(mock);
1686        view.persona_config.enabled = true;
1687        let tc = NaiveTokenCounter;
1688        let result = fetch_persona_facts(&view, 1000, &tc).await;
1689        assert!(
1690            result.is_err(),
1691            "expected Err from load_persona_facts failure"
1692        );
1693    }
1694
1695    // ── fetch_trajectory_hints ────────────────────────────────────────────────
1696
1697    #[tokio::test]
1698    async fn fetch_trajectory_hints_returns_message_when_trajectory_enabled() {
1699        let mock = MockMemoryBackend {
1700            trajectory_entries: vec![MemTrajectoryEntry {
1701                intent: "summarize code".to_string(),
1702                outcome: "produced concise summary".to_string(),
1703                confidence: 0.9,
1704            }],
1705            ..Default::default()
1706        };
1707        let mut view = mock_view(mock);
1708        view.trajectory_config.enabled = true;
1709        view.trajectory_config.context_budget_tokens = 1000;
1710        view.trajectory_config.min_confidence = 0.5;
1711        let tc = NaiveTokenCounter;
1712        let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1713        assert!(result.is_some());
1714        let msg = result.unwrap();
1715        assert!(msg.content.contains("summarize code"));
1716        assert!(msg.content.starts_with(crate::slot::TRAJECTORY_PREFIX));
1717    }
1718
1719    #[tokio::test]
1720    async fn fetch_trajectory_hints_passes_tier_filter() {
1721        // I1: confidence filtering — entry below min_confidence must be excluded,
1722        // entry above must be present. Verifies the .filter(|e| e.confidence >= min_conf) branch.
1723        let mock = MockMemoryBackend {
1724            trajectory_entries: vec![
1725                MemTrajectoryEntry {
1726                    intent: "debug async code".to_string(),
1727                    outcome: "fixed deadlock".to_string(),
1728                    confidence: 0.85,
1729                },
1730                MemTrajectoryEntry {
1731                    intent: "low confidence task".to_string(),
1732                    outcome: "irrelevant".to_string(),
1733                    confidence: 0.3,
1734                },
1735            ],
1736            ..Default::default()
1737        };
1738        let mut view = mock_view(mock);
1739        view.trajectory_config.enabled = true;
1740        view.trajectory_config.context_budget_tokens = 1000;
1741        view.trajectory_config.min_confidence = 0.5;
1742        let tc = NaiveTokenCounter;
1743        let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1744        assert!(result.is_some(), "expected Some message");
1745        let msg = result.unwrap();
1746        assert!(
1747            msg.content.contains("debug async code"),
1748            "high-confidence entry must be included"
1749        );
1750        assert!(
1751            !msg.content.contains("low confidence task"),
1752            "entry below min_confidence must be filtered out"
1753        );
1754    }
1755
1756    #[tokio::test]
1757    async fn fetch_trajectory_hints_propagates_error() {
1758        let mock = MockMemoryBackend::with_fail_on("load_trajectory_entries");
1759        let mut view = mock_view(mock);
1760        view.trajectory_config.enabled = true;
1761        let tc = NaiveTokenCounter;
1762        let result = fetch_trajectory_hints(&view, 1000, &tc).await;
1763        assert!(result.is_err());
1764    }
1765
1766    // ── fetch_tree_memory ─────────────────────────────────────────────────────
1767
1768    #[tokio::test]
1769    async fn fetch_tree_memory_returns_message_when_tree_enabled() {
1770        let mock = MockMemoryBackend {
1771            tree_nodes: vec![MemTreeNode {
1772                content: "Topic: async Rust patterns".to_string(),
1773            }],
1774            ..Default::default()
1775        };
1776        let mut view = mock_view(mock);
1777        view.tree_config.enabled = true;
1778        view.tree_config.context_budget_tokens = 1000;
1779        let tc = NaiveTokenCounter;
1780        let result = fetch_tree_memory(&view, 1000, &tc).await.unwrap();
1781        assert!(result.is_some());
1782        let msg = result.unwrap();
1783        assert!(msg.content.contains("async Rust patterns"));
1784        assert!(msg.content.starts_with(crate::slot::TREE_MEMORY_PREFIX));
1785    }
1786
1787    #[tokio::test]
1788    async fn fetch_tree_memory_propagates_error() {
1789        let mock = MockMemoryBackend::with_fail_on("load_tree_nodes");
1790        let mut view = mock_view(mock);
1791        view.tree_config.enabled = true;
1792        let tc = NaiveTokenCounter;
1793        let result = fetch_tree_memory(&view, 1000, &tc).await;
1794        assert!(result.is_err());
1795    }
1796
1797    // ── fetch_corrections ─────────────────────────────────────────────────────
1798
1799    #[tokio::test]
1800    async fn fetch_corrections_returns_message_when_corrections_present() {
1801        let mock = MockMemoryBackend {
1802            corrections: vec![MemCorrection {
1803                correction_text: "use snake_case not camelCase".to_string(),
1804            }],
1805            ..Default::default()
1806        };
1807        let view = mock_view(mock);
1808        let result = fetch_corrections(&view, "query", 10, 0.5, |s| s.into())
1809            .await
1810            .unwrap();
1811        assert!(result.is_some());
1812        let msg = result.unwrap();
1813        assert!(msg.content.contains("snake_case"));
1814        assert!(msg.content.starts_with(CORRECTIONS_PREFIX));
1815    }
1816
1817    #[tokio::test]
1818    async fn fetch_corrections_propagates_error() {
1819        // fetch_corrections uses map_err(AssemblerError::Memory)? so retrieve_corrections
1820        // errors are propagated instead of silently discarded.
1821        let mock = MockMemoryBackend::with_fail_on("retrieve_corrections");
1822        let view = mock_view(mock);
1823        let result = fetch_corrections(&view, "query", 10, 0.5, |s| s.into()).await;
1824        assert!(result.is_err(), "expected Err, got {result:?}");
1825    }
1826
1827    // ── fetch_semantic_recall ─────────────────────────────────────────────────
1828
1829    #[tokio::test]
1830    async fn fetch_semantic_recall_returns_message_with_content() {
1831        let mock = MockMemoryBackend {
1832            recalled: vec![
1833                MemRecalledMessage {
1834                    role: "user".to_string(),
1835                    content: "how does tokio work".to_string(),
1836                    score: 0.95,
1837                },
1838                MemRecalledMessage {
1839                    role: "assistant".to_string(),
1840                    content: "tokio is an async runtime".to_string(),
1841                    score: 0.88,
1842                },
1843            ],
1844            ..Default::default()
1845        };
1846        let mut view = mock_view(mock);
1847        view.recall_limit = 10;
1848        let tc = NaiveTokenCounter;
1849        let (msg, score) = fetch_semantic_recall(&view, "tokio", 1000, &tc, None)
1850            .await
1851            .unwrap();
1852        assert!(msg.is_some(), "expected Some message");
1853        // I4: verify score equals first message's score
1854        assert!(score.is_some_and(|s| (s - 0.95_f32).abs() < f32::EPSILON));
1855        let msg = msg.unwrap();
1856        // content is in parts.Recall so check parts
1857        let has_recall_part = msg.parts.iter().any(|p| {
1858            if let zeph_llm::provider::MessagePart::Recall { text } = p {
1859                text.contains("how does tokio work")
1860            } else {
1861                false
1862            }
1863        });
1864        assert!(has_recall_part, "expected recalled content in Recall part");
1865    }
1866
1867    #[tokio::test]
1868    async fn fetch_semantic_recall_returns_none_when_recalled_empty() {
1869        let mock = MockMemoryBackend::default();
1870        let mut view = mock_view(mock);
1871        view.recall_limit = 10;
1872        let tc = NaiveTokenCounter;
1873        let (msg, score) = fetch_semantic_recall(&view, "query", 1000, &tc, None)
1874            .await
1875            .unwrap();
1876        assert!(msg.is_none());
1877        assert!(score.is_none());
1878    }
1879
1880    #[tokio::test]
1881    async fn fetch_semantic_recall_propagates_error() {
1882        let mock = MockMemoryBackend::with_fail_on("recall");
1883        let mut view = mock_view(mock);
1884        view.recall_limit = 10;
1885        let tc = NaiveTokenCounter;
1886        let result = fetch_semantic_recall(&view, "query", 1000, &tc, None).await;
1887        assert!(result.is_err());
1888    }
1889
1890    // ── fetch_document_rag ────────────────────────────────────────────────────
1891
1892    #[tokio::test]
1893    async fn fetch_document_rag_returns_message_when_rag_enabled() {
1894        let mock = MockMemoryBackend {
1895            document_chunks: vec![MemDocumentChunk {
1896                text: "Rust ownership rules prevent data races".to_string(),
1897            }],
1898            ..Default::default()
1899        };
1900        let mut view = mock_view(mock);
1901        view.document_config.rag_enabled = true;
1902        let tc = NaiveTokenCounter;
1903        let result = fetch_document_rag(&view, "ownership", 1000, &tc)
1904            .await
1905            .unwrap();
1906        assert!(result.is_some());
1907        let msg = result.unwrap();
1908        assert!(msg.content.contains("ownership rules"));
1909        assert!(msg.content.starts_with(DOCUMENT_RAG_PREFIX));
1910    }
1911
1912    #[tokio::test]
1913    async fn fetch_document_rag_propagates_error() {
1914        let mock = MockMemoryBackend::with_fail_on("search_document_collection");
1915        let mut view = mock_view(mock);
1916        view.document_config.rag_enabled = true;
1917        let tc = NaiveTokenCounter;
1918        let result = fetch_document_rag(&view, "query", 1000, &tc).await;
1919        assert!(result.is_err());
1920    }
1921
1922    // ── fetch_summaries ───────────────────────────────────────────────────────
1923
1924    #[tokio::test]
1925    async fn fetch_summaries_returns_message_when_summaries_present() {
1926        let mock = MockMemoryBackend {
1927            summaries: vec![MemSummary {
1928                first_message_id: Some(1),
1929                last_message_id: Some(5),
1930                content: "User asked about async Rust".to_string(),
1931            }],
1932            ..Default::default()
1933        };
1934        let mut view = mock_view(mock);
1935        view.conversation_id = Some(42);
1936        let tc = NaiveTokenCounter;
1937        let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1938        assert!(result.is_some());
1939        let msg = result.unwrap();
1940        let has_summary_part = msg.parts.iter().any(|p| {
1941            if let zeph_llm::provider::MessagePart::Summary { text } = p {
1942                text.contains("Messages 1-5") && text.contains("async Rust")
1943            } else {
1944                false
1945            }
1946        });
1947        assert!(
1948            has_summary_part,
1949            "expected Summary part with messages range"
1950        );
1951    }
1952
1953    #[tokio::test]
1954    async fn fetch_summaries_returns_none_without_conversation_id() {
1955        let mock = MockMemoryBackend {
1956            summaries: vec![MemSummary {
1957                first_message_id: Some(1),
1958                last_message_id: Some(5),
1959                content: "some content".to_string(),
1960            }],
1961            ..Default::default()
1962        };
1963        let mut view = mock_view(mock);
1964        view.conversation_id = None; // no conversation_id → must return None
1965        let tc = NaiveTokenCounter;
1966        let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1967        assert!(result.is_none());
1968    }
1969
1970    #[tokio::test]
1971    async fn fetch_summaries_propagates_error() {
1972        let mock = MockMemoryBackend::with_fail_on("load_summaries");
1973        let mut view = mock_view(mock);
1974        view.conversation_id = Some(42);
1975        let tc = NaiveTokenCounter;
1976        let result = fetch_summaries(&view, 1000, &tc).await;
1977        assert!(result.is_err());
1978    }
1979
1980    // ── fetch_cross_session ───────────────────────────────────────────────────
1981
1982    #[tokio::test]
1983    async fn fetch_cross_session_returns_message_when_results_present() {
1984        let mock = MockMemoryBackend {
1985            session_summaries: vec![MemSessionSummary {
1986                summary_text: "Previous session: debugging tokio deadlock".to_string(),
1987                score: 0.9,
1988            }],
1989            ..Default::default()
1990        };
1991        let mut view = mock_view(mock);
1992        view.conversation_id = Some(1);
1993        view.cross_session_score_threshold = 0.5;
1994        let tc = NaiveTokenCounter;
1995        let result = fetch_cross_session(&view, "async", 1000, &tc)
1996            .await
1997            .unwrap();
1998        assert!(result.is_some());
1999        let msg = result.unwrap();
2000        let has_cross_session_part = msg.parts.iter().any(|p| {
2001            if let zeph_llm::provider::MessagePart::CrossSession { text } = p {
2002                text.contains("tokio deadlock")
2003            } else {
2004                false
2005            }
2006        });
2007        assert!(has_cross_session_part);
2008    }
2009
2010    #[tokio::test]
2011    async fn fetch_cross_session_propagates_error() {
2012        let mock = MockMemoryBackend::with_fail_on("search_session_summaries");
2013        let mut view = mock_view(mock);
2014        view.conversation_id = Some(1);
2015        let tc = NaiveTokenCounter;
2016        let result = fetch_cross_session(&view, "query", 1000, &tc).await;
2017        assert!(result.is_err());
2018    }
2019
2020    // ── fetch_reasoning_strategies (happy path + mark_used) ──────────────────
2021
2022    #[tokio::test]
2023    async fn fetch_reasoning_strategies_returns_message_and_marks_used() {
2024        let mock = Arc::new(MockMemoryBackend {
2025            reasoning_strategies: vec![
2026                MemReasoningStrategy {
2027                    id: "strat-1".to_string(),
2028                    outcome: "success".to_string(),
2029                    summary: "break the problem into small steps".to_string(),
2030                },
2031                MemReasoningStrategy {
2032                    id: "strat-2".to_string(),
2033                    outcome: "success".to_string(),
2034                    summary: "use tracing spans for debugging".to_string(),
2035                },
2036            ],
2037            ..Default::default()
2038        });
2039        let marked_ids = Arc::clone(&mock);
2040        let mut view = empty_view();
2041        view.memory = Some(mock);
2042        view.reasoning_config.enabled = true;
2043        view.reasoning_config.context_budget_tokens = 1000;
2044        let tc = NaiveTokenCounter;
2045        let (result, handle) = fetch_reasoning_strategies(&view, "debug", 1000, 5, &tc)
2046            .await
2047            .unwrap();
2048        assert!(result.is_some());
2049        let msg = result.unwrap();
2050        assert!(msg.content.starts_with(crate::slot::REASONING_PREFIX));
2051        assert!(msg.content.contains("break the problem"));
2052
2053        // Await the returned JoinHandle to ensure mark_reasoning_used completes before assertion.
2054        if let Some(h) = handle {
2055            h.await.expect("mark_reasoning_used task panicked");
2056        }
2057
2058        let ids = marked_ids.marked_ids.lock().expect("marked_ids poisoned");
2059        assert!(
2060            ids.contains(&"strat-1".to_string()),
2061            "expected strat-1 marked"
2062        );
2063        assert!(
2064            ids.contains(&"strat-2".to_string()),
2065            "expected strat-2 marked"
2066        );
2067    }
2068
2069    #[tokio::test]
2070    async fn fetch_reasoning_strategies_propagates_error() {
2071        let mock = MockMemoryBackend::with_fail_on("retrieve_reasoning_strategies");
2072        let mut view = mock_view(mock);
2073        view.reasoning_config.enabled = true;
2074        let tc = NaiveTokenCounter;
2075        let result = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc).await;
2076        assert!(result.is_err());
2077    }
2078
2079    // ── edge cases ────────────────────────────────────────────────────────────
2080
2081    #[tokio::test]
2082    async fn fetch_semantic_recall_skips_skipped_and_stopped_messages() {
2083        let mock = MockMemoryBackend {
2084            recalled: vec![
2085                MemRecalledMessage {
2086                    role: "user".to_string(),
2087                    content: "[skipped] some content".to_string(),
2088                    score: 0.95,
2089                },
2090                MemRecalledMessage {
2091                    role: "user".to_string(),
2092                    content: "[stopped] other content".to_string(),
2093                    score: 0.90,
2094                },
2095                MemRecalledMessage {
2096                    role: "user".to_string(),
2097                    content: "valid content to recall".to_string(),
2098                    score: 0.85,
2099                },
2100            ],
2101            ..Default::default()
2102        };
2103        let mut view = mock_view(mock);
2104        view.recall_limit = 10;
2105        let tc = NaiveTokenCounter;
2106        let (msg, _) = fetch_semantic_recall(&view, "query", 1000, &tc, None)
2107            .await
2108            .unwrap();
2109        assert!(msg.is_some());
2110        let msg = msg.unwrap();
2111        let full_text = msg.parts.iter().find_map(|p| {
2112            if let zeph_llm::provider::MessagePart::Recall { text } = p {
2113                Some(text.clone())
2114            } else {
2115                None
2116            }
2117        });
2118        let text = full_text.unwrap_or_default();
2119        assert!(
2120            !text.contains("[skipped]"),
2121            "skipped messages must be excluded"
2122        );
2123        assert!(
2124            !text.contains("[stopped]"),
2125            "stopped messages must be excluded"
2126        );
2127        assert!(
2128            text.contains("valid content to recall"),
2129            "valid messages must be included"
2130        );
2131    }
2132
2133    #[tokio::test]
2134    async fn fetch_cross_session_filters_below_threshold() {
2135        let mock = MockMemoryBackend {
2136            session_summaries: vec![
2137                MemSessionSummary {
2138                    summary_text: "high relevance session".to_string(),
2139                    score: 0.9,
2140                },
2141                MemSessionSummary {
2142                    summary_text: "low relevance session".to_string(),
2143                    score: 0.2,
2144                },
2145            ],
2146            ..Default::default()
2147        };
2148        let mut view = mock_view(mock);
2149        view.conversation_id = Some(1);
2150        view.cross_session_score_threshold = 0.5;
2151        let tc = NaiveTokenCounter;
2152        let result = fetch_cross_session(&view, "query", 1000, &tc)
2153            .await
2154            .unwrap();
2155        assert!(result.is_some());
2156        let msg = result.unwrap();
2157        let text = msg
2158            .parts
2159            .iter()
2160            .find_map(|p| {
2161                if let zeph_llm::provider::MessagePart::CrossSession { text } = p {
2162                    Some(text.clone())
2163                } else {
2164                    None
2165                }
2166            })
2167            .unwrap_or_default();
2168        assert!(
2169            text.contains("high relevance"),
2170            "high score must be included"
2171        );
2172        assert!(
2173            !text.contains("low relevance"),
2174            "low score must be filtered out"
2175        );
2176    }
2177
2178    #[tokio::test]
2179    async fn fetch_document_rag_skips_empty_chunks() {
2180        let mock = MockMemoryBackend {
2181            document_chunks: vec![
2182                MemDocumentChunk {
2183                    text: String::new(),
2184                }, // empty — must be skipped
2185                MemDocumentChunk {
2186                    text: "real content here".to_string(),
2187                },
2188            ],
2189            ..Default::default()
2190        };
2191        let mut view = mock_view(mock);
2192        view.document_config.rag_enabled = true;
2193        let tc = NaiveTokenCounter;
2194        let result = fetch_document_rag(&view, "query", 1000, &tc).await.unwrap();
2195        assert!(result.is_some());
2196        let msg = result.unwrap();
2197        assert!(msg.content.contains("real content here"));
2198        // empty chunk text should not produce an empty line before prefix
2199        assert!(!msg.content.contains("\n\n\n"));
2200    }
2201
2202    #[tokio::test]
2203    async fn fetch_graph_facts_sanitizes_injection_payloads() {
2204        // I3: newlines and angle brackets are replaced with spaces
2205        let mock = MockMemoryBackend {
2206            graph_facts: vec![zeph_common::memory::MemGraphFact {
2207                fact: "fact with <script>alert(1)</script> and\nnewline".to_string(),
2208                confidence: 0.8,
2209                activation_score: None,
2210                neighbors: vec![],
2211                provenance_snippet: None,
2212            }],
2213            ..Default::default()
2214        };
2215        let mut view = mock_view(mock);
2216        view.graph_config.enabled = true;
2217        view.graph_config.spreading_activation.recall_timeout_ms = 5000;
2218        let tc = NaiveTokenCounter;
2219        let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
2220        assert!(result.is_some());
2221        let msg = result.unwrap();
2222        assert!(
2223            !msg.content.contains('<'),
2224            "angle brackets must be sanitized"
2225        );
2226        // The formatter adds trailing \n to each line, but embedded \n in fact text is replaced
2227        // with spaces. Verify no double-newline sequences exist (would indicate unsanitized \n).
2228        assert!(
2229            !msg.content.contains("\n\n"),
2230            "embedded newlines must be sanitized, no double-newline sequences expected"
2231        );
2232    }
2233
2234    #[tokio::test]
2235    async fn fetch_reasoning_strategies_sanitizes_injection_payloads() {
2236        // I3: newlines and angle brackets are replaced with spaces in strategy summaries
2237        let mock = MockMemoryBackend {
2238            reasoning_strategies: vec![MemReasoningStrategy {
2239                id: "s1".to_string(),
2240                outcome: "success".to_string(),
2241                summary: "strategy with <b>bold</b> and\nnewline".to_string(),
2242            }],
2243            ..Default::default()
2244        };
2245        let mut view = mock_view(mock);
2246        view.reasoning_config.enabled = true;
2247        let tc = NaiveTokenCounter;
2248        let (result, _handle) = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc)
2249            .await
2250            .unwrap();
2251        assert!(result.is_some());
2252        let msg = result.unwrap();
2253        assert!(
2254            !msg.content.contains('<'),
2255            "angle brackets must be sanitized in strategy summaries"
2256        );
2257    }
2258
2259    // ── budget truncation (CR-1) ──────────────────────────────────────────────
2260
2261    #[tokio::test]
2262    async fn fetch_persona_facts_truncates_at_budget() {
2263        let tc = NaiveTokenCounter;
2264        // Tight budget: fits prefix + exactly 1 fact line, second must be omitted.
2265        let first_line = "[pref] brief\n";
2266        let budget = tc.count_tokens(crate::slot::PERSONA_PREFIX) + tc.count_tokens(first_line);
2267        let mock = MockMemoryBackend {
2268            persona_facts: vec![
2269                MemPersonaFact {
2270                    category: "pref".to_string(),
2271                    content: "brief".to_string(),
2272                },
2273                MemPersonaFact {
2274                    category: "lang".to_string(),
2275                    content: "english".to_string(),
2276                },
2277            ],
2278            ..Default::default()
2279        };
2280        let mut view = mock_view(mock);
2281        view.persona_config.enabled = true;
2282        let result = fetch_persona_facts(&view, budget, &tc).await.unwrap();
2283        let msg = result.unwrap();
2284        assert!(msg.content.contains("brief"), "first fact must be included");
2285        assert!(
2286            !msg.content.contains("english"),
2287            "second fact must be truncated by budget"
2288        );
2289    }
2290
2291    #[tokio::test]
2292    async fn fetch_semantic_recall_truncates_at_budget() {
2293        let tc = NaiveTokenCounter;
2294        // Tight budget: fits prefix + exactly 1 recall entry, second must be omitted.
2295        let first_entry = "- [user] first message\n";
2296        let budget = tc.count_tokens(RECALL_PREFIX) + tc.count_tokens(first_entry);
2297        let mock = MockMemoryBackend {
2298            recalled: vec![
2299                MemRecalledMessage {
2300                    role: "user".to_string(),
2301                    content: "first message".to_string(),
2302                    score: 0.95,
2303                },
2304                MemRecalledMessage {
2305                    role: "user".to_string(),
2306                    content: "second message that should be truncated".to_string(),
2307                    score: 0.80,
2308                },
2309            ],
2310            ..Default::default()
2311        };
2312        let mut view = mock_view(mock);
2313        view.recall_limit = 10;
2314        let (msg, _) = fetch_semantic_recall(&view, "query", budget, &tc, None)
2315            .await
2316            .unwrap();
2317        assert!(msg.is_some());
2318        let text = msg
2319            .unwrap()
2320            .parts
2321            .iter()
2322            .find_map(|p| {
2323                if let zeph_llm::provider::MessagePart::Recall { text } = p {
2324                    Some(text.clone())
2325                } else {
2326                    None
2327                }
2328            })
2329            .unwrap_or_default();
2330        assert!(
2331            text.contains("first message"),
2332            "first entry must be included"
2333        );
2334        assert!(
2335            !text.contains("second message"),
2336            "second entry must be truncated by budget"
2337        );
2338    }
2339
2340    // ── provenance_snippet sanitization (CR-2 test) ───────────────────────────
2341
2342    #[tokio::test]
2343    async fn fetch_graph_facts_sanitizes_provenance_snippet() {
2344        use zeph_common::memory::MemGraphNeighbor;
2345        let mock = MockMemoryBackend {
2346            graph_facts: vec![zeph_common::memory::MemGraphFact {
2347                fact: "safe fact".to_string(),
2348                confidence: 0.9,
2349                activation_score: None,
2350                neighbors: vec![MemGraphNeighbor {
2351                    fact: "neighbor".to_string(),
2352                    confidence: 0.7,
2353                }],
2354                provenance_snippet: Some("source with <injection>\nand newline".to_string()),
2355            }],
2356            ..Default::default()
2357        };
2358        let mut view = mock_view(mock);
2359        view.graph_config.enabled = true;
2360        view.graph_config.spreading_activation.recall_timeout_ms = 5000;
2361        let tc = NaiveTokenCounter;
2362        let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
2363        assert!(result.is_some());
2364        let msg = result.unwrap();
2365        assert!(
2366            !msg.content.contains('<'),
2367            "angle brackets in provenance_snippet must be sanitized"
2368        );
2369        assert!(
2370            !msg.content.contains("\n\n"),
2371            "newlines in provenance_snippet must be sanitized"
2372        );
2373        assert!(
2374            msg.content.contains("[source:"),
2375            "provenance snippet must be rendered"
2376        );
2377    }
2378}