Skip to main content

rustant_core/
summarizer.rs

1//! LLM-based context summarization.
2//!
3//! When the conversation grows beyond the context window, older messages
4//! are summarized into a compact representation to preserve important context
5//! while reducing token usage.
6
7use crate::brain::{Brain, LlmProvider};
8use crate::types::{CompletionRequest, Content, Message, Role};
9use std::sync::Arc;
10
11/// Summary of conversation context for compression.
12#[derive(Debug, Clone)]
13pub struct ContextSummary {
14    /// The generated summary text.
15    pub text: String,
16    /// Number of messages that were summarized.
17    pub messages_summarized: usize,
18    /// Estimated tokens saved.
19    pub tokens_saved: usize,
20}
21
22/// Generates summaries of conversation history using the LLM.
23pub struct ContextSummarizer {
24    /// LLM provider for generating summaries.
25    provider: Arc<dyn LlmProvider>,
26}
27
28impl ContextSummarizer {
29    /// Create a new summarizer with the given LLM provider.
30    pub fn new(provider: Arc<dyn LlmProvider>) -> Self {
31        Self { provider }
32    }
33
34    /// Generate a summary of the given messages.
35    pub async fn summarize(&self, messages: &[Message]) -> Result<ContextSummary, SummarizeError> {
36        if messages.is_empty() {
37            return Ok(ContextSummary {
38                text: String::new(),
39                messages_summarized: 0,
40                tokens_saved: 0,
41            });
42        }
43
44        let prompt = build_summarization_prompt(messages);
45
46        let request = CompletionRequest {
47            messages: vec![Message::user(prompt)],
48            tools: None,
49            temperature: 0.3,
50            max_tokens: Some(500),
51            stop_sequences: Vec::new(),
52            model: None,
53        };
54
55        let response = self
56            .provider
57            .complete(request)
58            .await
59            .map_err(|e| SummarizeError::LlmError(e.to_string()))?;
60
61        let summary_text = match &response.message.content {
62            Content::Text { text } => text.clone(),
63            _ => String::from("[Summary unavailable]"),
64        };
65
66        // Estimate tokens saved (rough: original messages minus summary)
67        let original_tokens: usize = messages.iter().map(estimate_message_tokens).sum();
68        let summary_tokens = summary_text.len() / 4; // rough estimate
69
70        Ok(ContextSummary {
71            text: summary_text,
72            messages_summarized: messages.len(),
73            tokens_saved: original_tokens.saturating_sub(summary_tokens),
74        })
75    }
76
77    /// Check if summarization is needed based on context usage.
78    pub fn should_summarize(context_ratio: f32, threshold: f32) -> bool {
79        context_ratio >= threshold
80    }
81}
82
83/// Build the prompt for summarizing messages.
84fn build_summarization_prompt(messages: &[Message]) -> String {
85    let mut prompt = String::from(
86        "Summarize the following conversation concisely, preserving:\n\
87         - Key decisions and conclusions\n\
88         - Important facts and data points\n\
89         - Tool results and their outcomes\n\
90         - Current task goals and progress\n\n\
91         Conversation:\n",
92    );
93
94    for msg in messages {
95        let role = match msg.role {
96            Role::User => "User",
97            Role::Assistant => "Assistant",
98            Role::System => "System",
99            Role::Tool => "Tool",
100        };
101        let text = match &msg.content {
102            Content::Text { text } => text.clone(),
103            Content::ToolCall {
104                name, arguments, ..
105            } => format!("[Tool Call: {} ({})]", name, arguments),
106            Content::ToolResult { output, .. } => {
107                format!("[Tool Result: {}]", output)
108            }
109            Content::MultiPart { parts } => parts
110                .iter()
111                .filter_map(|p| {
112                    if let Content::Text { text } = p {
113                        Some(text.as_str())
114                    } else {
115                        None
116                    }
117                })
118                .collect::<Vec<_>>()
119                .join(" "),
120        };
121        prompt.push_str(&format!("{}: {}\n", role, text));
122    }
123
124    prompt.push_str("\nProvide a concise summary (3-5 sentences) capturing the essential context:");
125    prompt
126}
127
128/// Rough token estimation for a message.
129fn estimate_message_tokens(msg: &Message) -> usize {
130    let text_len = match &msg.content {
131        Content::Text { text } => text.len(),
132        Content::ToolCall { arguments, .. } => arguments.to_string().len(),
133        Content::ToolResult { output, .. } => output.len(),
134        Content::MultiPart { parts } => parts
135            .iter()
136            .map(|p| match p {
137                Content::Text { text } => text.len(),
138                _ => 0,
139            })
140            .sum(),
141    };
142    text_len / 4 + 4 // rough: 4 chars per token + overhead
143}
144
145/// Errors during summarization.
146#[derive(Debug, thiserror::Error)]
147pub enum SummarizeError {
148    #[error("LLM error during summarization: {0}")]
149    LlmError(String),
150}
151
152/// Token budget alerts.
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum TokenAlert {
155    /// Context is within normal range.
156    Normal,
157    /// Context is getting large (> 50%).
158    Warning,
159    /// Context is critically full (> 80%).
160    Critical,
161    /// Context is near overflow (> 95%).
162    Overflow,
163}
164
165impl TokenAlert {
166    /// Determine the alert level from a context ratio.
167    pub fn from_ratio(ratio: f32) -> Self {
168        if ratio > 0.95 {
169            TokenAlert::Overflow
170        } else if ratio > 0.80 {
171            TokenAlert::Critical
172        } else if ratio > 0.50 {
173            TokenAlert::Warning
174        } else {
175            TokenAlert::Normal
176        }
177    }
178}
179
180impl std::fmt::Display for TokenAlert {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        match self {
183            TokenAlert::Normal => write!(f, "OK"),
184            TokenAlert::Warning => write!(f, "WARNING"),
185            TokenAlert::Critical => write!(f, "CRITICAL"),
186            TokenAlert::Overflow => write!(f, "OVERFLOW"),
187        }
188    }
189}
190
191/// Token and cost tracking display data.
192#[derive(Debug, Clone)]
193pub struct TokenCostDisplay {
194    /// Total input tokens used.
195    pub input_tokens: usize,
196    /// Total output tokens used.
197    pub output_tokens: usize,
198    /// Total tokens.
199    pub total_tokens: usize,
200    /// Context window size.
201    pub context_window: usize,
202    /// Context fill ratio.
203    pub context_ratio: f32,
204    /// Total cost in USD.
205    pub total_cost: f64,
206    /// Alert level.
207    pub alert: TokenAlert,
208}
209
210impl TokenCostDisplay {
211    /// Create from brain statistics.
212    ///
213    /// Uses total usage tokens as a ratio against the context window.
214    pub fn from_brain(brain: &Brain) -> Self {
215        let usage = brain.total_usage();
216        let cost = brain.total_cost();
217        let context_window = brain.context_window();
218        let ratio = if context_window > 0 {
219            usage.total() as f32 / context_window as f32
220        } else {
221            0.0
222        };
223
224        Self {
225            input_tokens: usage.input_tokens,
226            output_tokens: usage.output_tokens,
227            total_tokens: usage.total(),
228            context_window,
229            context_ratio: ratio,
230            total_cost: cost.total(),
231            alert: TokenAlert::from_ratio(ratio),
232        }
233    }
234
235    /// Format as a display string.
236    pub fn format_display(&self) -> String {
237        format!(
238            "Tokens: {} in / {} out ({} total) | Context: {:.0}% of {} | Cost: ${:.4} | {}",
239            self.input_tokens,
240            self.output_tokens,
241            self.total_tokens,
242            self.context_ratio * 100.0,
243            self.context_window,
244            self.total_cost,
245            self.alert,
246        )
247    }
248}
249
250/// Truncate a string to at most `max` characters (byte-safe via char boundary).
251fn truncate_str(s: &str, max: usize) -> &str {
252    if s.len() <= max {
253        return s;
254    }
255    // Find a valid char boundary at or before `max`
256    let mut end = max;
257    while end > 0 && !s.is_char_boundary(end) {
258        end -= 1;
259    }
260    &s[..end]
261}
262
263/// Smart fallback summary that preserves structured information when LLM-based
264/// summarization fails. Instead of naive truncation, it extracts tool names,
265/// results, and preserves the first/last messages for continuity.
266pub fn smart_fallback_summary(messages: &[Message], max_chars: usize) -> String {
267    if messages.is_empty() {
268        return String::new();
269    }
270
271    let quarter = max_chars / 4;
272    let mut parts = Vec::new();
273
274    // Always include first message (the initial request context)
275    if let Some(first) = messages.first() {
276        if let Some(text) = first.content.as_text() {
277            parts.push(format!("[Start] {}", truncate_str(text, quarter)));
278        }
279    }
280
281    // Extract tool call summaries (tool name + brief result)
282    for msg in messages.iter() {
283        match &msg.content {
284            Content::ToolCall { name, .. } => {
285                parts.push(format!("[Tool: {}]", name));
286            }
287            Content::ToolResult { output, .. } => {
288                parts.push(format!("[Result: {}]", truncate_str(output, 80)));
289            }
290            _ => {}
291        }
292    }
293
294    // Always include last message if different from first
295    if messages.len() > 1 {
296        if let Some(last) = messages.last() {
297            if let Some(text) = last.content.as_text() {
298                parts.push(format!("[Latest] {}", truncate_str(text, quarter)));
299            }
300        }
301    }
302
303    let joined = parts.join("\n");
304    if joined.len() > max_chars {
305        format!("{}...", truncate_str(&joined, max_chars))
306    } else {
307        joined
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::MockLlmProvider;
315
316    #[test]
317    fn test_token_alert_from_ratio() {
318        assert_eq!(TokenAlert::from_ratio(0.0), TokenAlert::Normal);
319        assert_eq!(TokenAlert::from_ratio(0.3), TokenAlert::Normal);
320        assert_eq!(TokenAlert::from_ratio(0.51), TokenAlert::Warning);
321        assert_eq!(TokenAlert::from_ratio(0.81), TokenAlert::Critical);
322        assert_eq!(TokenAlert::from_ratio(0.96), TokenAlert::Overflow);
323    }
324
325    #[test]
326    fn test_token_alert_display() {
327        assert_eq!(TokenAlert::Normal.to_string(), "OK");
328        assert_eq!(TokenAlert::Warning.to_string(), "WARNING");
329        assert_eq!(TokenAlert::Critical.to_string(), "CRITICAL");
330        assert_eq!(TokenAlert::Overflow.to_string(), "OVERFLOW");
331    }
332
333    #[test]
334    fn test_should_summarize() {
335        assert!(!ContextSummarizer::should_summarize(0.5, 0.8));
336        assert!(ContextSummarizer::should_summarize(0.85, 0.8));
337        assert!(ContextSummarizer::should_summarize(1.0, 0.8));
338    }
339
340    #[test]
341    fn test_build_summarization_prompt() {
342        let messages = vec![Message::user("Hello"), Message::assistant("Hi there")];
343        let prompt = build_summarization_prompt(&messages);
344        assert!(prompt.contains("User: Hello"));
345        assert!(prompt.contains("Assistant: Hi there"));
346        assert!(prompt.contains("Summarize"));
347    }
348
349    #[test]
350    fn test_estimate_message_tokens() {
351        let msg = Message::user("Hello world, this is a test message");
352        let tokens = estimate_message_tokens(&msg);
353        assert!(tokens > 0);
354    }
355
356    #[tokio::test]
357    async fn test_summarize_empty() {
358        let provider = Arc::new(MockLlmProvider::new());
359        let summarizer = ContextSummarizer::new(provider);
360        let result = summarizer.summarize(&[]).await.unwrap();
361        assert_eq!(result.messages_summarized, 0);
362        assert!(result.text.is_empty());
363    }
364
365    #[tokio::test]
366    async fn test_summarize_messages() {
367        let provider = Arc::new(MockLlmProvider::new());
368        let summarizer = ContextSummarizer::new(provider);
369        let messages = vec![
370            Message::user("Write a function"),
371            Message::assistant("Here's the function..."),
372        ];
373        let result = summarizer.summarize(&messages).await.unwrap();
374        assert_eq!(result.messages_summarized, 2);
375        assert!(!result.text.is_empty());
376    }
377
378    #[test]
379    fn test_token_cost_display_format() {
380        let display = TokenCostDisplay {
381            input_tokens: 1000,
382            output_tokens: 500,
383            total_tokens: 1500,
384            context_window: 128000,
385            context_ratio: 0.45,
386            total_cost: 0.0123,
387            alert: TokenAlert::Normal,
388        };
389        let formatted = display.format_display();
390        assert!(formatted.contains("1000 in"));
391        assert!(formatted.contains("500 out"));
392        assert!(formatted.contains("$0.0123"));
393        assert!(formatted.contains("OK"));
394    }
395
396    // --- Gap 2: Smart fallback summary tests ---
397
398    #[test]
399    fn test_smart_fallback_preserves_tool_names() {
400        let messages = vec![
401            Message::user("fix the bug"),
402            Message::new(
403                Role::Assistant,
404                Content::tool_call(
405                    "c1",
406                    "file_read",
407                    serde_json::json!({"path": "src/main.rs"}),
408                ),
409            ),
410            Message::new(
411                Role::Tool,
412                Content::tool_result("c1", "fn main() { println!(\"hello\"); }", false),
413            ),
414            Message::assistant("I found the issue."),
415        ];
416
417        let summary = smart_fallback_summary(&messages, 500);
418
419        assert!(
420            summary.contains("file_read"),
421            "Summary should contain tool name: {}",
422            summary
423        );
424        assert!(
425            summary.contains("fix the bug"),
426            "Summary should contain first message: {}",
427            summary
428        );
429    }
430
431    #[test]
432    fn test_smart_fallback_preserves_first_and_last() {
433        let messages = vec![
434            Message::user("initial request about authentication"),
435            Message::assistant("Let me look into that."),
436            Message::user("follow up about tokens"),
437            Message::assistant("Here is the solution for token handling"),
438        ];
439
440        let summary = smart_fallback_summary(&messages, 500);
441
442        assert!(
443            summary.contains("initial request"),
444            "Summary should contain first message: {}",
445            summary
446        );
447        assert!(
448            summary.contains("token handling"),
449            "Summary should contain last message: {}",
450            summary
451        );
452    }
453
454    #[test]
455    fn test_smart_fallback_respects_limit() {
456        let long_text = "a".repeat(1000);
457        let messages = vec![Message::user(&long_text)];
458
459        let summary = smart_fallback_summary(&messages, 100);
460
461        assert!(
462            summary.len() <= 110, // small margin for formatting
463            "Summary should respect limit: len={} > 110",
464            summary.len()
465        );
466    }
467
468    #[test]
469    fn test_smart_fallback_empty_messages() {
470        let summary = smart_fallback_summary(&[], 500);
471        assert!(
472            summary.is_empty(),
473            "Empty messages should give empty summary"
474        );
475    }
476
477    #[test]
478    fn test_smart_fallback_different_limits() {
479        let messages = vec![Message::user("x".repeat(1000))];
480
481        let short = smart_fallback_summary(&messages, 50);
482        let long = smart_fallback_summary(&messages, 800);
483
484        assert!(short.len() <= 60);
485        assert!(long.len() <= 810);
486        assert!(long.len() > short.len());
487    }
488}