Skip to main content

zeph_agent_context/
compaction.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Task-aware pruning strategy for tool output eviction.
5//!
6//! Provides relevance scoring, subgoal tracking, density classification, and
7//! focus auto-consolidation. These types and functions are used by the
8//! [`crate::service::ContextService`] summarization path and are referenced from
9//! `zeph-core` via re-export.
10
11use std::collections::{HashMap, HashSet};
12use std::time::Duration;
13
14use zeph_llm::provider::{LlmProvider, Message, MessagePart, Role};
15use zeph_memory::TokenCounter;
16
17// ── Scoring ───────────────────────────────────────────────────────────────────
18
19/// Per-message relevance score used by task-aware and MIG pruning.
20#[derive(Debug, Clone)]
21pub struct BlockScore {
22    /// Index in the messages vec.
23    pub msg_index: usize,
24    /// Relevance to current task goal (0.0..1.0).
25    pub relevance: f32,
26    /// Redundancy relative to other high-relevance blocks (0.0..1.0).
27    pub redundancy: f32,
28    /// MIG = relevance − redundancy. Negative MIG = good eviction candidate.
29    pub mig: f32,
30}
31
32/// Common Rust/shell stop-words that dominate token overlap but carry no task signal.
33static STOP_WORDS: std::sync::LazyLock<HashSet<&'static str>> = std::sync::LazyLock::new(|| {
34    [
35        "fn", "pub", "let", "use", "mod", "impl", "struct", "enum", "trait", "type", "for", "if",
36        "else", "match", "return", "self", "super", "crate", "true", "false", "mut", "ref",
37        "where", "in", "as", "const", "static", "extern", "unsafe", "async", "await", "move",
38        "box", "dyn", "loop", "while", "break", "continue", "yield", "do", "try", "the", "a", "an",
39        "is", "are", "was", "be", "to", "of", "and", "or", "not", "with", "from", "by", "at", "on",
40        "in", "it", "this", "that", "have", "has", "had", "cargo", "rustc", "warning", "error",
41        "note", "help", "running",
42    ]
43    .into_iter()
44    .collect()
45});
46
47fn tokenize(text: &str) -> Vec<String> {
48    text.split(|c: char| !c.is_alphanumeric() && c != '_')
49        .filter(|t| t.len() >= 3)
50        .map(str::to_lowercase)
51        .filter(|t| !STOP_WORDS.contains(t.as_str()))
52        .collect()
53}
54
55#[allow(clippy::cast_precision_loss)]
56fn term_frequencies(tokens: &[String]) -> HashMap<String, f32> {
57    let mut counts: HashMap<String, usize> = HashMap::new();
58    for t in tokens {
59        *counts.entry(t.clone()).or_insert(0) += 1;
60    }
61    let total = tokens.len().max(1) as f32;
62    counts
63        .into_iter()
64        .map(|(k, v)| (k, v as f32 / total))
65        .collect()
66}
67
68fn tf_weighted_similarity(tf_a: &HashMap<String, f32>, tf_b: &HashMap<String, f32>) -> f32 {
69    let mut intersection = 0.0_f32;
70    let mut union = 0.0_f32;
71    for (term, freq_a) in tf_a {
72        if let Some(freq_b) = tf_b.get(term) {
73            intersection += freq_a.min(*freq_b);
74        }
75        union += *freq_a;
76    }
77    for (term, freq_b) in tf_b {
78        if !tf_a.contains_key(term) {
79            union += *freq_b;
80        }
81    }
82    if union == 0.0 {
83        0.0
84    } else {
85        intersection / union
86    }
87}
88
89/// Extract text content from a message suitable for scoring.
90#[must_use]
91pub fn extract_scorable_text(msg: &Message) -> String {
92    let mut parts_text = String::new();
93    for part in &msg.parts {
94        match part {
95            MessagePart::ToolOutput {
96                body, tool_name, ..
97            } => {
98                parts_text.push_str(tool_name.as_str());
99                parts_text.push(' ');
100                parts_text.push_str(body);
101                parts_text.push(' ');
102            }
103            MessagePart::ToolResult { content, .. } => {
104                parts_text.push_str(content);
105                parts_text.push(' ');
106            }
107            _ => {}
108        }
109    }
110    if parts_text.is_empty() {
111        msg.content.clone()
112    } else {
113        parts_text
114    }
115}
116
117/// Score each tool-output message block against the task goal.
118#[must_use]
119pub fn score_blocks_task_aware(
120    messages: &[Message],
121    task_goal: &str,
122    _tc: &TokenCounter,
123) -> Vec<BlockScore> {
124    let goal_tokens = tokenize(task_goal);
125    let goal_tf = term_frequencies(&goal_tokens);
126    let mut scores = Vec::new();
127    for (i, msg) in messages.iter().enumerate() {
128        if i == 0 || msg.metadata.focus_pinned {
129            continue;
130        }
131        let has_tool_output = msg.parts.iter().any(|p| {
132            matches!(
133                p,
134                MessagePart::ToolOutput { .. } | MessagePart::ToolResult { .. }
135            )
136        });
137        if !has_tool_output {
138            continue;
139        }
140        let text = extract_scorable_text(msg);
141        let tokens = tokenize(&text);
142        let tf = term_frequencies(&tokens);
143        let relevance = tf_weighted_similarity(&goal_tf, &tf);
144        scores.push(BlockScore {
145            msg_index: i,
146            relevance,
147            redundancy: 0.0,
148            mig: relevance,
149        });
150    }
151    scores
152}
153
154/// Score blocks using MIG (relevance − redundancy) with temporal partitioning.
155#[must_use]
156pub fn score_blocks_mig(
157    messages: &[Message],
158    task_goal: Option<&str>,
159    tc: &TokenCounter,
160) -> Vec<BlockScore> {
161    #[allow(clippy::cast_precision_loss)]
162    let mut scores = if let Some(goal) = task_goal {
163        score_blocks_task_aware(messages, goal, tc)
164    } else {
165        let total = messages.len();
166        messages
167            .iter()
168            .enumerate()
169            .filter(|(i, msg)| {
170                *i > 0
171                    && !msg.metadata.focus_pinned
172                    && msg.parts.iter().any(|p| {
173                        matches!(
174                            p,
175                            MessagePart::ToolOutput { .. } | MessagePart::ToolResult { .. }
176                        )
177                    })
178            })
179            .map(|(i, _)| {
180                let relevance = i as f32 / total as f32;
181                BlockScore {
182                    msg_index: i,
183                    relevance,
184                    redundancy: 0.0,
185                    mig: relevance,
186                }
187            })
188            .collect()
189    };
190    let texts: Vec<_> = scores
191        .iter()
192        .map(|s| {
193            let tokens = tokenize(&extract_scorable_text(&messages[s.msg_index]));
194            term_frequencies(&tokens)
195        })
196        .collect();
197    for i in 0..scores.len() {
198        let mut max_redundancy = 0.0_f32;
199        for j in 0..scores.len() {
200            if i == j {
201                continue;
202            }
203            if scores[j].relevance > scores[i].relevance {
204                let sim = tf_weighted_similarity(&texts[i], &texts[j]);
205                max_redundancy = max_redundancy.max(sim);
206            }
207        }
208        scores[i].redundancy = max_redundancy;
209        scores[i].mig = scores[i].relevance - max_redundancy;
210    }
211    scores
212}
213
214/// Score each tool-output message block by subgoal tier membership.
215#[must_use]
216#[allow(clippy::cast_precision_loss)]
217pub fn score_blocks_subgoal(
218    messages: &[Message],
219    registry: &SubgoalRegistry,
220    _tc: &TokenCounter,
221) -> Vec<BlockScore> {
222    let total = messages.len().max(1) as f32;
223    let mut scores = Vec::new();
224    for (i, msg) in messages.iter().enumerate() {
225        if i == 0 || msg.metadata.focus_pinned {
226            continue;
227        }
228        let has_tool_output = msg.parts.iter().any(|p| {
229            matches!(
230                p,
231                MessagePart::ToolOutput { .. } | MessagePart::ToolResult { .. }
232            )
233        });
234        if !has_tool_output {
235            continue;
236        }
237        let recency = i as f32 / total * 0.05;
238        let relevance = match registry.subgoal_state(i) {
239            Some(SubgoalState::Active) => 1.0_f32 + recency,
240            Some(SubgoalState::Completed) => 0.3_f32 + recency,
241            None => 0.1_f32 + recency,
242        };
243        scores.push(BlockScore {
244            msg_index: i,
245            relevance,
246            redundancy: 0.0,
247            mig: relevance,
248        });
249    }
250    scores
251}
252
253/// Score tool-output blocks using subgoal tiers combined with MIG redundancy.
254#[must_use]
255pub fn score_blocks_subgoal_mig(
256    messages: &[Message],
257    registry: &SubgoalRegistry,
258    tc: &TokenCounter,
259) -> Vec<BlockScore> {
260    let mut scores = score_blocks_subgoal(messages, registry, tc);
261    let texts: Vec<_> = scores
262        .iter()
263        .map(|s| {
264            let tokens = tokenize(&extract_scorable_text(&messages[s.msg_index]));
265            term_frequencies(&tokens)
266        })
267        .collect();
268    for i in 0..scores.len() {
269        let mut max_redundancy = 0.0_f32;
270        for j in 0..scores.len() {
271            if i == j {
272                continue;
273            }
274            if scores[j].relevance > scores[i].relevance {
275                let sim = tf_weighted_similarity(&texts[i], &texts[j]);
276                max_redundancy = max_redundancy.max(sim);
277            }
278        }
279        scores[i].redundancy = max_redundancy;
280        scores[i].mig = scores[i].relevance - max_redundancy;
281    }
282    scores
283}
284
285// ── SubgoalRegistry ───────────────────────────────────────────────────────────
286
287/// Unique identifier for a subgoal within a session.
288#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
289pub struct SubgoalId(pub u32);
290
291/// Lifecycle state of a subgoal.
292#[derive(Debug, Clone, Copy, PartialEq, Eq)]
293pub enum SubgoalState {
294    /// Currently being worked on. Messages tagged with this subgoal are protected.
295    Active,
296    /// Completed. Messages tagged with this subgoal are candidates for summarization.
297    Completed,
298}
299
300/// A tracked subgoal with message span.
301#[derive(Debug, Clone)]
302pub struct Subgoal {
303    pub id: SubgoalId,
304    pub description: String,
305    pub state: SubgoalState,
306    /// Index of the first message in this subgoal's span.
307    pub start_msg_index: usize,
308    /// Index of the last message known to belong to this subgoal.
309    pub end_msg_index: usize,
310}
311
312/// In-memory registry of all subgoals in the current session.
313///
314/// Not persisted across restarts — subgoal state is transient session data.
315#[derive(Debug, Default)]
316pub struct SubgoalRegistry {
317    pub subgoals: Vec<Subgoal>,
318    next_id: u32,
319    /// Maps message index → subgoal ID for fast lookup during compaction.
320    pub msg_to_subgoal: std::collections::HashMap<usize, SubgoalId>,
321    last_tagged_index: usize,
322}
323
324impl SubgoalRegistry {
325    /// Register a new active subgoal starting at the given message index.
326    ///
327    /// Auto-completes any existing Active subgoal before creating the new one.
328    pub fn push_active(&mut self, description: String, start_msg_index: usize) -> SubgoalId {
329        if let Some(active) = self
330            .subgoals
331            .iter_mut()
332            .find(|s| s.state == SubgoalState::Active)
333        {
334            active.state = SubgoalState::Completed;
335        }
336        let id = SubgoalId(self.next_id);
337        self.next_id = self.next_id.wrapping_add(1);
338        self.subgoals.push(Subgoal {
339            id,
340            description,
341            state: SubgoalState::Active,
342            start_msg_index,
343            end_msg_index: start_msg_index,
344        });
345        self.last_tagged_index = start_msg_index.saturating_sub(1);
346        id
347    }
348
349    /// Mark the current active subgoal as completed and assign an end boundary.
350    pub fn complete_active(&mut self, end_msg_index: usize) {
351        if let Some(active) = self
352            .subgoals
353            .iter_mut()
354            .find(|s| s.state == SubgoalState::Active)
355        {
356            active.state = SubgoalState::Completed;
357            active.end_msg_index = end_msg_index;
358        }
359    }
360
361    /// Extend the active subgoal to cover new messages up to `new_end`.
362    pub fn extend_active(&mut self, new_end: usize) {
363        if let Some(active) = self
364            .subgoals
365            .iter_mut()
366            .find(|s| s.state == SubgoalState::Active)
367        {
368            active.end_msg_index = new_end;
369            let start = self.last_tagged_index.saturating_add(1);
370            for idx in start..=new_end {
371                self.msg_to_subgoal.insert(idx, active.id);
372            }
373            if new_end >= start {
374                self.last_tagged_index = new_end;
375            }
376        }
377    }
378
379    /// Tag messages in range `[start, end]` with the given subgoal ID.
380    pub fn tag_range(&mut self, start: usize, end: usize, id: SubgoalId) {
381        for idx in start..=end {
382            self.msg_to_subgoal.insert(idx, id);
383        }
384        if end > self.last_tagged_index {
385            self.last_tagged_index = end;
386        }
387    }
388
389    /// Get the subgoal state for a given message index.
390    #[must_use]
391    pub fn subgoal_state(&self, msg_index: usize) -> Option<SubgoalState> {
392        let sg_id = self.msg_to_subgoal.get(&msg_index)?;
393        self.subgoals
394            .iter()
395            .find(|s| &s.id == sg_id)
396            .map(|s| s.state)
397    }
398
399    /// Get the current active subgoal (for debug output and TUI metrics).
400    #[must_use]
401    pub fn active_subgoal(&self) -> Option<&Subgoal> {
402        self.subgoals
403            .iter()
404            .find(|s| s.state == SubgoalState::Active)
405    }
406
407    /// Rebuild the registry after compaction.
408    ///
409    /// When `old_compact_end == 0`, repairs shifted indices without dropping subgoals.
410    /// When `old_compact_end > 0`, drops subgoals whose entire span was drained.
411    pub fn rebuild_after_compaction(&mut self, messages: &[Message], old_compact_end: usize) {
412        self.msg_to_subgoal.clear();
413        if self.subgoals.is_empty() {
414            self.last_tagged_index = 0;
415            return;
416        }
417        if old_compact_end > 0 {
418            self.subgoals
419                .retain(|s| s.state == SubgoalState::Active || s.end_msg_index >= old_compact_end);
420        }
421        if self.subgoals.is_empty() {
422            self.last_tagged_index = 0;
423            return;
424        }
425        let mut last_idx = 0usize;
426        for (i, _msg) in messages.iter().enumerate().skip(1) {
427            let id = self
428                .subgoals
429                .iter()
430                .filter(|s| s.state == SubgoalState::Active)
431                .find(|s| i >= s.start_msg_index && i <= s.end_msg_index)
432                .map(|s| s.id)
433                .or_else(|| {
434                    self.subgoals
435                        .iter()
436                        .filter(|s| s.state == SubgoalState::Completed)
437                        .find(|s| i >= s.start_msg_index && i <= s.end_msg_index)
438                        .map(|s| s.id)
439                });
440            if let Some(id) = id {
441                self.msg_to_subgoal.insert(i, id);
442                last_idx = i;
443            }
444        }
445        self.last_tagged_index = last_idx;
446    }
447}
448
449// ── ContentDensity ────────────────────────────────────────────────────────────
450
451/// Density classification for a message or segment.
452#[derive(Debug, Clone, Copy, PartialEq, Eq)]
453pub enum ContentDensity {
454    /// More than 50% of lines are structured (code fences, JSON, lists, shell output).
455    High,
456    /// 50% or fewer lines are structured.
457    Low,
458}
459
460/// Classify a message's content density.
461#[must_use]
462pub fn classify_density(content: &str) -> ContentDensity {
463    let lines: Vec<&str> = content.lines().collect();
464    if lines.is_empty() {
465        return ContentDensity::Low;
466    }
467    let structured = lines
468        .iter()
469        .filter(|line| {
470            let trimmed = line.trim_start();
471            trimmed.starts_with("```")
472                || trimmed.starts_with("~~~")
473                || trimmed.starts_with('{')
474                || trimmed.starts_with('[')
475                || trimmed.starts_with('|')
476                || trimmed.starts_with('$')
477                || trimmed.starts_with('>')
478                || trimmed.starts_with('#')
479                || (line.len() >= 4 && line.starts_with("    "))
480        })
481        .count();
482    #[allow(clippy::cast_precision_loss)]
483    let ratio = structured as f32 / lines.len() as f32;
484    if ratio > 0.5 {
485        ContentDensity::High
486    } else {
487        ContentDensity::Low
488    }
489}
490
491/// Partition messages into (high-density, low-density) groups.
492#[must_use]
493pub fn partition_by_density(messages: &[Message]) -> (Vec<Message>, Vec<Message>) {
494    let mut high = Vec::new();
495    let mut low = Vec::new();
496    for msg in messages {
497        if msg.metadata.focus_pinned {
498            continue;
499        }
500        match classify_density(&msg.content) {
501            ContentDensity::High => high.push(msg.clone()),
502            ContentDensity::Low => low.push(msg.clone()),
503        }
504    }
505    (high, low)
506}
507
508// ── SubgoalExtractionResult ───────────────────────────────────────────────────
509
510/// Output of a background subgoal extraction LLM call.
511#[derive(Debug)]
512pub struct SubgoalExtractionResult {
513    /// Current subgoal the agent is working toward.
514    pub current: String,
515    /// Just-completed subgoal, if the LLM detected a transition (`COMPLETED:` non-NONE).
516    pub completed: Option<String>,
517}
518
519// ── Focus auto-consolidation ──────────────────────────────────────────────────
520
521/// Automatically consolidate low-relevance context into a knowledge-block summary.
522///
523/// # Errors
524///
525/// Returns an error if the provider call returns an error or if the 20-second timeout
526/// elapses before the provider responds.
527pub async fn run_focus_auto_consolidation(
528    messages: &[Message],
529    min_window: usize,
530    provider: impl LlmProvider,
531    max_chars: usize,
532) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
533    let _span = tracing::info_span!("ctx.compaction.focus_auto_consolidate").entered();
534
535    if messages.len() < min_window {
536        return Ok(None);
537    }
538    let task_goal = messages
539        .iter()
540        .rev()
541        .find(|m| m.role == Role::User)
542        .map_or("", |m| m.content.as_str());
543    if task_goal.is_empty() {
544        tracing::debug!("focus_auto_consolidation: no user message found, skipping");
545        return Ok(None);
546    }
547    let messages_owned: Vec<Message> = messages.to_vec();
548    let task_goal_owned = task_goal.to_string();
549    let scores = tokio::task::spawn_blocking(move || {
550        let tc = TokenCounter::default();
551        score_blocks_mig(
552            &messages_owned,
553            Some(task_goal_owned.as_str()).filter(|s| !s.is_empty()),
554            &tc,
555        )
556    })
557    .await
558    .map_err(|e| format!("score_blocks_mig panicked: {e}"))?;
559
560    let low_relevance: HashSet<usize> = scores
561        .iter()
562        .filter(|s| s.mig <= 0.0)
563        .map(|s| s.msg_index)
564        .collect();
565    let window_indices = find_low_relevance_window(messages, &low_relevance, min_window);
566    if window_indices.is_empty() {
567        return Ok(None);
568    }
569    let combined: String = window_indices
570        .iter()
571        .map(|&i| extract_scorable_text(&messages[i]))
572        .collect::<Vec<_>>()
573        .join("\n---\n");
574    let prompt = format!(
575        "Extract up to 10 key facts the agent must remember from the following context. \
576         Return bullet points only (one per line, starting with `- `).\n\n{combined}"
577    );
578    let request = vec![Message::from_legacy(Role::User, &prompt)];
579    let raw = tokio::time::timeout(Duration::from_secs(20), provider.chat(&request))
580        .await
581        .map_err(|_| {
582            Box::new(std::io::Error::other(
583                "focus auto-consolidation timed out after 20s",
584            )) as Box<dyn std::error::Error + Send + Sync>
585        })?
586        .map_err(|e| {
587            Box::new(std::io::Error::other(format!(
588                "focus auto-consolidation provider error: {e}"
589            ))) as Box<dyn std::error::Error + Send + Sync>
590        })?;
591    let truncated = if raw.len() <= max_chars {
592        raw
593    } else {
594        let boundary = raw
595            .char_indices()
596            .map(|(i, _)| i)
597            .take_while(|&i| i <= max_chars)
598            .last()
599            .unwrap_or(0);
600        raw[..boundary].to_owned()
601    };
602    if truncated.is_empty() {
603        return Ok(None);
604    }
605    Ok(Some(truncated))
606}
607
608fn find_low_relevance_window(
609    messages: &[Message],
610    low_relevance: &HashSet<usize>,
611    min_window: usize,
612) -> Vec<usize> {
613    let mut best: Vec<usize> = Vec::new();
614    let mut current: Vec<usize> = Vec::new();
615    for (i, msg) in messages.iter().enumerate() {
616        if i == 0 || msg.metadata.focus_pinned {
617            current.clear();
618            continue;
619        }
620        if low_relevance.contains(&i) {
621            current.push(i);
622        } else {
623            if current.len() >= min_window && best.is_empty() {
624                best.append(&mut current);
625            }
626            current.clear();
627        }
628    }
629    if current.len() >= min_window && best.is_empty() {
630        best = current;
631    }
632    best
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use std::collections::HashMap;
639
640    #[test]
641    fn tokenize_filters_stop_words() {
642        let tokens = tokenize("fn main() { let x = 5; }");
643        assert!(!tokens.contains(&"fn".to_string()));
644        assert!(!tokens.contains(&"let".to_string()));
645    }
646
647    #[test]
648    fn tokenize_keeps_meaningful_tokens() {
649        let tokens = tokenize("authentication middleware session");
650        assert!(tokens.contains(&"authentication".to_string()));
651        assert!(tokens.contains(&"middleware".to_string()));
652        assert!(tokens.contains(&"session".to_string()));
653    }
654
655    #[test]
656    fn tf_weighted_similarity_identical_is_one() {
657        let tokens = tokenize("authentication session token");
658        let tf = term_frequencies(&tokens);
659        let sim = tf_weighted_similarity(&tf, &tf);
660        assert!((sim - 1.0).abs() < f32::EPSILON);
661    }
662
663    #[test]
664    fn tf_weighted_similarity_disjoint_is_zero() {
665        let tokens_a = tokenize("authentication session");
666        let tokens_b = tokenize("database migration schema");
667        let tf_a = term_frequencies(&tokens_a);
668        let tf_b = term_frequencies(&tokens_b);
669        assert!(tf_weighted_similarity(&tf_a, &tf_b).abs() < f32::EPSILON);
670    }
671
672    #[test]
673    fn tf_weighted_similarity_empty_is_zero() {
674        let tf_empty: HashMap<String, f32> = HashMap::new();
675        let tokens = tokenize("authentication session");
676        let tf = term_frequencies(&tokens);
677        assert!(tf_weighted_similarity(&tf_empty, &tf).abs() < f32::EPSILON);
678    }
679
680    fn make_tool_output_msg(body: &str) -> Message {
681        use zeph_llm::provider::{MessageMetadata, MessagePart};
682        let mut msg = Message {
683            role: Role::User,
684            content: body.to_string(),
685            parts: vec![MessagePart::ToolOutput {
686                tool_name: "read".into(),
687                body: body.to_string(),
688                compacted_at: None,
689            }],
690            metadata: MessageMetadata::default(),
691        };
692        msg.rebuild_content();
693        msg
694    }
695
696    #[test]
697    fn score_blocks_task_aware_skips_system_prompt() {
698        let tc = TokenCounter::default();
699        let messages = vec![
700            Message::from_legacy(Role::System, "system prompt"),
701            make_tool_output_msg("authentication session middleware"),
702        ];
703        let scores = score_blocks_task_aware(&messages, "authentication session", &tc);
704        assert_eq!(scores.len(), 1);
705        assert_eq!(scores[0].msg_index, 1);
706    }
707
708    #[test]
709    fn score_blocks_task_aware_skips_pinned_messages() {
710        use zeph_llm::provider::MessageMetadata;
711        let tc = TokenCounter::default();
712        let mut pinned_meta = MessageMetadata::focus_pinned();
713        pinned_meta.focus_pinned = true;
714        let pinned = Message {
715            role: Role::System,
716            content: "authentication session knowledge".to_string(),
717            parts: vec![],
718            metadata: pinned_meta,
719        };
720        let messages = vec![
721            Message::from_legacy(Role::System, "sys"),
722            pinned,
723            make_tool_output_msg("authentication session"),
724        ];
725        let scores = score_blocks_task_aware(&messages, "authentication session", &tc);
726        assert!(scores.iter().all(|s| s.msg_index != 1));
727    }
728
729    #[test]
730    fn score_blocks_task_aware_relevant_block_scores_higher() {
731        let tc = TokenCounter::default();
732        let messages = vec![
733            Message::from_legacy(Role::System, "sys"),
734            make_tool_output_msg("authentication middleware session token implementation"),
735            make_tool_output_msg("database schema migration foreign key index"),
736        ];
737        let scores = score_blocks_task_aware(&messages, "authentication session token", &tc);
738        assert_eq!(scores.len(), 2);
739        let auth_score = scores.iter().find(|s| s.msg_index == 1).unwrap();
740        let db_score = scores.iter().find(|s| s.msg_index == 2).unwrap();
741        assert!(
742            auth_score.relevance > db_score.relevance,
743            "auth block must score higher than db block"
744        );
745    }
746
747    #[test]
748    fn subgoal_registry_push_active_creates_active_subgoal() {
749        let mut registry = SubgoalRegistry::default();
750        let id = registry.push_active("Implement login endpoint".into(), 1);
751        assert_eq!(registry.subgoals.len(), 1);
752        assert_eq!(registry.subgoals[0].id, id);
753        assert_eq!(registry.subgoals[0].state, SubgoalState::Active);
754    }
755
756    #[test]
757    fn subgoal_registry_complete_active_transitions_state() {
758        let mut registry = SubgoalRegistry::default();
759        registry.push_active("initial subgoal".into(), 1);
760        registry.complete_active(5);
761        assert_eq!(registry.subgoals[0].state, SubgoalState::Completed);
762        assert!(registry.active_subgoal().is_none());
763    }
764
765    #[test]
766    fn subgoal_registry_push_active_auto_completes_existing_active() {
767        let mut registry = SubgoalRegistry::default();
768        registry.push_active("first subgoal".into(), 1);
769        registry.push_active("second subgoal".into(), 6);
770        assert_eq!(registry.subgoals[0].state, SubgoalState::Completed);
771        assert_eq!(registry.subgoals[1].state, SubgoalState::Active);
772        let active_count = registry
773            .subgoals
774            .iter()
775            .filter(|s| s.state == SubgoalState::Active)
776            .count();
777        assert_eq!(active_count, 1);
778    }
779
780    #[test]
781    fn subgoal_registry_extend_active_tags_incrementally() {
782        let mut registry = SubgoalRegistry::default();
783        let id = registry.push_active("subgoal".into(), 3);
784        registry.extend_active(5);
785        assert_eq!(registry.subgoal_state(3), Some(SubgoalState::Active));
786        assert_eq!(registry.subgoal_state(4), Some(SubgoalState::Active));
787        assert_eq!(registry.subgoal_state(5), Some(SubgoalState::Active));
788        assert_eq!(registry.msg_to_subgoal.get(&3), Some(&id));
789        registry.extend_active(7);
790        assert_eq!(registry.subgoal_state(6), Some(SubgoalState::Active));
791        assert_eq!(registry.subgoal_state(7), Some(SubgoalState::Active));
792        assert_eq!(registry.msg_to_subgoal.len(), 5);
793    }
794
795    #[test]
796    fn subgoal_registry_subgoal_state_returns_correct_tier() {
797        let mut registry = SubgoalRegistry::default();
798        registry.push_active("completed subgoal".into(), 1);
799        registry.tag_range(1, 5, SubgoalId(0));
800        registry.complete_active(5);
801        registry.push_active("active subgoal".into(), 6);
802        registry.extend_active(9);
803        assert_eq!(registry.subgoal_state(1), Some(SubgoalState::Completed));
804        assert_eq!(registry.subgoal_state(6), Some(SubgoalState::Active));
805        assert_eq!(registry.subgoal_state(0), None);
806    }
807
808    #[test]
809    fn classify_density_empty_string_is_low() {
810        assert_eq!(classify_density(""), ContentDensity::Low);
811    }
812
813    #[test]
814    fn classify_density_all_structured_is_high() {
815        let content = "```rust\nfn main() {}\n```\n$ cargo build\n";
816        assert_eq!(classify_density(content), ContentDensity::High);
817    }
818
819    #[test]
820    fn classify_density_all_prose_is_low() {
821        let content = "This is a sentence.\nAnother sentence here.\nNo structured content at all.";
822        assert_eq!(classify_density(content), ContentDensity::Low);
823    }
824
825    // ─── run_focus_auto_consolidation tests ──────────────────────────────────
826
827    struct StubProvider {
828        response: &'static str,
829    }
830
831    impl zeph_llm::provider::LlmProvider for StubProvider {
832        async fn chat(&self, _messages: &[Message]) -> Result<String, zeph_llm::LlmError> {
833            Ok(self.response.to_owned())
834        }
835
836        async fn chat_stream(
837            &self,
838            messages: &[Message],
839        ) -> Result<zeph_llm::provider::ChatStream, zeph_llm::LlmError> {
840            let r = self.chat(messages).await?;
841            Ok(Box::pin(futures::stream::once(async move {
842                Ok::<_, zeph_llm::LlmError>(zeph_llm::provider::StreamChunk::Content(r))
843            })))
844        }
845
846        fn supports_streaming(&self) -> bool {
847            false
848        }
849
850        async fn embed(&self, _text: &str) -> Result<Vec<f32>, zeph_llm::LlmError> {
851            Ok(vec![])
852        }
853
854        fn supports_embeddings(&self) -> bool {
855            false
856        }
857
858        fn name(&self) -> &'static str {
859            "stub"
860        }
861    }
862
863    struct HangingProvider;
864
865    impl zeph_llm::provider::LlmProvider for HangingProvider {
866        async fn chat(&self, _messages: &[Message]) -> Result<String, zeph_llm::LlmError> {
867            std::future::pending::<()>().await;
868            unreachable!()
869        }
870
871        async fn chat_stream(
872            &self,
873            _messages: &[Message],
874        ) -> Result<zeph_llm::provider::ChatStream, zeph_llm::LlmError> {
875            std::future::pending::<()>().await;
876            unreachable!()
877        }
878
879        fn supports_streaming(&self) -> bool {
880            false
881        }
882
883        async fn embed(&self, _text: &str) -> Result<Vec<f32>, zeph_llm::LlmError> {
884            Ok(vec![])
885        }
886
887        fn supports_embeddings(&self) -> bool {
888            false
889        }
890
891        fn name(&self) -> &'static str {
892            "hanging"
893        }
894    }
895
896    #[tokio::test]
897    async fn run_focus_auto_consolidation_returns_none_for_small_history() {
898        let messages = vec![
899            Message::from_legacy(Role::System, "sys"),
900            make_tool_output_msg("some tool output here"),
901        ];
902        // min_window = 6, but only 2 messages → None.
903        let result = run_focus_auto_consolidation(
904            &messages,
905            6,
906            StubProvider {
907                response: "- fact one",
908            },
909            4096,
910        )
911        .await
912        .unwrap();
913        assert!(result.is_none());
914    }
915
916    #[tokio::test]
917    async fn run_focus_auto_consolidation_produces_summary() {
918        let mut messages = vec![Message::from_legacy(Role::System, "sys")];
919        for _ in 0..6 {
920            messages.push(make_tool_output_msg(
921                "database schema migration foreign key index",
922            ));
923        }
924        messages.push(Message::from_legacy(
925            Role::User,
926            "Help me with authentication",
927        ));
928
929        let result = run_focus_auto_consolidation(
930            &messages,
931            4,
932            StubProvider {
933                response: "- database schema uses foreign keys",
934            },
935            4096,
936        )
937        .await
938        .unwrap();
939
940        assert!(result.is_some());
941        let summary = result.unwrap();
942        assert!(!summary.is_empty());
943    }
944
945    #[tokio::test]
946    async fn run_focus_auto_consolidation_skips_when_no_user_message() {
947        // S2/S3: when no User message is present, must return None instead of
948        // entering recency mode and eagerly consolidating all history.
949        let mut messages = vec![Message::from_legacy(Role::System, "sys")];
950        for i in 0..8 {
951            messages.push(make_tool_output_msg(&format!("tool output {i}")));
952        }
953
954        let result = run_focus_auto_consolidation(
955            &messages,
956            4,
957            StubProvider {
958                response: "- should not be reached",
959            },
960            4096,
961        )
962        .await
963        .unwrap();
964
965        assert!(
966            result.is_none(),
967            "must return None when no user message is present (S2/S3)"
968        );
969    }
970
971    #[tokio::test]
972    async fn auto_consolidation_timeout_recovers() {
973        let mut messages = vec![Message::from_legacy(Role::System, "sys")];
974        for _ in 0..6 {
975            messages.push(make_tool_output_msg(
976                "database schema migration foreign key index",
977            ));
978        }
979        messages.push(Message::from_legacy(
980            Role::User,
981            "Help me with authentication",
982        ));
983
984        // Wrap in a short timeout to avoid waiting the full 20s internal timeout.
985        let result = tokio::time::timeout(
986            std::time::Duration::from_millis(50),
987            run_focus_auto_consolidation(&messages, 4, HangingProvider, 4096),
988        )
989        .await;
990
991        // Either: outer timeout fires (Err), or inner 20s timeout fires (Ok(Err)).
992        // Both cases must not panic.
993        match result {
994            Err(_elapsed) => {
995                // Outer timeout fired — no panic, correct.
996            }
997            Ok(inner) => {
998                assert!(inner.is_err(), "hanging provider must return an error");
999            }
1000        }
1001    }
1002}