Skip to main content

zeph_agent_context/
compaction.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Task-aware pruning strategy for tool output eviction.
5//!
6//! Provides relevance scoring, subgoal tracking, density classification, and
7//! focus auto-consolidation. These types and functions are used by the
8//! [`crate::service::ContextService`] summarization path and are referenced from
9//! `zeph-core` via re-export.
10
11use std::collections::{HashMap, HashSet};
12use std::time::Duration;
13
14use zeph_llm::provider::{LlmProvider, Message, MessagePart, Role};
15use zeph_memory::TokenCounter;
16
17// ── Scoring ───────────────────────────────────────────────────────────────────
18
19/// Per-message relevance score used by task-aware and MIG pruning.
20#[derive(Debug, Clone)]
21pub struct BlockScore {
22    /// Index in the messages vec.
23    pub msg_index: usize,
24    /// Relevance to current task goal (0.0..1.0).
25    pub relevance: f32,
26    /// Redundancy relative to other high-relevance blocks (0.0..1.0).
27    pub redundancy: f32,
28    /// MIG = relevance − redundancy. Negative MIG = good eviction candidate.
29    pub mig: f32,
30}
31
32/// Common Rust/shell stop-words that dominate token overlap but carry no task signal.
33static STOP_WORDS: std::sync::LazyLock<HashSet<&'static str>> = std::sync::LazyLock::new(|| {
34    [
35        "fn", "pub", "let", "use", "mod", "impl", "struct", "enum", "trait", "type", "for", "if",
36        "else", "match", "return", "self", "super", "crate", "true", "false", "mut", "ref",
37        "where", "in", "as", "const", "static", "extern", "unsafe", "async", "await", "move",
38        "box", "dyn", "loop", "while", "break", "continue", "yield", "do", "try", "the", "a", "an",
39        "is", "are", "was", "be", "to", "of", "and", "or", "not", "with", "from", "by", "at", "on",
40        "in", "it", "this", "that", "have", "has", "had", "cargo", "rustc", "warning", "error",
41        "note", "help", "running",
42    ]
43    .into_iter()
44    .collect()
45});
46
47fn tokenize(text: &str) -> Vec<String> {
48    text.split(|c: char| !c.is_alphanumeric() && c != '_')
49        .filter(|t| t.len() >= 3)
50        .map(str::to_lowercase)
51        .filter(|t| !STOP_WORDS.contains(t.as_str()))
52        .collect()
53}
54
55#[allow(clippy::cast_precision_loss)]
56fn term_frequencies(tokens: &[String]) -> HashMap<String, f32> {
57    let mut counts: HashMap<String, usize> = HashMap::new();
58    for t in tokens {
59        *counts.entry(t.clone()).or_insert(0) += 1;
60    }
61    let total = tokens.len().max(1) as f32;
62    counts
63        .into_iter()
64        .map(|(k, v)| (k, v as f32 / total))
65        .collect()
66}
67
68fn tf_weighted_similarity(tf_a: &HashMap<String, f32>, tf_b: &HashMap<String, f32>) -> f32 {
69    let mut intersection = 0.0_f32;
70    let mut union = 0.0_f32;
71    for (term, freq_a) in tf_a {
72        if let Some(freq_b) = tf_b.get(term) {
73            intersection += freq_a.min(*freq_b);
74        }
75        union += *freq_a;
76    }
77    for (term, freq_b) in tf_b {
78        if !tf_a.contains_key(term) {
79            union += *freq_b;
80        }
81    }
82    if union == 0.0 {
83        0.0
84    } else {
85        intersection / union
86    }
87}
88
89/// Extract text content from a message suitable for scoring.
90#[must_use]
91pub fn extract_scorable_text(msg: &Message) -> String {
92    let mut parts_text = String::new();
93    for part in &msg.parts {
94        match part {
95            MessagePart::ToolOutput {
96                body, tool_name, ..
97            } => {
98                parts_text.push_str(tool_name.as_str());
99                parts_text.push(' ');
100                parts_text.push_str(body);
101                parts_text.push(' ');
102            }
103            MessagePart::ToolResult { content, .. } => {
104                parts_text.push_str(content);
105                parts_text.push(' ');
106            }
107            _ => {}
108        }
109    }
110    if parts_text.is_empty() {
111        msg.content.clone()
112    } else {
113        parts_text
114    }
115}
116
117/// Score each tool-output message block against the task goal.
118#[must_use]
119pub fn score_blocks_task_aware(
120    messages: &[Message],
121    task_goal: &str,
122    _tc: &TokenCounter,
123) -> Vec<BlockScore> {
124    let goal_tokens = tokenize(task_goal);
125    let goal_tf = term_frequencies(&goal_tokens);
126    let mut scores = Vec::new();
127    for (i, msg) in messages.iter().enumerate() {
128        if i == 0 || msg.metadata.focus_pinned {
129            continue;
130        }
131        let has_tool_output = msg.parts.iter().any(|p| {
132            matches!(
133                p,
134                MessagePart::ToolOutput { .. } | MessagePart::ToolResult { .. }
135            )
136        });
137        if !has_tool_output {
138            continue;
139        }
140        let text = extract_scorable_text(msg);
141        let tokens = tokenize(&text);
142        let tf = term_frequencies(&tokens);
143        let relevance = tf_weighted_similarity(&goal_tf, &tf);
144        scores.push(BlockScore {
145            msg_index: i,
146            relevance,
147            redundancy: 0.0,
148            mig: relevance,
149        });
150    }
151    scores
152}
153
154/// Score blocks using MIG (relevance − redundancy) with temporal partitioning.
155#[must_use]
156pub fn score_blocks_mig(
157    messages: &[Message],
158    task_goal: Option<&str>,
159    tc: &TokenCounter,
160) -> Vec<BlockScore> {
161    #[allow(clippy::cast_precision_loss)]
162    let mut scores = if let Some(goal) = task_goal {
163        score_blocks_task_aware(messages, goal, tc)
164    } else {
165        let total = messages.len();
166        messages
167            .iter()
168            .enumerate()
169            .filter(|(i, msg)| {
170                *i > 0
171                    && !msg.metadata.focus_pinned
172                    && msg.parts.iter().any(|p| {
173                        matches!(
174                            p,
175                            MessagePart::ToolOutput { .. } | MessagePart::ToolResult { .. }
176                        )
177                    })
178            })
179            .map(|(i, _)| {
180                let relevance = i as f32 / total as f32;
181                BlockScore {
182                    msg_index: i,
183                    relevance,
184                    redundancy: 0.0,
185                    mig: relevance,
186                }
187            })
188            .collect()
189    };
190    let texts: Vec<_> = scores
191        .iter()
192        .map(|s| {
193            let tokens = tokenize(&extract_scorable_text(&messages[s.msg_index]));
194            term_frequencies(&tokens)
195        })
196        .collect();
197    for i in 0..scores.len() {
198        let mut max_redundancy = 0.0_f32;
199        for j in 0..scores.len() {
200            if i == j {
201                continue;
202            }
203            if scores[j].relevance > scores[i].relevance {
204                let sim = tf_weighted_similarity(&texts[i], &texts[j]);
205                max_redundancy = max_redundancy.max(sim);
206            }
207        }
208        scores[i].redundancy = max_redundancy;
209        scores[i].mig = scores[i].relevance - max_redundancy;
210    }
211    scores
212}
213
214/// Score each tool-output message block by subgoal tier membership.
215#[must_use]
216#[allow(clippy::cast_precision_loss)]
217pub fn score_blocks_subgoal(
218    messages: &[Message],
219    registry: &SubgoalRegistry,
220    _tc: &TokenCounter,
221) -> Vec<BlockScore> {
222    let total = messages.len().max(1) as f32;
223    let mut scores = Vec::new();
224    for (i, msg) in messages.iter().enumerate() {
225        if i == 0 || msg.metadata.focus_pinned {
226            continue;
227        }
228        let has_tool_output = msg.parts.iter().any(|p| {
229            matches!(
230                p,
231                MessagePart::ToolOutput { .. } | MessagePart::ToolResult { .. }
232            )
233        });
234        if !has_tool_output {
235            continue;
236        }
237        let recency = i as f32 / total * 0.05;
238        let relevance = match registry.subgoal_state(i) {
239            Some(SubgoalState::Active) => 1.0_f32 + recency,
240            Some(SubgoalState::Completed) => 0.3_f32 + recency,
241            None => 0.1_f32 + recency,
242        };
243        scores.push(BlockScore {
244            msg_index: i,
245            relevance,
246            redundancy: 0.0,
247            mig: relevance,
248        });
249    }
250    scores
251}
252
253/// Score tool-output blocks using subgoal tiers combined with MIG redundancy.
254#[must_use]
255pub fn score_blocks_subgoal_mig(
256    messages: &[Message],
257    registry: &SubgoalRegistry,
258    tc: &TokenCounter,
259) -> Vec<BlockScore> {
260    let mut scores = score_blocks_subgoal(messages, registry, tc);
261    let texts: Vec<_> = scores
262        .iter()
263        .map(|s| {
264            let tokens = tokenize(&extract_scorable_text(&messages[s.msg_index]));
265            term_frequencies(&tokens)
266        })
267        .collect();
268    for i in 0..scores.len() {
269        let mut max_redundancy = 0.0_f32;
270        for j in 0..scores.len() {
271            if i == j {
272                continue;
273            }
274            if scores[j].relevance > scores[i].relevance {
275                let sim = tf_weighted_similarity(&texts[i], &texts[j]);
276                max_redundancy = max_redundancy.max(sim);
277            }
278        }
279        scores[i].redundancy = max_redundancy;
280        scores[i].mig = scores[i].relevance - max_redundancy;
281    }
282    scores
283}
284
285// ── SubgoalRegistry ───────────────────────────────────────────────────────────
286
287/// Unique identifier for a subgoal within a session.
288#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
289pub struct SubgoalId(pub u32);
290
291/// Lifecycle state of a subgoal.
292#[derive(Debug, Clone, Copy, PartialEq, Eq)]
293#[non_exhaustive]
294pub enum SubgoalState {
295    /// Currently being worked on. Messages tagged with this subgoal are protected.
296    Active,
297    /// Completed. Messages tagged with this subgoal are candidates for summarization.
298    Completed,
299}
300
301/// A tracked subgoal with message span.
302#[derive(Debug, Clone)]
303pub struct Subgoal {
304    pub id: SubgoalId,
305    pub description: String,
306    pub state: SubgoalState,
307    /// Index of the first message in this subgoal's span.
308    pub start_msg_index: usize,
309    /// Index of the last message known to belong to this subgoal.
310    pub end_msg_index: usize,
311}
312
313/// In-memory registry of all subgoals in the current session.
314///
315/// Not persisted across restarts — subgoal state is transient session data.
316#[derive(Debug, Default)]
317pub struct SubgoalRegistry {
318    pub subgoals: Vec<Subgoal>,
319    next_id: u32,
320    /// Maps message index → subgoal ID for fast lookup during compaction.
321    pub msg_to_subgoal: std::collections::HashMap<usize, SubgoalId>,
322    last_tagged_index: usize,
323}
324
325impl SubgoalRegistry {
326    /// Register a new active subgoal starting at the given message index.
327    ///
328    /// Auto-completes any existing Active subgoal before creating the new one.
329    pub fn push_active(&mut self, description: String, start_msg_index: usize) -> SubgoalId {
330        if let Some(active) = self
331            .subgoals
332            .iter_mut()
333            .find(|s| s.state == SubgoalState::Active)
334        {
335            active.state = SubgoalState::Completed;
336        }
337        let id = SubgoalId(self.next_id);
338        self.next_id = self.next_id.wrapping_add(1);
339        self.subgoals.push(Subgoal {
340            id,
341            description,
342            state: SubgoalState::Active,
343            start_msg_index,
344            end_msg_index: start_msg_index,
345        });
346        self.last_tagged_index = start_msg_index.saturating_sub(1);
347        id
348    }
349
350    /// Mark the current active subgoal as completed and assign an end boundary.
351    pub fn complete_active(&mut self, end_msg_index: usize) {
352        if let Some(active) = self
353            .subgoals
354            .iter_mut()
355            .find(|s| s.state == SubgoalState::Active)
356        {
357            active.state = SubgoalState::Completed;
358            active.end_msg_index = end_msg_index;
359        }
360    }
361
362    /// Extend the active subgoal to cover new messages up to `new_end`.
363    pub fn extend_active(&mut self, new_end: usize) {
364        if let Some(active) = self
365            .subgoals
366            .iter_mut()
367            .find(|s| s.state == SubgoalState::Active)
368        {
369            active.end_msg_index = new_end;
370            let start = self.last_tagged_index.saturating_add(1);
371            for idx in start..=new_end {
372                self.msg_to_subgoal.insert(idx, active.id);
373            }
374            if new_end >= start {
375                self.last_tagged_index = new_end;
376            }
377        }
378    }
379
380    /// Tag messages in range `[start, end]` with the given subgoal ID.
381    pub fn tag_range(&mut self, start: usize, end: usize, id: SubgoalId) {
382        for idx in start..=end {
383            self.msg_to_subgoal.insert(idx, id);
384        }
385        if end > self.last_tagged_index {
386            self.last_tagged_index = end;
387        }
388    }
389
390    /// Get the subgoal state for a given message index.
391    #[must_use]
392    pub fn subgoal_state(&self, msg_index: usize) -> Option<SubgoalState> {
393        let sg_id = self.msg_to_subgoal.get(&msg_index)?;
394        self.subgoals
395            .iter()
396            .find(|s| &s.id == sg_id)
397            .map(|s| s.state)
398    }
399
400    /// Get the current active subgoal (for debug output and TUI metrics).
401    #[must_use]
402    pub fn active_subgoal(&self) -> Option<&Subgoal> {
403        self.subgoals
404            .iter()
405            .find(|s| s.state == SubgoalState::Active)
406    }
407
408    /// Rebuild the registry after compaction.
409    ///
410    /// When `old_compact_end == 0`, repairs shifted indices without dropping subgoals.
411    /// When `old_compact_end > 0`, drops subgoals whose entire span was drained.
412    pub fn rebuild_after_compaction(&mut self, messages: &[Message], old_compact_end: usize) {
413        self.msg_to_subgoal.clear();
414        if self.subgoals.is_empty() {
415            self.last_tagged_index = 0;
416            return;
417        }
418        if old_compact_end > 0 {
419            self.subgoals
420                .retain(|s| s.state == SubgoalState::Active || s.end_msg_index >= old_compact_end);
421        }
422        if self.subgoals.is_empty() {
423            self.last_tagged_index = 0;
424            return;
425        }
426        let mut last_idx = 0usize;
427        for (i, _msg) in messages.iter().enumerate().skip(1) {
428            let id = self
429                .subgoals
430                .iter()
431                .filter(|s| s.state == SubgoalState::Active)
432                .find(|s| i >= s.start_msg_index && i <= s.end_msg_index)
433                .map(|s| s.id)
434                .or_else(|| {
435                    self.subgoals
436                        .iter()
437                        .filter(|s| s.state == SubgoalState::Completed)
438                        .find(|s| i >= s.start_msg_index && i <= s.end_msg_index)
439                        .map(|s| s.id)
440                });
441            if let Some(id) = id {
442                self.msg_to_subgoal.insert(i, id);
443                last_idx = i;
444            }
445        }
446        self.last_tagged_index = last_idx;
447    }
448}
449
450// ── ContentDensity ────────────────────────────────────────────────────────────
451
452/// Density classification for a message or segment.
453#[derive(Debug, Clone, Copy, PartialEq, Eq)]
454#[non_exhaustive]
455pub enum ContentDensity {
456    /// More than 50% of lines are structured (code fences, JSON, lists, shell output).
457    High,
458    /// 50% or fewer lines are structured.
459    Low,
460}
461
462/// Classify a message's content density.
463#[must_use]
464pub fn classify_density(content: &str) -> ContentDensity {
465    let lines: Vec<&str> = content.lines().collect();
466    if lines.is_empty() {
467        return ContentDensity::Low;
468    }
469    let structured = lines
470        .iter()
471        .filter(|line| {
472            let trimmed = line.trim_start();
473            trimmed.starts_with("```")
474                || trimmed.starts_with("~~~")
475                || trimmed.starts_with('{')
476                || trimmed.starts_with('[')
477                || trimmed.starts_with('|')
478                || trimmed.starts_with('$')
479                || trimmed.starts_with('>')
480                || trimmed.starts_with('#')
481                || (line.len() >= 4 && line.starts_with("    "))
482        })
483        .count();
484    #[allow(clippy::cast_precision_loss)]
485    let ratio = structured as f32 / lines.len() as f32;
486    if ratio > 0.5 {
487        ContentDensity::High
488    } else {
489        ContentDensity::Low
490    }
491}
492
493/// Partition messages into (high-density, low-density) groups.
494#[must_use]
495pub fn partition_by_density(messages: &[Message]) -> (Vec<Message>, Vec<Message>) {
496    let mut high = Vec::new();
497    let mut low = Vec::new();
498    for msg in messages {
499        if msg.metadata.focus_pinned {
500            continue;
501        }
502        match classify_density(&msg.content) {
503            ContentDensity::High => high.push(msg.clone()),
504            ContentDensity::Low => low.push(msg.clone()),
505        }
506    }
507    (high, low)
508}
509
510// ── SubgoalExtractionResult ───────────────────────────────────────────────────
511
512/// Output of a background subgoal extraction LLM call.
513#[derive(Debug)]
514pub struct SubgoalExtractionResult {
515    /// Current subgoal the agent is working toward.
516    pub current: String,
517    /// Just-completed subgoal, if the LLM detected a transition (`COMPLETED:` non-NONE).
518    pub completed: Option<String>,
519}
520
521// ── Focus auto-consolidation ──────────────────────────────────────────────────
522
523/// Automatically consolidate low-relevance context into a knowledge-block summary.
524///
525/// # Errors
526///
527/// Returns an error if the provider call returns an error or if the 20-second timeout
528/// elapses before the provider responds.
529#[tracing::instrument(name = "ctx.compaction.focus_auto_consolidate", skip_all)]
530pub async fn run_focus_auto_consolidation(
531    messages: &[Message],
532    min_window: usize,
533    provider: impl LlmProvider,
534    max_chars: usize,
535) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
536    if messages.len() < min_window {
537        return Ok(None);
538    }
539    let task_goal = messages
540        .iter()
541        .rev()
542        .find(|m| m.role == Role::User)
543        .map_or("", |m| m.content.as_str());
544    if task_goal.is_empty() {
545        tracing::debug!("focus_auto_consolidation: no user message found, skipping");
546        return Ok(None);
547    }
548    let messages_owned: Vec<Message> = messages.to_vec();
549    let task_goal_owned = task_goal.to_string();
550    let scores = tokio::task::spawn_blocking(move || {
551        let tc = TokenCounter::default();
552        score_blocks_mig(
553            &messages_owned,
554            Some(task_goal_owned.as_str()).filter(|s| !s.is_empty()),
555            &tc,
556        )
557    })
558    .await
559    .map_err(|e| format!("score_blocks_mig panicked: {e}"))?;
560
561    let low_relevance: HashSet<usize> = scores
562        .iter()
563        .filter(|s| s.mig <= 0.0)
564        .map(|s| s.msg_index)
565        .collect();
566    let window_indices = find_low_relevance_window(messages, &low_relevance, min_window);
567    if window_indices.is_empty() {
568        return Ok(None);
569    }
570    let combined: String = window_indices
571        .iter()
572        .map(|&i| extract_scorable_text(&messages[i]))
573        .collect::<Vec<_>>()
574        .join("\n---\n");
575    let prompt = format!(
576        "Extract up to 10 key facts the agent must remember from the following context. \
577         Return bullet points only (one per line, starting with `- `).\n\n{combined}"
578    );
579    let request = vec![Message::from_legacy(Role::User, &prompt)];
580    let raw = tokio::time::timeout(Duration::from_secs(20), provider.chat(&request))
581        .await
582        .map_err(|_| {
583            Box::new(std::io::Error::other(
584                "focus auto-consolidation timed out after 20s",
585            )) as Box<dyn std::error::Error + Send + Sync>
586        })?
587        .map_err(|e| {
588            Box::new(std::io::Error::other(format!(
589                "focus auto-consolidation provider error: {e}"
590            ))) as Box<dyn std::error::Error + Send + Sync>
591        })?;
592    let truncated = if raw.len() <= max_chars {
593        raw
594    } else {
595        let boundary = raw
596            .char_indices()
597            .map(|(i, _)| i)
598            .take_while(|&i| i <= max_chars)
599            .last()
600            .unwrap_or(0);
601        raw[..boundary].to_owned()
602    };
603    if truncated.is_empty() {
604        return Ok(None);
605    }
606    Ok(Some(truncated))
607}
608
609fn find_low_relevance_window(
610    messages: &[Message],
611    low_relevance: &HashSet<usize>,
612    min_window: usize,
613) -> Vec<usize> {
614    let mut best: Vec<usize> = Vec::new();
615    let mut current: Vec<usize> = Vec::new();
616    for (i, msg) in messages.iter().enumerate() {
617        if i == 0 || msg.metadata.focus_pinned {
618            current.clear();
619            continue;
620        }
621        if low_relevance.contains(&i) {
622            current.push(i);
623        } else {
624            if current.len() >= min_window && best.is_empty() {
625                best.append(&mut current);
626            }
627            current.clear();
628        }
629    }
630    if current.len() >= min_window && best.is_empty() {
631        best = current;
632    }
633    best
634}
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639    use std::collections::HashMap;
640
641    #[test]
642    fn tokenize_filters_stop_words() {
643        let tokens = tokenize("fn main() { let x = 5; }");
644        assert!(!tokens.contains(&"fn".to_string()));
645        assert!(!tokens.contains(&"let".to_string()));
646    }
647
648    #[test]
649    fn tokenize_keeps_meaningful_tokens() {
650        let tokens = tokenize("authentication middleware session");
651        assert!(tokens.contains(&"authentication".to_string()));
652        assert!(tokens.contains(&"middleware".to_string()));
653        assert!(tokens.contains(&"session".to_string()));
654    }
655
656    #[test]
657    fn tf_weighted_similarity_identical_is_one() {
658        let tokens = tokenize("authentication session token");
659        let tf = term_frequencies(&tokens);
660        let sim = tf_weighted_similarity(&tf, &tf);
661        assert!((sim - 1.0).abs() < f32::EPSILON);
662    }
663
664    #[test]
665    fn tf_weighted_similarity_disjoint_is_zero() {
666        let tokens_a = tokenize("authentication session");
667        let tokens_b = tokenize("database migration schema");
668        let tf_a = term_frequencies(&tokens_a);
669        let tf_b = term_frequencies(&tokens_b);
670        assert!(tf_weighted_similarity(&tf_a, &tf_b).abs() < f32::EPSILON);
671    }
672
673    #[test]
674    fn tf_weighted_similarity_empty_is_zero() {
675        let tf_empty: HashMap<String, f32> = HashMap::new();
676        let tokens = tokenize("authentication session");
677        let tf = term_frequencies(&tokens);
678        assert!(tf_weighted_similarity(&tf_empty, &tf).abs() < f32::EPSILON);
679    }
680
681    fn make_tool_output_msg(body: &str) -> Message {
682        use zeph_llm::provider::{MessageMetadata, MessagePart};
683        let mut msg = Message {
684            role: Role::User,
685            content: body.to_string(),
686            parts: vec![MessagePart::ToolOutput {
687                tool_name: "read".into(),
688                body: body.to_string(),
689                compacted_at: None,
690            }],
691            metadata: MessageMetadata::default(),
692        };
693        msg.rebuild_content();
694        msg
695    }
696
697    #[test]
698    fn score_blocks_task_aware_skips_system_prompt() {
699        let tc = TokenCounter::default();
700        let messages = vec![
701            Message::from_legacy(Role::System, "system prompt"),
702            make_tool_output_msg("authentication session middleware"),
703        ];
704        let scores = score_blocks_task_aware(&messages, "authentication session", &tc);
705        assert_eq!(scores.len(), 1);
706        assert_eq!(scores[0].msg_index, 1);
707    }
708
709    #[test]
710    fn score_blocks_task_aware_skips_pinned_messages() {
711        use zeph_llm::provider::MessageMetadata;
712        let tc = TokenCounter::default();
713        let mut pinned_meta = MessageMetadata::focus_pinned();
714        pinned_meta.focus_pinned = true;
715        let pinned = Message {
716            role: Role::System,
717            content: "authentication session knowledge".to_string(),
718            parts: vec![],
719            metadata: pinned_meta,
720        };
721        let messages = vec![
722            Message::from_legacy(Role::System, "sys"),
723            pinned,
724            make_tool_output_msg("authentication session"),
725        ];
726        let scores = score_blocks_task_aware(&messages, "authentication session", &tc);
727        assert!(scores.iter().all(|s| s.msg_index != 1));
728    }
729
730    #[test]
731    fn score_blocks_task_aware_relevant_block_scores_higher() {
732        let tc = TokenCounter::default();
733        let messages = vec![
734            Message::from_legacy(Role::System, "sys"),
735            make_tool_output_msg("authentication middleware session token implementation"),
736            make_tool_output_msg("database schema migration foreign key index"),
737        ];
738        let scores = score_blocks_task_aware(&messages, "authentication session token", &tc);
739        assert_eq!(scores.len(), 2);
740        let auth_score = scores.iter().find(|s| s.msg_index == 1).unwrap();
741        let db_score = scores.iter().find(|s| s.msg_index == 2).unwrap();
742        assert!(
743            auth_score.relevance > db_score.relevance,
744            "auth block must score higher than db block"
745        );
746    }
747
748    #[test]
749    fn subgoal_registry_push_active_creates_active_subgoal() {
750        let mut registry = SubgoalRegistry::default();
751        let id = registry.push_active("Implement login endpoint".into(), 1);
752        assert_eq!(registry.subgoals.len(), 1);
753        assert_eq!(registry.subgoals[0].id, id);
754        assert_eq!(registry.subgoals[0].state, SubgoalState::Active);
755    }
756
757    #[test]
758    fn subgoal_registry_complete_active_transitions_state() {
759        let mut registry = SubgoalRegistry::default();
760        registry.push_active("initial subgoal".into(), 1);
761        registry.complete_active(5);
762        assert_eq!(registry.subgoals[0].state, SubgoalState::Completed);
763        assert!(registry.active_subgoal().is_none());
764    }
765
766    #[test]
767    fn subgoal_registry_push_active_auto_completes_existing_active() {
768        let mut registry = SubgoalRegistry::default();
769        registry.push_active("first subgoal".into(), 1);
770        registry.push_active("second subgoal".into(), 6);
771        assert_eq!(registry.subgoals[0].state, SubgoalState::Completed);
772        assert_eq!(registry.subgoals[1].state, SubgoalState::Active);
773        let active_count = registry
774            .subgoals
775            .iter()
776            .filter(|s| s.state == SubgoalState::Active)
777            .count();
778        assert_eq!(active_count, 1);
779    }
780
781    #[test]
782    fn subgoal_registry_extend_active_tags_incrementally() {
783        let mut registry = SubgoalRegistry::default();
784        let id = registry.push_active("subgoal".into(), 3);
785        registry.extend_active(5);
786        assert_eq!(registry.subgoal_state(3), Some(SubgoalState::Active));
787        assert_eq!(registry.subgoal_state(4), Some(SubgoalState::Active));
788        assert_eq!(registry.subgoal_state(5), Some(SubgoalState::Active));
789        assert_eq!(registry.msg_to_subgoal.get(&3), Some(&id));
790        registry.extend_active(7);
791        assert_eq!(registry.subgoal_state(6), Some(SubgoalState::Active));
792        assert_eq!(registry.subgoal_state(7), Some(SubgoalState::Active));
793        assert_eq!(registry.msg_to_subgoal.len(), 5);
794    }
795
796    #[test]
797    fn subgoal_registry_subgoal_state_returns_correct_tier() {
798        let mut registry = SubgoalRegistry::default();
799        registry.push_active("completed subgoal".into(), 1);
800        registry.tag_range(1, 5, SubgoalId(0));
801        registry.complete_active(5);
802        registry.push_active("active subgoal".into(), 6);
803        registry.extend_active(9);
804        assert_eq!(registry.subgoal_state(1), Some(SubgoalState::Completed));
805        assert_eq!(registry.subgoal_state(6), Some(SubgoalState::Active));
806        assert_eq!(registry.subgoal_state(0), None);
807    }
808
809    #[test]
810    fn classify_density_empty_string_is_low() {
811        assert_eq!(classify_density(""), ContentDensity::Low);
812    }
813
814    #[test]
815    fn classify_density_all_structured_is_high() {
816        let content = "```rust\nfn main() {}\n```\n$ cargo build\n";
817        assert_eq!(classify_density(content), ContentDensity::High);
818    }
819
820    #[test]
821    fn classify_density_all_prose_is_low() {
822        let content = "This is a sentence.\nAnother sentence here.\nNo structured content at all.";
823        assert_eq!(classify_density(content), ContentDensity::Low);
824    }
825
826    // ─── run_focus_auto_consolidation tests ──────────────────────────────────
827
828    struct StubProvider {
829        response: &'static str,
830    }
831
832    impl zeph_llm::provider::LlmProvider for StubProvider {
833        async fn chat(&self, _messages: &[Message]) -> Result<String, zeph_llm::LlmError> {
834            Ok(self.response.to_owned())
835        }
836
837        async fn chat_stream(
838            &self,
839            messages: &[Message],
840        ) -> Result<zeph_llm::provider::ChatStream, zeph_llm::LlmError> {
841            let r = self.chat(messages).await?;
842            Ok(Box::pin(futures::stream::once(async move {
843                Ok::<_, zeph_llm::LlmError>(zeph_llm::provider::StreamChunk::Content(r))
844            })))
845        }
846
847        fn supports_streaming(&self) -> bool {
848            false
849        }
850
851        async fn embed(&self, _text: &str) -> Result<Vec<f32>, zeph_llm::LlmError> {
852            Ok(vec![])
853        }
854
855        fn supports_embeddings(&self) -> bool {
856            false
857        }
858
859        fn name(&self) -> &'static str {
860            "stub"
861        }
862    }
863
864    struct HangingProvider;
865
866    impl zeph_llm::provider::LlmProvider for HangingProvider {
867        async fn chat(&self, _messages: &[Message]) -> Result<String, zeph_llm::LlmError> {
868            std::future::pending::<()>().await;
869            unreachable!()
870        }
871
872        async fn chat_stream(
873            &self,
874            _messages: &[Message],
875        ) -> Result<zeph_llm::provider::ChatStream, zeph_llm::LlmError> {
876            std::future::pending::<()>().await;
877            unreachable!()
878        }
879
880        fn supports_streaming(&self) -> bool {
881            false
882        }
883
884        async fn embed(&self, _text: &str) -> Result<Vec<f32>, zeph_llm::LlmError> {
885            Ok(vec![])
886        }
887
888        fn supports_embeddings(&self) -> bool {
889            false
890        }
891
892        fn name(&self) -> &'static str {
893            "hanging"
894        }
895    }
896
897    #[tokio::test]
898    async fn run_focus_auto_consolidation_returns_none_for_small_history() {
899        let messages = vec![
900            Message::from_legacy(Role::System, "sys"),
901            make_tool_output_msg("some tool output here"),
902        ];
903        // min_window = 6, but only 2 messages → None.
904        let result = run_focus_auto_consolidation(
905            &messages,
906            6,
907            StubProvider {
908                response: "- fact one",
909            },
910            4096,
911        )
912        .await
913        .unwrap();
914        assert!(result.is_none());
915    }
916
917    #[tokio::test]
918    async fn run_focus_auto_consolidation_produces_summary() {
919        let mut messages = vec![Message::from_legacy(Role::System, "sys")];
920        for _ in 0..6 {
921            messages.push(make_tool_output_msg(
922                "database schema migration foreign key index",
923            ));
924        }
925        messages.push(Message::from_legacy(
926            Role::User,
927            "Help me with authentication",
928        ));
929
930        let result = run_focus_auto_consolidation(
931            &messages,
932            4,
933            StubProvider {
934                response: "- database schema uses foreign keys",
935            },
936            4096,
937        )
938        .await
939        .unwrap();
940
941        assert!(result.is_some());
942        let summary = result.unwrap();
943        assert!(!summary.is_empty());
944    }
945
946    #[tokio::test]
947    async fn run_focus_auto_consolidation_skips_when_no_user_message() {
948        // S2/S3: when no User message is present, must return None instead of
949        // entering recency mode and eagerly consolidating all history.
950        let mut messages = vec![Message::from_legacy(Role::System, "sys")];
951        for i in 0..8 {
952            messages.push(make_tool_output_msg(&format!("tool output {i}")));
953        }
954
955        let result = run_focus_auto_consolidation(
956            &messages,
957            4,
958            StubProvider {
959                response: "- should not be reached",
960            },
961            4096,
962        )
963        .await
964        .unwrap();
965
966        assert!(
967            result.is_none(),
968            "must return None when no user message is present (S2/S3)"
969        );
970    }
971
972    #[tokio::test]
973    async fn auto_consolidation_timeout_recovers() {
974        let mut messages = vec![Message::from_legacy(Role::System, "sys")];
975        for _ in 0..6 {
976            messages.push(make_tool_output_msg(
977                "database schema migration foreign key index",
978            ));
979        }
980        messages.push(Message::from_legacy(
981            Role::User,
982            "Help me with authentication",
983        ));
984
985        // Wrap in a short timeout to avoid waiting the full 20s internal timeout.
986        let result = tokio::time::timeout(
987            std::time::Duration::from_millis(50),
988            run_focus_auto_consolidation(&messages, 4, HangingProvider, 4096),
989        )
990        .await;
991
992        // Either: outer timeout fires (Err), or inner 20s timeout fires (Ok(Err)).
993        // Both cases must not panic.
994        match result {
995            Err(_elapsed) => {
996                // Outer timeout fired — no panic, correct.
997            }
998            Ok(inner) => {
999                assert!(inner.is_err(), "hanging provider must return an error");
1000            }
1001        }
1002    }
1003}