1use crate::brain::{Brain, LlmProvider};
8use crate::types::{CompletionRequest, Content, Message, Role};
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub struct ContextSummary {
14 pub text: String,
16 pub messages_summarized: usize,
18 pub tokens_saved: usize,
20}
21
22pub struct ContextSummarizer {
24 provider: Arc<dyn LlmProvider>,
26}
27
28impl ContextSummarizer {
29 pub fn new(provider: Arc<dyn LlmProvider>) -> Self {
31 Self { provider }
32 }
33
34 pub async fn summarize(&self, messages: &[Message]) -> Result<ContextSummary, SummarizeError> {
36 if messages.is_empty() {
37 return Ok(ContextSummary {
38 text: String::new(),
39 messages_summarized: 0,
40 tokens_saved: 0,
41 });
42 }
43
44 let prompt = build_summarization_prompt(messages);
45
46 let request = CompletionRequest {
47 messages: vec![Message::user(prompt)],
48 tools: None,
49 temperature: 0.3,
50 max_tokens: Some(500),
51 stop_sequences: Vec::new(),
52 model: None,
53 };
54
55 let response = self
56 .provider
57 .complete(request)
58 .await
59 .map_err(|e| SummarizeError::LlmError(e.to_string()))?;
60
61 let summary_text = match &response.message.content {
62 Content::Text { text } => text.clone(),
63 _ => String::from("[Summary unavailable]"),
64 };
65
66 let original_tokens: usize = messages.iter().map(estimate_message_tokens).sum();
68 let summary_tokens = summary_text.len() / 4; Ok(ContextSummary {
71 text: summary_text,
72 messages_summarized: messages.len(),
73 tokens_saved: original_tokens.saturating_sub(summary_tokens),
74 })
75 }
76
77 pub fn should_summarize(context_ratio: f32, threshold: f32) -> bool {
79 context_ratio >= threshold
80 }
81}
82
83fn build_summarization_prompt(messages: &[Message]) -> String {
85 let mut prompt = String::from(
86 "Summarize the following conversation concisely, preserving:\n\
87 - Key decisions and conclusions\n\
88 - Important facts and data points\n\
89 - Tool results and their outcomes\n\
90 - Current task goals and progress\n\n\
91 Conversation:\n",
92 );
93
94 for msg in messages {
95 let role = match msg.role {
96 Role::User => "User",
97 Role::Assistant => "Assistant",
98 Role::System => "System",
99 Role::Tool => "Tool",
100 };
101 let text = match &msg.content {
102 Content::Text { text } => text.clone(),
103 Content::ToolCall {
104 name, arguments, ..
105 } => format!("[Tool Call: {} ({})]", name, arguments),
106 Content::ToolResult { output, .. } => {
107 format!("[Tool Result: {}]", output)
108 }
109 Content::MultiPart { parts } => parts
110 .iter()
111 .filter_map(|p| {
112 if let Content::Text { text } = p {
113 Some(text.as_str())
114 } else {
115 None
116 }
117 })
118 .collect::<Vec<_>>()
119 .join(" "),
120 };
121 prompt.push_str(&format!("{}: {}\n", role, text));
122 }
123
124 prompt.push_str("\nProvide a concise summary (3-5 sentences) capturing the essential context:");
125 prompt
126}
127
128fn estimate_message_tokens(msg: &Message) -> usize {
130 let text_len = match &msg.content {
131 Content::Text { text } => text.len(),
132 Content::ToolCall { arguments, .. } => arguments.to_string().len(),
133 Content::ToolResult { output, .. } => output.len(),
134 Content::MultiPart { parts } => parts
135 .iter()
136 .map(|p| match p {
137 Content::Text { text } => text.len(),
138 _ => 0,
139 })
140 .sum(),
141 };
142 text_len / 4 + 4 }
144
145#[derive(Debug, thiserror::Error)]
147pub enum SummarizeError {
148 #[error("LLM error during summarization: {0}")]
149 LlmError(String),
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum TokenAlert {
155 Normal,
157 Warning,
159 Critical,
161 Overflow,
163}
164
165impl TokenAlert {
166 pub fn from_ratio(ratio: f32) -> Self {
168 if ratio > 0.95 {
169 TokenAlert::Overflow
170 } else if ratio > 0.80 {
171 TokenAlert::Critical
172 } else if ratio > 0.50 {
173 TokenAlert::Warning
174 } else {
175 TokenAlert::Normal
176 }
177 }
178}
179
180impl std::fmt::Display for TokenAlert {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 match self {
183 TokenAlert::Normal => write!(f, "OK"),
184 TokenAlert::Warning => write!(f, "WARNING"),
185 TokenAlert::Critical => write!(f, "CRITICAL"),
186 TokenAlert::Overflow => write!(f, "OVERFLOW"),
187 }
188 }
189}
190
191#[derive(Debug, Clone)]
193pub struct TokenCostDisplay {
194 pub input_tokens: usize,
196 pub output_tokens: usize,
198 pub total_tokens: usize,
200 pub context_window: usize,
202 pub context_ratio: f32,
204 pub total_cost: f64,
206 pub alert: TokenAlert,
208}
209
210impl TokenCostDisplay {
211 pub fn from_brain(brain: &Brain) -> Self {
215 let usage = brain.total_usage();
216 let cost = brain.total_cost();
217 let context_window = brain.context_window();
218 let ratio = if context_window > 0 {
219 usage.total() as f32 / context_window as f32
220 } else {
221 0.0
222 };
223
224 Self {
225 input_tokens: usage.input_tokens,
226 output_tokens: usage.output_tokens,
227 total_tokens: usage.total(),
228 context_window,
229 context_ratio: ratio,
230 total_cost: cost.total(),
231 alert: TokenAlert::from_ratio(ratio),
232 }
233 }
234
235 pub fn format_display(&self) -> String {
237 format!(
238 "Tokens: {} in / {} out ({} total) | Context: {:.0}% of {} | Cost: ${:.4} | {}",
239 self.input_tokens,
240 self.output_tokens,
241 self.total_tokens,
242 self.context_ratio * 100.0,
243 self.context_window,
244 self.total_cost,
245 self.alert,
246 )
247 }
248}
249
250fn truncate_str(s: &str, max: usize) -> &str {
252 if s.len() <= max {
253 return s;
254 }
255 let mut end = max;
257 while end > 0 && !s.is_char_boundary(end) {
258 end -= 1;
259 }
260 &s[..end]
261}
262
263pub fn smart_fallback_summary(messages: &[Message], max_chars: usize) -> String {
267 if messages.is_empty() {
268 return String::new();
269 }
270
271 let quarter = max_chars / 4;
272 let mut parts = Vec::new();
273
274 if let Some(first) = messages.first() {
276 if let Some(text) = first.content.as_text() {
277 parts.push(format!("[Start] {}", truncate_str(text, quarter)));
278 }
279 }
280
281 for msg in messages.iter() {
283 match &msg.content {
284 Content::ToolCall { name, .. } => {
285 parts.push(format!("[Tool: {}]", name));
286 }
287 Content::ToolResult { output, .. } => {
288 parts.push(format!("[Result: {}]", truncate_str(output, 80)));
289 }
290 _ => {}
291 }
292 }
293
294 if messages.len() > 1 {
296 if let Some(last) = messages.last() {
297 if let Some(text) = last.content.as_text() {
298 parts.push(format!("[Latest] {}", truncate_str(text, quarter)));
299 }
300 }
301 }
302
303 let joined = parts.join("\n");
304 if joined.len() > max_chars {
305 format!("{}...", truncate_str(&joined, max_chars))
306 } else {
307 joined
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::MockLlmProvider;
315
316 #[test]
317 fn test_token_alert_from_ratio() {
318 assert_eq!(TokenAlert::from_ratio(0.0), TokenAlert::Normal);
319 assert_eq!(TokenAlert::from_ratio(0.3), TokenAlert::Normal);
320 assert_eq!(TokenAlert::from_ratio(0.51), TokenAlert::Warning);
321 assert_eq!(TokenAlert::from_ratio(0.81), TokenAlert::Critical);
322 assert_eq!(TokenAlert::from_ratio(0.96), TokenAlert::Overflow);
323 }
324
325 #[test]
326 fn test_token_alert_display() {
327 assert_eq!(TokenAlert::Normal.to_string(), "OK");
328 assert_eq!(TokenAlert::Warning.to_string(), "WARNING");
329 assert_eq!(TokenAlert::Critical.to_string(), "CRITICAL");
330 assert_eq!(TokenAlert::Overflow.to_string(), "OVERFLOW");
331 }
332
333 #[test]
334 fn test_should_summarize() {
335 assert!(!ContextSummarizer::should_summarize(0.5, 0.8));
336 assert!(ContextSummarizer::should_summarize(0.85, 0.8));
337 assert!(ContextSummarizer::should_summarize(1.0, 0.8));
338 }
339
340 #[test]
341 fn test_build_summarization_prompt() {
342 let messages = vec![Message::user("Hello"), Message::assistant("Hi there")];
343 let prompt = build_summarization_prompt(&messages);
344 assert!(prompt.contains("User: Hello"));
345 assert!(prompt.contains("Assistant: Hi there"));
346 assert!(prompt.contains("Summarize"));
347 }
348
349 #[test]
350 fn test_estimate_message_tokens() {
351 let msg = Message::user("Hello world, this is a test message");
352 let tokens = estimate_message_tokens(&msg);
353 assert!(tokens > 0);
354 }
355
356 #[tokio::test]
357 async fn test_summarize_empty() {
358 let provider = Arc::new(MockLlmProvider::new());
359 let summarizer = ContextSummarizer::new(provider);
360 let result = summarizer.summarize(&[]).await.unwrap();
361 assert_eq!(result.messages_summarized, 0);
362 assert!(result.text.is_empty());
363 }
364
365 #[tokio::test]
366 async fn test_summarize_messages() {
367 let provider = Arc::new(MockLlmProvider::new());
368 let summarizer = ContextSummarizer::new(provider);
369 let messages = vec![
370 Message::user("Write a function"),
371 Message::assistant("Here's the function..."),
372 ];
373 let result = summarizer.summarize(&messages).await.unwrap();
374 assert_eq!(result.messages_summarized, 2);
375 assert!(!result.text.is_empty());
376 }
377
378 #[test]
379 fn test_token_cost_display_format() {
380 let display = TokenCostDisplay {
381 input_tokens: 1000,
382 output_tokens: 500,
383 total_tokens: 1500,
384 context_window: 128000,
385 context_ratio: 0.45,
386 total_cost: 0.0123,
387 alert: TokenAlert::Normal,
388 };
389 let formatted = display.format_display();
390 assert!(formatted.contains("1000 in"));
391 assert!(formatted.contains("500 out"));
392 assert!(formatted.contains("$0.0123"));
393 assert!(formatted.contains("OK"));
394 }
395
396 #[test]
399 fn test_smart_fallback_preserves_tool_names() {
400 let messages = vec![
401 Message::user("fix the bug"),
402 Message::new(
403 Role::Assistant,
404 Content::tool_call(
405 "c1",
406 "file_read",
407 serde_json::json!({"path": "src/main.rs"}),
408 ),
409 ),
410 Message::new(
411 Role::Tool,
412 Content::tool_result("c1", "fn main() { println!(\"hello\"); }", false),
413 ),
414 Message::assistant("I found the issue."),
415 ];
416
417 let summary = smart_fallback_summary(&messages, 500);
418
419 assert!(
420 summary.contains("file_read"),
421 "Summary should contain tool name: {}",
422 summary
423 );
424 assert!(
425 summary.contains("fix the bug"),
426 "Summary should contain first message: {}",
427 summary
428 );
429 }
430
431 #[test]
432 fn test_smart_fallback_preserves_first_and_last() {
433 let messages = vec![
434 Message::user("initial request about authentication"),
435 Message::assistant("Let me look into that."),
436 Message::user("follow up about tokens"),
437 Message::assistant("Here is the solution for token handling"),
438 ];
439
440 let summary = smart_fallback_summary(&messages, 500);
441
442 assert!(
443 summary.contains("initial request"),
444 "Summary should contain first message: {}",
445 summary
446 );
447 assert!(
448 summary.contains("token handling"),
449 "Summary should contain last message: {}",
450 summary
451 );
452 }
453
454 #[test]
455 fn test_smart_fallback_respects_limit() {
456 let long_text = "a".repeat(1000);
457 let messages = vec![Message::user(&long_text)];
458
459 let summary = smart_fallback_summary(&messages, 100);
460
461 assert!(
462 summary.len() <= 110, "Summary should respect limit: len={} > 110",
464 summary.len()
465 );
466 }
467
468 #[test]
469 fn test_smart_fallback_empty_messages() {
470 let summary = smart_fallback_summary(&[], 500);
471 assert!(
472 summary.is_empty(),
473 "Empty messages should give empty summary"
474 );
475 }
476
477 #[test]
478 fn test_smart_fallback_different_limits() {
479 let messages = vec![Message::user("x".repeat(1000))];
480
481 let short = smart_fallback_summary(&messages, 50);
482 let long = smart_fallback_summary(&messages, 800);
483
484 assert!(short.len() <= 60);
485 assert!(long.len() <= 810);
486 assert!(long.len() > short.len());
487 }
488}