Skip to main content

rustant_core/
summarizer.rs

1//! LLM-based context summarization.
2//!
3//! When the conversation grows beyond the context window, older messages
4//! are summarized into a compact representation to preserve important context
5//! while reducing token usage.
6
7use crate::brain::{Brain, LlmProvider};
8use crate::types::{CompletionRequest, Content, Message, Role};
9use std::sync::Arc;
10
11/// Summary of conversation context for compression.
12#[derive(Debug, Clone)]
13pub struct ContextSummary {
14    /// The generated summary text.
15    pub text: String,
16    /// Number of messages that were summarized.
17    pub messages_summarized: usize,
18    /// Estimated tokens saved.
19    pub tokens_saved: usize,
20}
21
22/// Generates summaries of conversation history using the LLM.
23pub struct ContextSummarizer {
24    /// LLM provider for generating summaries.
25    provider: Arc<dyn LlmProvider>,
26}
27
28impl ContextSummarizer {
29    /// Create a new summarizer with the given LLM provider.
30    pub fn new(provider: Arc<dyn LlmProvider>) -> Self {
31        Self { provider }
32    }
33
34    /// Generate a summary of the given messages.
35    pub async fn summarize(&self, messages: &[Message]) -> Result<ContextSummary, SummarizeError> {
36        if messages.is_empty() {
37            return Ok(ContextSummary {
38                text: String::new(),
39                messages_summarized: 0,
40                tokens_saved: 0,
41            });
42        }
43
44        let prompt = build_summarization_prompt(messages);
45
46        let request = CompletionRequest {
47            messages: vec![Message::user(prompt)],
48            tools: None,
49            temperature: 0.3,
50            max_tokens: Some(500),
51            stop_sequences: Vec::new(),
52            model: None,
53        };
54
55        let response = self
56            .provider
57            .complete(request)
58            .await
59            .map_err(|e| SummarizeError::LlmError(e.to_string()))?;
60
61        let summary_text = match &response.message.content {
62            Content::Text { text } => text.clone(),
63            _ => String::from("[Summary unavailable]"),
64        };
65
66        // Estimate tokens saved (rough: original messages minus summary)
67        let original_tokens: usize = messages.iter().map(estimate_message_tokens).sum();
68        let summary_tokens = summary_text.len() / 4; // rough estimate
69
70        Ok(ContextSummary {
71            text: summary_text,
72            messages_summarized: messages.len(),
73            tokens_saved: original_tokens.saturating_sub(summary_tokens),
74        })
75    }
76
77    /// Check if summarization is needed based on context usage.
78    pub fn should_summarize(context_ratio: f32, threshold: f32) -> bool {
79        context_ratio >= threshold
80    }
81}
82
83/// Build the prompt for summarizing messages.
84fn build_summarization_prompt(messages: &[Message]) -> String {
85    let mut prompt = String::from(
86        "Summarize the following conversation concisely, preserving:\n\
87         - Key decisions and conclusions\n\
88         - Important facts and data points\n\
89         - Tool results and their outcomes\n\
90         - Current task goals and progress\n\n\
91         Conversation:\n",
92    );
93
94    for msg in messages {
95        let role = match msg.role {
96            Role::User => "User",
97            Role::Assistant => "Assistant",
98            Role::System => "System",
99            Role::Tool => "Tool",
100        };
101        let text = match &msg.content {
102            Content::Text { text } => text.clone(),
103            Content::ToolCall {
104                name, arguments, ..
105            } => format!("[Tool Call: {} ({})]", name, arguments),
106            Content::ToolResult { output, .. } => {
107                format!("[Tool Result: {}]", output)
108            }
109            Content::MultiPart { parts } => parts
110                .iter()
111                .filter_map(|p| {
112                    if let Content::Text { text } = p {
113                        Some(text.as_str())
114                    } else {
115                        None
116                    }
117                })
118                .collect::<Vec<_>>()
119                .join(" "),
120        };
121        prompt.push_str(&format!("{}: {}\n", role, text));
122    }
123
124    prompt.push_str("\nProvide a concise summary (3-5 sentences) capturing the essential context:");
125    prompt
126}
127
128/// Rough token estimation for a message.
129fn estimate_message_tokens(msg: &Message) -> usize {
130    let text_len = match &msg.content {
131        Content::Text { text } => text.len(),
132        Content::ToolCall { arguments, .. } => arguments.to_string().len(),
133        Content::ToolResult { output, .. } => output.len(),
134        Content::MultiPart { parts } => parts
135            .iter()
136            .map(|p| match p {
137                Content::Text { text } => text.len(),
138                _ => 0,
139            })
140            .sum(),
141    };
142    text_len / 4 + 4 // rough: 4 chars per token + overhead
143}
144
145/// Errors during summarization.
146#[derive(Debug, thiserror::Error)]
147pub enum SummarizeError {
148    #[error("LLM error during summarization: {0}")]
149    LlmError(String),
150}
151
152/// Token budget alerts.
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum TokenAlert {
155    /// Context is within normal range.
156    Normal,
157    /// Context is getting large (> 50%).
158    Warning,
159    /// Context is critically full (> 80%).
160    Critical,
161    /// Context is near overflow (> 95%).
162    Overflow,
163}
164
165impl TokenAlert {
166    /// Determine the alert level from a context ratio.
167    pub fn from_ratio(ratio: f32) -> Self {
168        if ratio > 0.95 {
169            TokenAlert::Overflow
170        } else if ratio > 0.80 {
171            TokenAlert::Critical
172        } else if ratio > 0.50 {
173            TokenAlert::Warning
174        } else {
175            TokenAlert::Normal
176        }
177    }
178}
179
180impl std::fmt::Display for TokenAlert {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        match self {
183            TokenAlert::Normal => write!(f, "OK"),
184            TokenAlert::Warning => write!(f, "WARNING"),
185            TokenAlert::Critical => write!(f, "CRITICAL"),
186            TokenAlert::Overflow => write!(f, "OVERFLOW"),
187        }
188    }
189}
190
191/// Token and cost tracking display data.
192#[derive(Debug, Clone)]
193pub struct TokenCostDisplay {
194    /// Total input tokens used.
195    pub input_tokens: usize,
196    /// Total output tokens used.
197    pub output_tokens: usize,
198    /// Total tokens.
199    pub total_tokens: usize,
200    /// Context window size.
201    pub context_window: usize,
202    /// Context fill ratio.
203    pub context_ratio: f32,
204    /// Total cost in USD.
205    pub total_cost: f64,
206    /// Alert level.
207    pub alert: TokenAlert,
208}
209
210impl TokenCostDisplay {
211    /// Create from brain statistics.
212    ///
213    /// Uses total usage tokens as a ratio against the context window.
214    pub fn from_brain(brain: &Brain) -> Self {
215        let usage = brain.total_usage();
216        let cost = brain.total_cost();
217        let context_window = brain.context_window();
218        let ratio = if context_window > 0 {
219            usage.total() as f32 / context_window as f32
220        } else {
221            0.0
222        };
223
224        Self {
225            input_tokens: usage.input_tokens,
226            output_tokens: usage.output_tokens,
227            total_tokens: usage.total(),
228            context_window,
229            context_ratio: ratio,
230            total_cost: cost.total(),
231            alert: TokenAlert::from_ratio(ratio),
232        }
233    }
234
235    /// Format as a display string.
236    pub fn format_display(&self) -> String {
237        format!(
238            "Tokens: {} in / {} out ({} total) | Context: {:.0}% of {} | Cost: ${:.4} | {}",
239            self.input_tokens,
240            self.output_tokens,
241            self.total_tokens,
242            self.context_ratio * 100.0,
243            self.context_window,
244            self.total_cost,
245            self.alert,
246        )
247    }
248}
249
250/// Truncate a string to at most `max` characters (byte-safe via char boundary).
251fn truncate_str(s: &str, max: usize) -> &str {
252    if s.len() <= max {
253        return s;
254    }
255    // Find a valid char boundary at or before `max`
256    let mut end = max;
257    while end > 0 && !s.is_char_boundary(end) {
258        end -= 1;
259    }
260    &s[..end]
261}
262
263/// Smart fallback summary that preserves structured information when LLM-based
264/// summarization fails. Instead of naive truncation, it extracts tool names,
265/// results, and preserves the first/last messages for continuity.
266pub fn smart_fallback_summary(messages: &[Message], max_chars: usize) -> String {
267    if messages.is_empty() {
268        return String::new();
269    }
270
271    let quarter = max_chars / 4;
272    let mut parts = Vec::new();
273
274    // Always include first message (the initial request context)
275    if let Some(first) = messages.first()
276        && let Some(text) = first.content.as_text()
277    {
278        parts.push(format!("[Start] {}", truncate_str(text, quarter)));
279    }
280
281    // Extract tool call summaries (tool name + brief result)
282    for msg in messages.iter() {
283        match &msg.content {
284            Content::ToolCall { name, .. } => {
285                parts.push(format!("[Tool: {}]", name));
286            }
287            Content::ToolResult { output, .. } => {
288                parts.push(format!("[Result: {}]", truncate_str(output, 80)));
289            }
290            _ => {}
291        }
292    }
293
294    // Always include last message if different from first
295    if messages.len() > 1
296        && let Some(last) = messages.last()
297        && let Some(text) = last.content.as_text()
298    {
299        parts.push(format!("[Latest] {}", truncate_str(text, quarter)));
300    }
301
302    let joined = parts.join("\n");
303    if joined.len() > max_chars {
304        format!("{}...", truncate_str(&joined, max_chars))
305    } else {
306        joined
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::MockLlmProvider;
314
315    #[test]
316    fn test_token_alert_from_ratio() {
317        assert_eq!(TokenAlert::from_ratio(0.0), TokenAlert::Normal);
318        assert_eq!(TokenAlert::from_ratio(0.3), TokenAlert::Normal);
319        assert_eq!(TokenAlert::from_ratio(0.51), TokenAlert::Warning);
320        assert_eq!(TokenAlert::from_ratio(0.81), TokenAlert::Critical);
321        assert_eq!(TokenAlert::from_ratio(0.96), TokenAlert::Overflow);
322    }
323
324    #[test]
325    fn test_token_alert_display() {
326        assert_eq!(TokenAlert::Normal.to_string(), "OK");
327        assert_eq!(TokenAlert::Warning.to_string(), "WARNING");
328        assert_eq!(TokenAlert::Critical.to_string(), "CRITICAL");
329        assert_eq!(TokenAlert::Overflow.to_string(), "OVERFLOW");
330    }
331
332    #[test]
333    fn test_should_summarize() {
334        assert!(!ContextSummarizer::should_summarize(0.5, 0.8));
335        assert!(ContextSummarizer::should_summarize(0.85, 0.8));
336        assert!(ContextSummarizer::should_summarize(1.0, 0.8));
337    }
338
339    #[test]
340    fn test_build_summarization_prompt() {
341        let messages = vec![Message::user("Hello"), Message::assistant("Hi there")];
342        let prompt = build_summarization_prompt(&messages);
343        assert!(prompt.contains("User: Hello"));
344        assert!(prompt.contains("Assistant: Hi there"));
345        assert!(prompt.contains("Summarize"));
346    }
347
348    #[test]
349    fn test_estimate_message_tokens() {
350        let msg = Message::user("Hello world, this is a test message");
351        let tokens = estimate_message_tokens(&msg);
352        assert!(tokens > 0);
353    }
354
355    #[tokio::test]
356    async fn test_summarize_empty() {
357        let provider = Arc::new(MockLlmProvider::new());
358        let summarizer = ContextSummarizer::new(provider);
359        let result = summarizer.summarize(&[]).await.unwrap();
360        assert_eq!(result.messages_summarized, 0);
361        assert!(result.text.is_empty());
362    }
363
364    #[tokio::test]
365    async fn test_summarize_messages() {
366        let provider = Arc::new(MockLlmProvider::new());
367        let summarizer = ContextSummarizer::new(provider);
368        let messages = vec![
369            Message::user("Write a function"),
370            Message::assistant("Here's the function..."),
371        ];
372        let result = summarizer.summarize(&messages).await.unwrap();
373        assert_eq!(result.messages_summarized, 2);
374        assert!(!result.text.is_empty());
375    }
376
377    #[test]
378    fn test_token_cost_display_format() {
379        let display = TokenCostDisplay {
380            input_tokens: 1000,
381            output_tokens: 500,
382            total_tokens: 1500,
383            context_window: 128000,
384            context_ratio: 0.45,
385            total_cost: 0.0123,
386            alert: TokenAlert::Normal,
387        };
388        let formatted = display.format_display();
389        assert!(formatted.contains("1000 in"));
390        assert!(formatted.contains("500 out"));
391        assert!(formatted.contains("$0.0123"));
392        assert!(formatted.contains("OK"));
393    }
394
395    // --- Gap 2: Smart fallback summary tests ---
396
397    #[test]
398    fn test_smart_fallback_preserves_tool_names() {
399        let messages = vec![
400            Message::user("fix the bug"),
401            Message::new(
402                Role::Assistant,
403                Content::tool_call(
404                    "c1",
405                    "file_read",
406                    serde_json::json!({"path": "src/main.rs"}),
407                ),
408            ),
409            Message::new(
410                Role::Tool,
411                Content::tool_result("c1", "fn main() { println!(\"hello\"); }", false),
412            ),
413            Message::assistant("I found the issue."),
414        ];
415
416        let summary = smart_fallback_summary(&messages, 500);
417
418        assert!(
419            summary.contains("file_read"),
420            "Summary should contain tool name: {}",
421            summary
422        );
423        assert!(
424            summary.contains("fix the bug"),
425            "Summary should contain first message: {}",
426            summary
427        );
428    }
429
430    #[test]
431    fn test_smart_fallback_preserves_first_and_last() {
432        let messages = vec![
433            Message::user("initial request about authentication"),
434            Message::assistant("Let me look into that."),
435            Message::user("follow up about tokens"),
436            Message::assistant("Here is the solution for token handling"),
437        ];
438
439        let summary = smart_fallback_summary(&messages, 500);
440
441        assert!(
442            summary.contains("initial request"),
443            "Summary should contain first message: {}",
444            summary
445        );
446        assert!(
447            summary.contains("token handling"),
448            "Summary should contain last message: {}",
449            summary
450        );
451    }
452
453    #[test]
454    fn test_smart_fallback_respects_limit() {
455        let long_text = "a".repeat(1000);
456        let messages = vec![Message::user(&long_text)];
457
458        let summary = smart_fallback_summary(&messages, 100);
459
460        assert!(
461            summary.len() <= 110, // small margin for formatting
462            "Summary should respect limit: len={} > 110",
463            summary.len()
464        );
465    }
466
467    #[test]
468    fn test_smart_fallback_empty_messages() {
469        let summary = smart_fallback_summary(&[], 500);
470        assert!(
471            summary.is_empty(),
472            "Empty messages should give empty summary"
473        );
474    }
475
476    #[test]
477    fn test_smart_fallback_different_limits() {
478        let messages = vec![Message::user("x".repeat(1000))];
479
480        let short = smart_fallback_summary(&messages, 50);
481        let long = smart_fallback_summary(&messages, 800);
482
483        assert!(short.len() <= 60);
484        assert!(long.len() <= 810);
485        assert!(long.len() > short.len());
486    }
487}