Skip to main content

zeph_context/
assembler.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Stateless context assembler.
5//!
6//! [`ContextAssembler`] gathers all memory-sourced context for a single agent turn by running
7//! all async fetch operations concurrently. It takes only borrowed references via
8//! [`ContextAssemblyInput`] and returns a [`PreparedContext`] ready for injection.
9//!
10//! Invariants:
11//! - No `Agent` field mutations inside `gather()`.
12//! - No channel communication inside `gather()`.
13//! - All `send_status` calls remain in `Agent::prepare_context`.
14//! - `session_digest` is cached (not async) and stays in `Agent::apply_prepared_context`.
15
16use std::future::Future;
17use std::pin::Pin;
18
19use futures::StreamExt as _;
20use futures::stream::FuturesUnordered;
21
22use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
23use zeph_memory::TokenCounter;
24
25use crate::error::ContextError;
26use crate::input::ContextAssemblyInput;
27use crate::slot::ContextSlot;
28
29/// Prefix for past-session summary injections.
30pub const SUMMARY_PREFIX: &str = "[conversation summaries]\n";
31/// Prefix for cross-session context injections.
32pub const CROSS_SESSION_PREFIX: &str = "[cross-session context]\n";
33/// Prefix for semantic recall injections.
34pub const RECALL_PREFIX: &str = "[semantic recall]\n";
35/// Prefix for past-correction injections.
36pub const CORRECTIONS_PREFIX: &str = "[past corrections]\n";
37/// Prefix for document RAG injections.
38pub const DOCUMENT_RAG_PREFIX: &str = "## Relevant documents\n";
39/// Prefix for knowledge graph fact injections.
40pub const GRAPH_FACTS_PREFIX: &str = "[known facts]\n";
41
42/// Result of one context-assembly pass.
43///
44/// All source fields are `Option` — `None` means disabled, empty, or budget-exhausted.
45/// `session_digest` is excluded: it is a cached value injected by `Agent::apply_prepared_context`.
46pub struct PreparedContext {
47    /// Knowledge graph fact recall.
48    pub graph_facts: Option<Message>,
49    /// Document RAG context.
50    pub doc_rag: Option<Message>,
51    /// Past user corrections.
52    pub corrections: Option<Message>,
53    /// Semantic recall results.
54    pub recall: Option<Message>,
55    /// Top-1 similarity score from semantic recall.
56    pub recall_confidence: Option<f32>,
57    /// Cross-session memory context.
58    pub cross_session: Option<Message>,
59    /// Past-conversation summaries.
60    pub summaries: Option<Message>,
61    /// Code-index RAG context (repo map or file context).
62    pub code_context: Option<String>,
63    /// Persona memory facts.
64    pub persona_facts: Option<Message>,
65    /// Trajectory hints.
66    pub trajectory_hints: Option<Message>,
67    /// `TiMem` tree memory summary.
68    pub tree_memory: Option<Message>,
69    /// Whether the memory-first context strategy is active for this turn.
70    pub memory_first: bool,
71    /// Token budget for recent conversation history (passed to trim step in apply).
72    pub recent_history_budget: usize,
73}
74
75/// Stateless coordinator for parallel context fetching.
76///
77/// All logic is in [`ContextAssembler::gather`]. No state is stored on this type.
78pub struct ContextAssembler;
79
80impl ContextAssembler {
81    /// Gather all context sources concurrently and return a [`PreparedContext`].
82    ///
83    /// Returns an empty `PreparedContext` immediately when `context_manager.budget` is `None`.
84    ///
85    /// # Errors
86    ///
87    /// Propagates errors from any async fetch operation.
88    #[allow(clippy::too_many_lines)]
89    pub async fn gather(input: &ContextAssemblyInput<'_>) -> Result<PreparedContext, ContextError> {
90        type CtxFuture<'a> =
91            Pin<Box<dyn Future<Output = Result<ContextSlot, ContextError>> + Send + 'a>>;
92
93        let Some(ref budget) = input.context_manager.budget else {
94            return Ok(PreparedContext {
95                graph_facts: None,
96                doc_rag: None,
97                corrections: None,
98                recall: None,
99                recall_confidence: None,
100                cross_session: None,
101                summaries: None,
102                code_context: None,
103                persona_facts: None,
104                trajectory_hints: None,
105                tree_memory: None,
106                memory_first: false,
107                recent_history_budget: 0,
108            });
109        };
110
111        let memory = input.memory;
112        let tc = input.token_counter;
113
114        let effective_strategy = match memory.context_strategy {
115            zeph_config::ContextStrategy::FullHistory => zeph_config::ContextStrategy::FullHistory,
116            zeph_config::ContextStrategy::MemoryFirst => zeph_config::ContextStrategy::MemoryFirst,
117            zeph_config::ContextStrategy::Adaptive => {
118                if input.sidequest_turn_counter >= u64::from(memory.crossover_turn_threshold) {
119                    zeph_config::ContextStrategy::MemoryFirst
120                } else {
121                    zeph_config::ContextStrategy::FullHistory
122                }
123            }
124        };
125        let memory_first = effective_strategy == zeph_config::ContextStrategy::MemoryFirst;
126
127        let system_prompt = input
128            .messages
129            .first()
130            .filter(|m| m.role == Role::System)
131            .map_or("", |m| m.content.as_str());
132
133        let digest_tokens = memory
134            .cached_session_digest
135            .as_ref()
136            .map_or(0, |(_, tokens)| *tokens);
137
138        let graph_enabled = memory.graph_config.enabled;
139
140        let alloc = budget.allocate_with_opts(
141            system_prompt,
142            input.skills_prompt,
143            tc,
144            graph_enabled,
145            digest_tokens,
146            memory_first,
147        );
148
149        let correction_params = input
150            .correction_config
151            .filter(|c| c.correction_detection)
152            .map(|c| {
153                (
154                    c.correction_recall_limit as usize,
155                    c.correction_min_similarity,
156                )
157            });
158        let (recall_limit, min_sim) = correction_params.unwrap_or((3, 0.75));
159
160        let router = input.context_manager.build_router();
161        let router_ref: &dyn zeph_memory::AsyncMemoryRouter = router.as_ref();
162        let query = input.query;
163        let scrub = input.scrub;
164
165        let mut fetchers: FuturesUnordered<CtxFuture<'_>> = FuturesUnordered::new();
166
167        tracing::debug!(
168            active_sources = alloc.active_sources(),
169            "context budget allocated"
170        );
171
172        if alloc.summaries > 0 {
173            fetchers.push(Box::pin(async {
174                fetch_summaries(memory, alloc.summaries, tc)
175                    .await
176                    .map(ContextSlot::Summaries)
177            }));
178        }
179        if alloc.cross_session > 0 {
180            fetchers.push(Box::pin(async {
181                fetch_cross_session(memory, query, alloc.cross_session, tc)
182                    .await
183                    .map(ContextSlot::CrossSession)
184            }));
185        }
186        if alloc.semantic_recall > 0 {
187            fetchers.push(Box::pin(async {
188                fetch_semantic_recall(memory, query, alloc.semantic_recall, tc, Some(router_ref))
189                    .await
190                    .map(|(msg, score)| ContextSlot::SemanticRecall(msg, score))
191            }));
192            fetchers.push(Box::pin(async {
193                fetch_document_rag(memory, query, alloc.semantic_recall, tc)
194                    .await
195                    .map(ContextSlot::DocumentRag)
196            }));
197        }
198        // Corrections are safety-critical and never budget-gated.
199        fetchers.push(Box::pin(async {
200            fetch_corrections(memory, query, recall_limit, min_sim, scrub)
201                .await
202                .map(ContextSlot::Corrections)
203        }));
204        if alloc.code_context > 0
205            && let Some(index) = input.index
206        {
207            let budget = alloc.code_context;
208            fetchers.push(Box::pin(async move {
209                let result: Result<Option<String>, ContextError> =
210                    index.fetch_code_rag(query, budget).await;
211                result.map(ContextSlot::CodeContext)
212            }));
213        }
214        if alloc.graph_facts > 0 {
215            fetchers.push(Box::pin(async {
216                fetch_graph_facts(memory, query, alloc.graph_facts, tc)
217                    .await
218                    .map(ContextSlot::GraphFacts)
219            }));
220        }
221        if memory.persona_config.context_budget_tokens > 0 {
222            fetchers.push(Box::pin(async {
223                let persona_budget = memory.persona_config.context_budget_tokens;
224                fetch_persona_facts(memory, persona_budget, tc)
225                    .await
226                    .map(ContextSlot::PersonaFacts)
227            }));
228        }
229        if memory.trajectory_config.context_budget_tokens > 0 {
230            fetchers.push(Box::pin(async {
231                let tbudget = memory.trajectory_config.context_budget_tokens;
232                fetch_trajectory_hints(memory, tbudget, tc)
233                    .await
234                    .map(ContextSlot::TrajectoryHints)
235            }));
236        }
237        if memory.tree_config.context_budget_tokens > 0 {
238            fetchers.push(Box::pin(async {
239                let tbudget = memory.tree_config.context_budget_tokens;
240                fetch_tree_memory(memory, tbudget, tc)
241                    .await
242                    .map(ContextSlot::TreeMemory)
243            }));
244        }
245
246        let mut prepared = PreparedContext {
247            graph_facts: None,
248            doc_rag: None,
249            corrections: None,
250            recall: None,
251            recall_confidence: None,
252            cross_session: None,
253            summaries: None,
254            code_context: None,
255            persona_facts: None,
256            trajectory_hints: None,
257            tree_memory: None,
258            memory_first,
259            recent_history_budget: alloc.recent_history,
260        };
261
262        while let Some(result) = fetchers.next().await {
263            match result {
264                Ok(slot) => match slot {
265                    ContextSlot::Summaries(msg) => prepared.summaries = msg,
266                    ContextSlot::CrossSession(msg) => prepared.cross_session = msg,
267                    ContextSlot::SemanticRecall(msg, score) => {
268                        prepared.recall = msg;
269                        prepared.recall_confidence = score;
270                    }
271                    ContextSlot::DocumentRag(msg) => prepared.doc_rag = msg,
272                    ContextSlot::Corrections(msg) => prepared.corrections = msg,
273                    ContextSlot::CodeContext(text) => prepared.code_context = text,
274                    ContextSlot::GraphFacts(msg) => prepared.graph_facts = msg,
275                    ContextSlot::PersonaFacts(msg) => prepared.persona_facts = msg,
276                    ContextSlot::TrajectoryHints(msg) => prepared.trajectory_hints = msg,
277                    ContextSlot::TreeMemory(msg) => prepared.tree_memory = msg,
278                },
279                Err(e) => return Err(e),
280            }
281        }
282
283        Ok(prepared)
284    }
285}
286
287/// Clamp recall timeout to a safe minimum.
288///
289/// A configured value of 0 would disable spreading activation recall entirely;
290/// clamping to 100ms preserves the user's intent while preventing a silent no-op.
291pub fn effective_recall_timeout_ms(configured: u64) -> u64 {
292    if configured == 0 {
293        tracing::warn!(
294            "recall_timeout_ms is 0, which would disable spreading activation recall; \
295             clamping to 100ms"
296        );
297        100
298    } else {
299        configured
300    }
301}
302
303use crate::input::ContextMemoryView;
304
305pub(crate) async fn fetch_graph_facts(
306    memory: &ContextMemoryView,
307    query: &str,
308    budget_tokens: usize,
309    tc: &TokenCounter,
310) -> Result<Option<Message>, ContextError> {
311    if budget_tokens == 0 || !memory.graph_config.enabled {
312        return Ok(None);
313    }
314    let Some(ref mem) = memory.memory else {
315        return Ok(None);
316    };
317    let recall_limit = memory.graph_config.recall_limit;
318    let temporal_decay_rate = memory.graph_config.temporal_decay_rate;
319    let edge_types = zeph_memory::classify_graph_subgraph(query);
320    let sa_config = &memory.graph_config.spreading_activation;
321
322    let mut body = String::from(GRAPH_FACTS_PREFIX);
323    let mut tokens_so_far = tc.count_tokens(&body);
324
325    if sa_config.enabled {
326        let sa_params = zeph_memory::graph::SpreadingActivationParams {
327            decay_lambda: sa_config.decay_lambda,
328            max_hops: sa_config.max_hops,
329            activation_threshold: sa_config.activation_threshold,
330            inhibition_threshold: sa_config.inhibition_threshold,
331            max_activated_nodes: sa_config.max_activated_nodes,
332            temporal_decay_rate,
333            seed_structural_weight: sa_config.seed_structural_weight,
334            seed_community_cap: sa_config.seed_community_cap,
335        };
336        let timeout_ms = effective_recall_timeout_ms(sa_config.recall_timeout_ms);
337        let recall_fut = mem.recall_graph_activated(query, recall_limit, sa_params, &edge_types);
338        let activated_facts =
339            match tokio::time::timeout(std::time::Duration::from_millis(timeout_ms), recall_fut)
340                .await
341            {
342                Ok(Ok(facts)) => facts,
343                Ok(Err(e)) => {
344                    tracing::warn!("spreading activation recall failed: {e:#}");
345                    Vec::new()
346                }
347                Err(_) => {
348                    tracing::warn!("spreading activation recall timed out ({timeout_ms}ms)");
349                    Vec::new()
350                }
351            };
352
353        if activated_facts.is_empty() {
354            return Ok(None);
355        }
356
357        for f in &activated_facts {
358            let fact_text = f.edge.fact.replace(['\n', '\r', '<', '>'], " ");
359            let line = format!(
360                "- {} (confidence: {:.2}, activation: {:.2})\n",
361                fact_text, f.edge.confidence, f.activation_score
362            );
363            let line_tokens = tc.count_tokens(&line);
364            if tokens_so_far + line_tokens > budget_tokens {
365                break;
366            }
367            body.push_str(&line);
368            tokens_so_far += line_tokens;
369        }
370    } else {
371        let max_hops = memory.graph_config.max_hops;
372        let facts = mem
373            .recall_graph(
374                query,
375                recall_limit,
376                max_hops,
377                None,
378                temporal_decay_rate,
379                &edge_types,
380            )
381            .await
382            .map_err(|e| {
383                tracing::warn!("graph recall failed: {e:#}");
384                ContextError::Memory(e)
385            })?;
386
387        if facts.is_empty() {
388            return Ok(None);
389        }
390
391        for f in &facts {
392            let fact_text = f.fact.replace(['\n', '\r', '<', '>'], " ");
393            let line = format!("- {} (confidence: {:.2})\n", fact_text, f.confidence);
394            let line_tokens = tc.count_tokens(&line);
395            if tokens_so_far + line_tokens > budget_tokens {
396                break;
397            }
398            body.push_str(&line);
399            tokens_so_far += line_tokens;
400        }
401    }
402
403    if body == GRAPH_FACTS_PREFIX {
404        return Ok(None);
405    }
406
407    Ok(Some(Message::from_legacy(Role::System, body)))
408}
409
410pub(crate) async fn fetch_persona_facts(
411    memory: &ContextMemoryView,
412    budget_tokens: usize,
413    tc: &TokenCounter,
414) -> Result<Option<Message>, ContextError> {
415    if budget_tokens == 0 || !memory.persona_config.enabled {
416        return Ok(None);
417    }
418    let Some(ref mem) = memory.memory else {
419        return Ok(None);
420    };
421
422    let min_confidence = memory.persona_config.min_confidence;
423    let facts = mem.sqlite().load_persona_facts(min_confidence).await?;
424
425    if facts.is_empty() {
426        return Ok(None);
427    }
428
429    let mut body = String::from(crate::slot::PERSONA_PREFIX);
430    let mut tokens_so_far = tc.count_tokens(&body);
431
432    for fact in &facts {
433        let line = format!("[{}] {}\n", fact.category, fact.content);
434        let line_tokens = tc.count_tokens(&line);
435        if tokens_so_far + line_tokens > budget_tokens {
436            break;
437        }
438        body.push_str(&line);
439        tokens_so_far += line_tokens;
440    }
441
442    if body == crate::slot::PERSONA_PREFIX {
443        return Ok(None);
444    }
445
446    Ok(Some(Message::from_legacy(Role::System, body)))
447}
448
449pub(crate) async fn fetch_trajectory_hints(
450    memory: &ContextMemoryView,
451    budget_tokens: usize,
452    tc: &TokenCounter,
453) -> Result<Option<Message>, ContextError> {
454    if budget_tokens == 0 || !memory.trajectory_config.enabled {
455        return Ok(None);
456    }
457    let Some(ref mem) = memory.memory else {
458        return Ok(None);
459    };
460
461    let top_k = memory.trajectory_config.recall_top_k;
462    let min_conf = memory.trajectory_config.min_confidence;
463    let entries = mem
464        .sqlite()
465        .load_trajectory_entries(Some("procedural"), top_k)
466        .await?;
467
468    if entries.is_empty() {
469        return Ok(None);
470    }
471
472    let mut body = String::from(crate::slot::TRAJECTORY_PREFIX);
473    let mut tokens_so_far = tc.count_tokens(&body);
474
475    for entry in entries
476        .iter()
477        .filter(|e| e.confidence >= min_conf)
478        .take(top_k)
479    {
480        let line = format!("- {}: {}\n", entry.intent, entry.outcome);
481        let line_tokens = tc.count_tokens(&line);
482        if tokens_so_far + line_tokens > budget_tokens {
483            break;
484        }
485        body.push_str(&line);
486        tokens_so_far += line_tokens;
487    }
488
489    if body == crate::slot::TRAJECTORY_PREFIX {
490        return Ok(None);
491    }
492
493    Ok(Some(Message::from_legacy(Role::System, body)))
494}
495
496pub(crate) async fn fetch_tree_memory(
497    memory: &ContextMemoryView,
498    budget_tokens: usize,
499    tc: &TokenCounter,
500) -> Result<Option<Message>, ContextError> {
501    if budget_tokens == 0 || !memory.tree_config.enabled {
502        return Ok(None);
503    }
504    let Some(ref mem) = memory.memory else {
505        return Ok(None);
506    };
507
508    let top_k = memory.tree_config.recall_top_k;
509    let nodes = mem.sqlite().load_tree_level(1, top_k).await?;
510
511    if nodes.is_empty() {
512        return Ok(None);
513    }
514
515    let mut body = String::from(crate::slot::TREE_MEMORY_PREFIX);
516    let mut tokens_so_far = tc.count_tokens(&body);
517
518    for node in nodes.iter().take(top_k) {
519        let line = format!("- {}\n", node.content);
520        let line_tokens = tc.count_tokens(&line);
521        if tokens_so_far + line_tokens > budget_tokens {
522            break;
523        }
524        body.push_str(&line);
525        tokens_so_far += line_tokens;
526    }
527
528    if body == crate::slot::TREE_MEMORY_PREFIX {
529        return Ok(None);
530    }
531
532    Ok(Some(Message::from_legacy(Role::System, body)))
533}
534
535pub(crate) async fn fetch_corrections(
536    memory: &ContextMemoryView,
537    query: &str,
538    limit: usize,
539    min_score: f32,
540    scrub: fn(&str) -> std::borrow::Cow<'_, str>,
541) -> Result<Option<Message>, ContextError> {
542    let Some(ref mem) = memory.memory else {
543        return Ok(None);
544    };
545    let corrections = mem
546        .retrieve_similar_corrections(query, limit, min_score)
547        .await
548        .unwrap_or_default();
549    if corrections.is_empty() {
550        return Ok(None);
551    }
552    let mut text = String::from(CORRECTIONS_PREFIX);
553    for c in &corrections {
554        text.push_str("- Past user correction: \"");
555        text.push_str(&scrub(&c.correction_text));
556        text.push_str("\"\n");
557    }
558    Ok(Some(Message::from_legacy(Role::System, text)))
559}
560
561pub(crate) async fn fetch_semantic_recall(
562    memory: &ContextMemoryView,
563    query: &str,
564    token_budget: usize,
565    tc: &TokenCounter,
566    router: Option<&dyn zeph_memory::AsyncMemoryRouter>,
567) -> Result<(Option<Message>, Option<f32>), ContextError> {
568    let Some(ref mem) = memory.memory else {
569        return Ok((None, None));
570    };
571    if memory.recall_limit == 0 || token_budget == 0 {
572        return Ok((None, None));
573    }
574
575    let recalled = if let Some(r) = router {
576        mem.recall_routed_async(query, memory.recall_limit, None, r)
577            .await?
578    } else {
579        mem.recall(query, memory.recall_limit, None).await?
580    };
581    if recalled.is_empty() {
582        return Ok((None, None));
583    }
584
585    let top_score = recalled.first().map(|r| r.score);
586
587    let mut recall_text = String::with_capacity(token_budget * 3);
588    recall_text.push_str(RECALL_PREFIX);
589    let mut tokens_used = tc.count_tokens(&recall_text);
590
591    for item in &recalled {
592        if item.message.content.starts_with("[skipped]")
593            || item.message.content.starts_with("[stopped]")
594        {
595            continue;
596        }
597        let role_label = match item.message.role {
598            Role::User => "user",
599            Role::Assistant => "assistant",
600            Role::System => "system",
601        };
602        let entry = format!("- [{}] {}\n", role_label, item.message.content);
603        let entry_tokens = tc.count_tokens(&entry);
604        if tokens_used + entry_tokens > token_budget {
605            break;
606        }
607        recall_text.push_str(&entry);
608        tokens_used += entry_tokens;
609    }
610
611    if tokens_used > tc.count_tokens(RECALL_PREFIX) {
612        Ok((
613            Some(Message::from_parts(
614                Role::System,
615                vec![MessagePart::Recall { text: recall_text }],
616            )),
617            top_score,
618        ))
619    } else {
620        Ok((None, None))
621    }
622}
623
624pub(crate) async fn fetch_document_rag(
625    memory: &ContextMemoryView,
626    query: &str,
627    token_budget: usize,
628    tc: &TokenCounter,
629) -> Result<Option<Message>, ContextError> {
630    if !memory.document_config.rag_enabled || token_budget == 0 {
631        return Ok(None);
632    }
633    let Some(ref mem) = memory.memory else {
634        return Ok(None);
635    };
636
637    let collection = &memory.document_config.collection;
638    let top_k = memory.document_config.top_k;
639    let points = mem
640        .search_document_collection(collection, query, top_k)
641        .await?;
642    if points.is_empty() {
643        return Ok(None);
644    }
645
646    let mut text = String::from(DOCUMENT_RAG_PREFIX);
647    let mut tokens_used = tc.count_tokens(&text);
648
649    for point in &points {
650        let chunk = point
651            .payload
652            .get("text")
653            .and_then(|v| v.as_str())
654            .unwrap_or_default();
655        if chunk.is_empty() {
656            continue;
657        }
658        let entry = format!("{chunk}\n");
659        let cost = tc.count_tokens(&entry);
660        if tokens_used + cost > token_budget {
661            break;
662        }
663        text.push_str(&entry);
664        tokens_used += cost;
665    }
666
667    if tokens_used > tc.count_tokens(DOCUMENT_RAG_PREFIX) {
668        Ok(Some(Message {
669            role: Role::System,
670            content: text,
671            parts: vec![],
672            metadata: MessageMetadata::default(),
673        }))
674    } else {
675        Ok(None)
676    }
677}
678
679pub(crate) async fn fetch_summaries(
680    memory: &ContextMemoryView,
681    token_budget: usize,
682    tc: &TokenCounter,
683) -> Result<Option<Message>, ContextError> {
684    let (Some(mem), Some(cid)) = (&memory.memory, memory.conversation_id) else {
685        return Ok(None);
686    };
687    if token_budget == 0 {
688        return Ok(None);
689    }
690
691    let summaries = mem.load_summaries(cid).await?;
692    if summaries.is_empty() {
693        return Ok(None);
694    }
695
696    let mut summary_text = String::from(SUMMARY_PREFIX);
697    let mut tokens_used = tc.count_tokens(&summary_text);
698
699    for summary in summaries.iter().rev() {
700        let first = summary.first_message_id.map_or(0, |m| m.0);
701        let last = summary.last_message_id.map_or(0, |m| m.0);
702        let entry = format!("- Messages {first}-{last}: {}\n", summary.content);
703        let cost = tc.count_tokens(&entry);
704        if tokens_used + cost > token_budget {
705            break;
706        }
707        summary_text.push_str(&entry);
708        tokens_used += cost;
709    }
710
711    if tokens_used > tc.count_tokens(SUMMARY_PREFIX) {
712        Ok(Some(Message::from_parts(
713            Role::System,
714            vec![MessagePart::Summary { text: summary_text }],
715        )))
716    } else {
717        Ok(None)
718    }
719}
720
721pub(crate) async fn fetch_cross_session(
722    memory: &ContextMemoryView,
723    query: &str,
724    token_budget: usize,
725    tc: &TokenCounter,
726) -> Result<Option<Message>, ContextError> {
727    let (Some(mem), Some(cid)) = (&memory.memory, memory.conversation_id) else {
728        return Ok(None);
729    };
730    if token_budget == 0 {
731        return Ok(None);
732    }
733
734    let threshold = memory.cross_session_score_threshold;
735    let results: Vec<_> = mem
736        .search_session_summaries(query, 5, Some(cid))
737        .await?
738        .into_iter()
739        .filter(|r| r.score >= threshold)
740        .collect();
741    if results.is_empty() {
742        return Ok(None);
743    }
744
745    let mut text = String::from(CROSS_SESSION_PREFIX);
746    let mut tokens_used = tc.count_tokens(&text);
747
748    for item in &results {
749        let entry = format!("- {}\n", item.summary_text);
750        let cost = tc.count_tokens(&entry);
751        if tokens_used + cost > token_budget {
752            break;
753        }
754        text.push_str(&entry);
755        tokens_used += cost;
756    }
757
758    if tokens_used > tc.count_tokens(CROSS_SESSION_PREFIX) {
759        Ok(Some(Message::from_parts(
760            Role::System,
761            vec![MessagePart::CrossSession { text }],
762        )))
763    } else {
764        Ok(None)
765    }
766}
767
768/// Maximum number of messages scanned backward by [`memory_first_keep_tail`] before
769/// stopping at the next non-`ToolResult` boundary, to avoid O(N) scans on long sessions.
770pub const MAX_KEEP_TAIL_SCAN: usize = 50;
771
772/// Compute how many tail messages to keep when the `MemoryFirst` strategy is active.
773///
774/// Always keeps at least 2 messages. Extends the tail as long as the boundary message is
775/// a `ToolResult` (user message with a `ToolResult` part) to avoid splitting a tool-call
776/// round-trip. Capped at `MAX_KEEP_TAIL_SCAN` to prevent O(N) scans on long sessions.
777///
778/// `history_start` is the index of the first non-system message (typically 1).
779#[must_use]
780pub fn memory_first_keep_tail(messages: &[Message], history_start: usize) -> usize {
781    use zeph_llm::provider::MessagePart;
782
783    let mut keep_tail = 2usize;
784    let len = messages.len();
785    let max = len.saturating_sub(history_start);
786
787    while keep_tail < max {
788        let first_retained = &messages[len - keep_tail];
789        let is_tool_result = first_retained.role == Role::User
790            && first_retained
791                .parts
792                .iter()
793                .any(|p| matches!(p, MessagePart::ToolResult { .. }));
794
795        if is_tool_result {
796            keep_tail += 1;
797        } else {
798            break;
799        }
800
801        if keep_tail >= MAX_KEEP_TAIL_SCAN {
802            let preceding_idx = len.saturating_sub(keep_tail + 1);
803            if preceding_idx >= history_start {
804                let preceding = &messages[preceding_idx];
805                let is_tool_use = preceding.role == Role::Assistant
806                    && preceding
807                        .parts
808                        .iter()
809                        .any(|p| matches!(p, MessagePart::ToolUse { .. }));
810                if is_tool_use {
811                    keep_tail += 1;
812                }
813            }
814            break;
815        }
816    }
817
818    keep_tail
819}
820
821#[cfg(test)]
822mod tests {
823    use super::*;
824    use crate::input::ContextMemoryView;
825    use zeph_config::{
826        ContextStrategy, DocumentConfig, GraphConfig, PersonaConfig, TrajectoryConfig, TreeConfig,
827    };
828    use zeph_memory::TokenCounter;
829
830    fn empty_view() -> ContextMemoryView {
831        ContextMemoryView {
832            memory: None,
833            conversation_id: None,
834            recall_limit: 10,
835            cross_session_score_threshold: 0.5,
836            context_strategy: ContextStrategy::default(),
837            crossover_turn_threshold: 5,
838            cached_session_digest: None,
839            graph_config: GraphConfig::default(),
840            document_config: DocumentConfig::default(),
841            persona_config: PersonaConfig::default(),
842            trajectory_config: TrajectoryConfig::default(),
843            tree_config: TreeConfig::default(),
844        }
845    }
846
847    // ── fetch_graph_facts ─────────────────────────────────────────────────────
848
849    #[tokio::test]
850    async fn fetch_graph_facts_returns_none_when_memory_is_none() {
851        let view = empty_view();
852        let tc = TokenCounter::new();
853        let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
854        assert!(result.is_none());
855    }
856
857    #[tokio::test]
858    async fn fetch_graph_facts_returns_none_when_budget_zero() {
859        let mut view = empty_view();
860        view.graph_config.enabled = true;
861        let tc = TokenCounter::new();
862        let result = fetch_graph_facts(&view, "test", 0, &tc).await.unwrap();
863        assert!(result.is_none());
864    }
865
866    #[tokio::test]
867    async fn fetch_graph_facts_returns_none_when_graph_disabled() {
868        let mut view = empty_view();
869        view.graph_config.enabled = false;
870        let tc = TokenCounter::new();
871        let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
872        assert!(result.is_none());
873    }
874
875    // ── fetch_persona_facts ───────────────────────────────────────────────────
876
877    #[tokio::test]
878    async fn fetch_persona_facts_returns_none_when_memory_is_none() {
879        let view = empty_view();
880        let tc = TokenCounter::new();
881        let result = fetch_persona_facts(&view, 1000, &tc).await.unwrap();
882        assert!(result.is_none());
883    }
884
885    #[tokio::test]
886    async fn fetch_persona_facts_returns_none_when_budget_zero() {
887        let mut view = empty_view();
888        view.persona_config.enabled = true;
889        let tc = TokenCounter::new();
890        let result = fetch_persona_facts(&view, 0, &tc).await.unwrap();
891        assert!(result.is_none());
892    }
893
894    // ── fetch_trajectory_hints ────────────────────────────────────────────────
895
896    #[tokio::test]
897    async fn fetch_trajectory_hints_returns_none_when_memory_is_none() {
898        let view = empty_view();
899        let tc = TokenCounter::new();
900        let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
901        assert!(result.is_none());
902    }
903
904    #[tokio::test]
905    async fn fetch_trajectory_hints_returns_none_when_budget_zero() {
906        let mut view = empty_view();
907        view.trajectory_config.enabled = true;
908        let tc = TokenCounter::new();
909        let result = fetch_trajectory_hints(&view, 0, &tc).await.unwrap();
910        assert!(result.is_none());
911    }
912
913    // ── fetch_tree_memory ─────────────────────────────────────────────────────
914
915    #[tokio::test]
916    async fn fetch_tree_memory_returns_none_when_memory_is_none() {
917        let view = empty_view();
918        let tc = TokenCounter::new();
919        let result = fetch_tree_memory(&view, 1000, &tc).await.unwrap();
920        assert!(result.is_none());
921    }
922
923    #[tokio::test]
924    async fn fetch_tree_memory_returns_none_when_budget_zero() {
925        let mut view = empty_view();
926        view.tree_config.enabled = true;
927        let tc = TokenCounter::new();
928        let result = fetch_tree_memory(&view, 0, &tc).await.unwrap();
929        assert!(result.is_none());
930    }
931
932    // ── fetch_corrections ─────────────────────────────────────────────────────
933
934    #[tokio::test]
935    async fn fetch_corrections_returns_none_when_memory_is_none() {
936        let view = empty_view();
937        let result = fetch_corrections(&view, "test", 10, 0.5, |s| s.into())
938            .await
939            .unwrap();
940        assert!(result.is_none());
941    }
942
943    // ── fetch_semantic_recall ─────────────────────────────────────────────────
944
945    #[tokio::test]
946    async fn fetch_semantic_recall_returns_none_when_memory_is_none() {
947        let view = empty_view();
948        let tc = TokenCounter::new();
949        let result = fetch_semantic_recall(&view, "test", 1000, &tc, None)
950            .await
951            .unwrap();
952        assert!(result.0.is_none() && result.1.is_none());
953    }
954
955    #[tokio::test]
956    async fn fetch_semantic_recall_returns_none_when_budget_zero() {
957        let view = empty_view();
958        let tc = TokenCounter::new();
959        let result = fetch_semantic_recall(&view, "test", 0, &tc, None)
960            .await
961            .unwrap();
962        assert!(result.0.is_none() && result.1.is_none());
963    }
964
965    // ── fetch_document_rag ────────────────────────────────────────────────────
966
967    #[tokio::test]
968    async fn fetch_document_rag_returns_none_when_memory_is_none() {
969        let mut view = empty_view();
970        view.document_config.rag_enabled = true;
971        let tc = TokenCounter::new();
972        let result = fetch_document_rag(&view, "test", 1000, &tc).await.unwrap();
973        assert!(result.is_none());
974    }
975
976    #[tokio::test]
977    async fn fetch_document_rag_returns_none_when_rag_disabled() {
978        let view = empty_view();
979        let tc = TokenCounter::new();
980        let result = fetch_document_rag(&view, "test", 1000, &tc).await.unwrap();
981        assert!(result.is_none());
982    }
983
984    // ── fetch_summaries ───────────────────────────────────────────────────────
985
986    #[tokio::test]
987    async fn fetch_summaries_returns_none_when_memory_is_none() {
988        let view = empty_view();
989        let tc = TokenCounter::new();
990        let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
991        assert!(result.is_none());
992    }
993
994    // ── fetch_cross_session ───────────────────────────────────────────────────
995
996    #[tokio::test]
997    async fn fetch_cross_session_returns_none_when_memory_is_none() {
998        let view = empty_view();
999        let tc = TokenCounter::new();
1000        let result = fetch_cross_session(&view, "test", 1000, &tc).await.unwrap();
1001        assert!(result.is_none());
1002    }
1003}