1use crate::brain::{Brain, LlmProvider};
8use crate::types::{CompletionRequest, Content, Message, Role};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub struct ContextSummary {
14 pub text: String,
16 pub messages_summarized: usize,
18 pub tokens_saved: usize,
20}
21
22pub struct ContextSummarizer {
24 provider: Arc<dyn LlmProvider>,
26}
27
28impl ContextSummarizer {
29 pub fn new(provider: Arc<dyn LlmProvider>) -> Self {
31 Self { provider }
32 }
33
34 pub async fn summarize(&self, messages: &[Message]) -> Result<ContextSummary, SummarizeError> {
36 if messages.is_empty() {
37 return Ok(ContextSummary {
38 text: String::new(),
39 messages_summarized: 0,
40 tokens_saved: 0,
41 });
42 }
43
44 let prompt = build_summarization_prompt(messages);
45
46 let request = CompletionRequest {
47 messages: vec![Message::user(prompt)],
48 tools: None,
49 temperature: 0.3,
50 max_tokens: Some(500),
51 stop_sequences: Vec::new(),
52 model: None,
53 };
54
55 let response = self
56 .provider
57 .complete(request)
58 .await
59 .map_err(|e| SummarizeError::LlmError(e.to_string()))?;
60
61 let summary_text = match &response.message.content {
62 Content::Text { text } => text.clone(),
63 _ => String::from("[Summary unavailable]"),
64 };
65
66 let original_tokens: usize = messages.iter().map(estimate_message_tokens).sum();
68 let summary_tokens = summary_text.len() / 4; Ok(ContextSummary {
71 text: summary_text,
72 messages_summarized: messages.len(),
73 tokens_saved: original_tokens.saturating_sub(summary_tokens),
74 })
75 }
76
77 pub fn should_summarize(context_ratio: f32, threshold: f32) -> bool {
79 context_ratio >= threshold
80 }
81}
82
83fn build_summarization_prompt(messages: &[Message]) -> String {
85 let mut prompt = String::from(
86 "Summarize the following conversation concisely, preserving:\n\
87 - Key decisions and conclusions\n\
88 - Important facts and data points\n\
89 - Tool results and their outcomes\n\
90 - Current task goals and progress\n\n\
91 Conversation:\n",
92 );
93
94 for msg in messages {
95 let role = match msg.role {
96 Role::User => "User",
97 Role::Assistant => "Assistant",
98 Role::System => "System",
99 Role::Tool => "Tool",
100 };
101 let text = match &msg.content {
102 Content::Text { text } => text.clone(),
103 Content::ToolCall {
104 name, arguments, ..
105 } => format!("[Tool Call: {} ({})]", name, arguments),
106 Content::ToolResult { output, .. } => {
107 format!("[Tool Result: {}]", output)
108 }
109 Content::MultiPart { parts } => parts
110 .iter()
111 .filter_map(|p| {
112 if let Content::Text { text } = p {
113 Some(text.as_str())
114 } else {
115 None
116 }
117 })
118 .collect::<Vec<_>>()
119 .join(" "),
120 };
121 prompt.push_str(&format!("{}: {}\n", role, text));
122 }
123
124 prompt.push_str("\nProvide a concise summary (3-5 sentences) capturing the essential context:");
125 prompt
126}
127
128fn estimate_message_tokens(msg: &Message) -> usize {
130 let text_len = match &msg.content {
131 Content::Text { text } => text.len(),
132 Content::ToolCall { arguments, .. } => arguments.to_string().len(),
133 Content::ToolResult { output, .. } => output.len(),
134 Content::MultiPart { parts } => parts
135 .iter()
136 .map(|p| match p {
137 Content::Text { text } => text.len(),
138 _ => 0,
139 })
140 .sum(),
141 };
142 text_len / 4 + 4 }
144
145#[derive(Debug, thiserror::Error)]
147pub enum SummarizeError {
148 #[error("LLM error during summarization: {0}")]
149 LlmError(String),
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum TokenAlert {
155 Normal,
157 Warning,
159 Critical,
161 Overflow,
163}
164
165impl TokenAlert {
166 pub fn from_ratio(ratio: f32) -> Self {
168 if ratio > 0.95 {
169 TokenAlert::Overflow
170 } else if ratio > 0.80 {
171 TokenAlert::Critical
172 } else if ratio > 0.50 {
173 TokenAlert::Warning
174 } else {
175 TokenAlert::Normal
176 }
177 }
178}
179
180impl std::fmt::Display for TokenAlert {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 match self {
183 TokenAlert::Normal => write!(f, "OK"),
184 TokenAlert::Warning => write!(f, "WARNING"),
185 TokenAlert::Critical => write!(f, "CRITICAL"),
186 TokenAlert::Overflow => write!(f, "OVERFLOW"),
187 }
188 }
189}
190
191#[derive(Debug, Clone)]
193pub struct TokenCostDisplay {
194 pub input_tokens: usize,
196 pub output_tokens: usize,
198 pub total_tokens: usize,
200 pub context_window: usize,
202 pub context_ratio: f32,
204 pub total_cost: f64,
206 pub alert: TokenAlert,
208}
209
210impl TokenCostDisplay {
211 pub fn from_brain(brain: &Brain) -> Self {
215 let usage = brain.total_usage();
216 let cost = brain.total_cost();
217 let context_window = brain.context_window();
218 let ratio = if context_window > 0 {
219 usage.total() as f32 / context_window as f32
220 } else {
221 0.0
222 };
223
224 Self {
225 input_tokens: usage.input_tokens,
226 output_tokens: usage.output_tokens,
227 total_tokens: usage.total(),
228 context_window,
229 context_ratio: ratio,
230 total_cost: cost.total(),
231 alert: TokenAlert::from_ratio(ratio),
232 }
233 }
234
235 pub fn format_display(&self) -> String {
237 format!(
238 "Tokens: {} in / {} out ({} total) | Context: {:.0}% of {} | Cost: ${:.4} | {}",
239 self.input_tokens,
240 self.output_tokens,
241 self.total_tokens,
242 self.context_ratio * 100.0,
243 self.context_window,
244 self.total_cost,
245 self.alert,
246 )
247 }
248}
249
250fn truncate_str(s: &str, max: usize) -> &str {
252 if s.len() <= max {
253 return s;
254 }
255 let mut end = max;
257 while end > 0 && !s.is_char_boundary(end) {
258 end -= 1;
259 }
260 &s[..end]
261}
262
263pub fn smart_fallback_summary(messages: &[Message], max_chars: usize) -> String {
267 if messages.is_empty() {
268 return String::new();
269 }
270
271 let quarter = max_chars / 4;
272 let mut parts = Vec::new();
273
274 if let Some(first) = messages.first()
276 && let Some(text) = first.content.as_text()
277 {
278 parts.push(format!("[Start] {}", truncate_str(text, quarter)));
279 }
280
281 for msg in messages.iter() {
283 match &msg.content {
284 Content::ToolCall { name, .. } => {
285 parts.push(format!("[Tool: {}]", name));
286 }
287 Content::ToolResult { output, .. } => {
288 parts.push(format!("[Result: {}]", truncate_str(output, 80)));
289 }
290 _ => {}
291 }
292 }
293
294 if messages.len() > 1
296 && let Some(last) = messages.last()
297 && let Some(text) = last.content.as_text()
298 {
299 parts.push(format!("[Latest] {}", truncate_str(text, quarter)));
300 }
301
302 let joined = parts.join("\n");
303 if joined.len() > max_chars {
304 format!("{}...", truncate_str(&joined, max_chars))
305 } else {
306 joined
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::MockLlmProvider;
314
315 #[test]
316 fn test_token_alert_from_ratio() {
317 assert_eq!(TokenAlert::from_ratio(0.0), TokenAlert::Normal);
318 assert_eq!(TokenAlert::from_ratio(0.3), TokenAlert::Normal);
319 assert_eq!(TokenAlert::from_ratio(0.51), TokenAlert::Warning);
320 assert_eq!(TokenAlert::from_ratio(0.81), TokenAlert::Critical);
321 assert_eq!(TokenAlert::from_ratio(0.96), TokenAlert::Overflow);
322 }
323
324 #[test]
325 fn test_token_alert_display() {
326 assert_eq!(TokenAlert::Normal.to_string(), "OK");
327 assert_eq!(TokenAlert::Warning.to_string(), "WARNING");
328 assert_eq!(TokenAlert::Critical.to_string(), "CRITICAL");
329 assert_eq!(TokenAlert::Overflow.to_string(), "OVERFLOW");
330 }
331
332 #[test]
333 fn test_should_summarize() {
334 assert!(!ContextSummarizer::should_summarize(0.5, 0.8));
335 assert!(ContextSummarizer::should_summarize(0.85, 0.8));
336 assert!(ContextSummarizer::should_summarize(1.0, 0.8));
337 }
338
339 #[test]
340 fn test_build_summarization_prompt() {
341 let messages = vec![Message::user("Hello"), Message::assistant("Hi there")];
342 let prompt = build_summarization_prompt(&messages);
343 assert!(prompt.contains("User: Hello"));
344 assert!(prompt.contains("Assistant: Hi there"));
345 assert!(prompt.contains("Summarize"));
346 }
347
348 #[test]
349 fn test_estimate_message_tokens() {
350 let msg = Message::user("Hello world, this is a test message");
351 let tokens = estimate_message_tokens(&msg);
352 assert!(tokens > 0);
353 }
354
355 #[tokio::test]
356 async fn test_summarize_empty() {
357 let provider = Arc::new(MockLlmProvider::new());
358 let summarizer = ContextSummarizer::new(provider);
359 let result = summarizer.summarize(&[]).await.unwrap();
360 assert_eq!(result.messages_summarized, 0);
361 assert!(result.text.is_empty());
362 }
363
364 #[tokio::test]
365 async fn test_summarize_messages() {
366 let provider = Arc::new(MockLlmProvider::new());
367 let summarizer = ContextSummarizer::new(provider);
368 let messages = vec![
369 Message::user("Write a function"),
370 Message::assistant("Here's the function..."),
371 ];
372 let result = summarizer.summarize(&messages).await.unwrap();
373 assert_eq!(result.messages_summarized, 2);
374 assert!(!result.text.is_empty());
375 }
376
377 #[test]
378 fn test_token_cost_display_format() {
379 let display = TokenCostDisplay {
380 input_tokens: 1000,
381 output_tokens: 500,
382 total_tokens: 1500,
383 context_window: 128000,
384 context_ratio: 0.45,
385 total_cost: 0.0123,
386 alert: TokenAlert::Normal,
387 };
388 let formatted = display.format_display();
389 assert!(formatted.contains("1000 in"));
390 assert!(formatted.contains("500 out"));
391 assert!(formatted.contains("$0.0123"));
392 assert!(formatted.contains("OK"));
393 }
394
395 #[test]
398 fn test_smart_fallback_preserves_tool_names() {
399 let messages = vec![
400 Message::user("fix the bug"),
401 Message::new(
402 Role::Assistant,
403 Content::tool_call(
404 "c1",
405 "file_read",
406 serde_json::json!({"path": "src/main.rs"}),
407 ),
408 ),
409 Message::new(
410 Role::Tool,
411 Content::tool_result("c1", "fn main() { println!(\"hello\"); }", false),
412 ),
413 Message::assistant("I found the issue."),
414 ];
415
416 let summary = smart_fallback_summary(&messages, 500);
417
418 assert!(
419 summary.contains("file_read"),
420 "Summary should contain tool name: {}",
421 summary
422 );
423 assert!(
424 summary.contains("fix the bug"),
425 "Summary should contain first message: {}",
426 summary
427 );
428 }
429
430 #[test]
431 fn test_smart_fallback_preserves_first_and_last() {
432 let messages = vec![
433 Message::user("initial request about authentication"),
434 Message::assistant("Let me look into that."),
435 Message::user("follow up about tokens"),
436 Message::assistant("Here is the solution for token handling"),
437 ];
438
439 let summary = smart_fallback_summary(&messages, 500);
440
441 assert!(
442 summary.contains("initial request"),
443 "Summary should contain first message: {}",
444 summary
445 );
446 assert!(
447 summary.contains("token handling"),
448 "Summary should contain last message: {}",
449 summary
450 );
451 }
452
453 #[test]
454 fn test_smart_fallback_respects_limit() {
455 let long_text = "a".repeat(1000);
456 let messages = vec![Message::user(&long_text)];
457
458 let summary = smart_fallback_summary(&messages, 100);
459
460 assert!(
461 summary.len() <= 110, "Summary should respect limit: len={} > 110",
463 summary.len()
464 );
465 }
466
467 #[test]
468 fn test_smart_fallback_empty_messages() {
469 let summary = smart_fallback_summary(&[], 500);
470 assert!(
471 summary.is_empty(),
472 "Empty messages should give empty summary"
473 );
474 }
475
476 #[test]
477 fn test_smart_fallback_different_limits() {
478 let messages = vec![Message::user("x".repeat(1000))];
479
480 let short = smart_fallback_summary(&messages, 50);
481 let long = smart_fallback_summary(&messages, 800);
482
483 assert!(short.len() <= 60);
484 assert!(long.len() <= 810);
485 assert!(long.len() > short.len());
486 }
487}