1use serde::{Deserialize, Serialize};
4
5use crate::types::{AgentMessage, ContentBlock, LlmMessage};
6
7pub trait TokenCounter: Send + Sync {
14 fn count_tokens(&self, message: &AgentMessage) -> usize;
16}
17
18#[derive(Debug, Clone, Copy, Default)]
23pub struct DefaultTokenCounter;
24
25impl TokenCounter for DefaultTokenCounter {
26 fn count_tokens(&self, message: &AgentMessage) -> usize {
27 match message {
28 AgentMessage::Llm(llm) => {
29 let chars: usize = content_blocks(llm)
30 .iter()
31 .map(|b| match b {
32 ContentBlock::Text { text } => text.len(),
33 ContentBlock::Thinking { thinking, .. } => thinking.len(),
34 ContentBlock::ToolCall { arguments, .. } => arguments.to_string().len(),
35 ContentBlock::Image { .. } => 0,
36 ContentBlock::Extension { data, .. } => data.to_string().len(),
37 })
38 .sum();
39 chars / 4
40 }
41 AgentMessage::Custom(_) => 100,
42 }
43 }
44}
45
46pub fn estimate_tokens(msg: &AgentMessage) -> usize {
50 DefaultTokenCounter.count_tokens(msg)
51}
52
53fn content_blocks(msg: &LlmMessage) -> &[ContentBlock] {
54 match msg {
55 LlmMessage::User(m) => &m.content,
56 LlmMessage::Assistant(m) => &m.content,
57 LlmMessage::ToolResult(m) => &m.content,
58 }
59}
60
61fn is_tool_result(messages: &[AgentMessage], idx: usize) -> bool {
62 matches!(
63 messages.get(idx),
64 Some(AgentMessage::Llm(LlmMessage::ToolResult(_)))
65 )
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct CompactionReport {
71 pub dropped_count: usize,
73 pub tokens_before: usize,
75 pub tokens_after: usize,
77 pub overflow: bool,
79 #[serde(default, skip_serializing_if = "Vec::is_empty")]
85 pub dropped_messages: Vec<LlmMessage>,
86}
87
88pub fn compact_sliding_window(
99 messages: &mut Vec<AgentMessage>,
100 budget: usize,
101 anchor: usize,
102) -> Option<CompactionReport> {
103 compact_sliding_window_with(messages, budget, anchor, None)
104}
105
106pub fn compact_sliding_window_with(
108 messages: &mut Vec<AgentMessage>,
109 budget: usize,
110 anchor: usize,
111 counter: Option<&dyn TokenCounter>,
112) -> Option<CompactionReport> {
113 let default = DefaultTokenCounter;
114 let counter: &dyn TokenCounter = counter.unwrap_or(&default);
115
116 let count = |m: &AgentMessage| counter.count_tokens(m);
117
118 let tokens_before: usize = messages.iter().map(count).sum();
119 if tokens_before <= budget {
120 return None;
121 }
122
123 let len = messages.len();
124 let effective_anchor = anchor.min(len);
125
126 let anchor_tokens: usize = messages[..effective_anchor].iter().map(count).sum();
128
129 let remaining_budget = budget.saturating_sub(anchor_tokens);
130
131 let mut tail_tokens = 0;
133 let mut tail_start = len;
134
135 for i in (effective_anchor..len).rev() {
136 let msg_tokens = count(&messages[i]);
137 if tail_tokens + msg_tokens > remaining_budget {
138 break;
139 }
140 tail_tokens += msg_tokens;
141 tail_start = i;
142 }
143
144 while tail_start > effective_anchor && tail_start < len && is_tool_result(messages, tail_start)
148 {
149 tail_start -= 1;
150 }
151
152 if tail_start <= effective_anchor {
154 return None;
155 }
156
157 let dropped_count = tail_start - effective_anchor;
158
159 let dropped_messages: Vec<LlmMessage> = messages[effective_anchor..tail_start]
161 .iter()
162 .filter_map(|m| match m {
163 AgentMessage::Llm(llm) => Some(llm.clone()),
164 AgentMessage::Custom(_) => None,
165 })
166 .collect();
167
168 let tail: Vec<AgentMessage> = messages.drain(tail_start..).collect();
170 messages.truncate(effective_anchor);
171 messages.extend(tail);
172
173 let tokens_after: usize = messages.iter().map(count).sum();
174
175 Some(CompactionReport {
176 dropped_count,
177 tokens_before,
178 tokens_after,
179 overflow: false,
180 dropped_messages,
181 })
182}
183
184#[deprecated(since = "0.5.0", note = "Use SlidingWindowTransformer instead")]
193pub fn sliding_window(
194 normal_budget: usize,
195 overflow_budget: usize,
196 anchor: usize,
197) -> impl Fn(&mut Vec<AgentMessage>, bool) + Send + Sync {
198 move |messages: &mut Vec<AgentMessage>, overflow: bool| {
199 let budget = if overflow {
200 overflow_budget
201 } else {
202 normal_budget
203 };
204 compact_sliding_window(messages, budget, anchor);
205 }
206}
207
208pub fn is_context_overflow(
214 messages: &[AgentMessage],
215 model: &crate::types::ModelSpec,
216 counter: Option<&dyn TokenCounter>,
217) -> bool {
218 let max_window = model
219 .capabilities
220 .as_ref()
221 .and_then(|c| c.max_context_window);
222
223 let Some(max_window) = max_window else {
224 return false;
225 };
226
227 let default = DefaultTokenCounter;
228 let counter: &dyn TokenCounter = counter.unwrap_or(&default);
229
230 let total_tokens: usize = messages.iter().map(|m| counter.count_tokens(m)).sum();
231 total_tokens as u64 > max_window
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::types::{
238 AssistantMessage, ContentBlock, Cost, LlmMessage, StopReason, ToolResultMessage, Usage,
239 UserMessage,
240 };
241
242 fn text_message(text: &str) -> AgentMessage {
243 AgentMessage::Llm(LlmMessage::User(UserMessage {
244 content: vec![ContentBlock::Text {
245 text: text.to_owned(),
246 }],
247 timestamp: 0,
248 cache_hint: None,
249 }))
250 }
251
252 fn tool_call_message(id: &str) -> AgentMessage {
254 AgentMessage::Llm(LlmMessage::Assistant(AssistantMessage {
255 content: vec![ContentBlock::ToolCall {
256 id: id.into(),
257 name: "test".into(),
258 arguments: serde_json::json!({}),
259 partial_json: None,
260 }],
261 provider: String::new(),
262 model_id: String::new(),
263 usage: Usage::default(),
264 cost: Cost::default(),
265 stop_reason: StopReason::ToolUse,
266 error_message: None,
267 error_kind: None,
268 timestamp: 0,
269 cache_hint: None,
270 }))
271 }
272
273 fn tool_result_message(id: &str, text: &str) -> AgentMessage {
275 AgentMessage::Llm(LlmMessage::ToolResult(ToolResultMessage {
276 tool_call_id: id.into(),
277 content: vec![ContentBlock::Text { text: text.into() }],
278 is_error: false,
279 timestamp: 0,
280 details: serde_json::Value::Null,
281 cache_hint: None,
282 }))
283 }
284
285 #[test]
286 #[allow(deprecated)]
287 fn under_budget_no_change() {
288 let compact = sliding_window(10_000, 5_000, 1);
289 let mut messages = vec![text_message("hello"), text_message("world")];
290 compact(&mut messages, false);
291 assert_eq!(messages.len(), 2);
292 }
293
294 #[test]
295 #[allow(deprecated)]
296 fn over_budget_trims_middle() {
297 let body = "x".repeat(400);
299 let compact = sliding_window(250, 100, 1);
300 let mut messages = vec![
301 text_message(&body),
302 text_message(&body),
303 text_message(&body),
304 text_message(&body),
305 ];
306 compact(&mut messages, false);
307 assert_eq!(messages.len(), 2);
310 }
311
312 #[test]
313 #[allow(deprecated)]
314 fn overflow_uses_smaller_budget() {
315 let body = "x".repeat(400);
316 let compact = sliding_window(1000, 150, 1);
317 let mut messages = vec![
318 text_message(&body),
319 text_message(&body),
320 text_message(&body),
321 text_message(&body),
322 ];
323 compact(&mut messages, false);
325 assert_eq!(messages.len(), 4);
326
327 compact(&mut messages, true);
329 assert!(messages.len() < 4);
330 }
331
332 #[test]
333 #[allow(deprecated)]
334 fn preserves_tool_result_pair() {
335 let compact = sliding_window(300, 100, 1);
336
337 let body = "x".repeat(400);
338 let mut messages = vec![
339 text_message(&body), text_message(&body), tool_call_message("tc1"),
342 tool_result_message("tc1", "result"),
343 ];
344
345 compact(&mut messages, false);
346
347 let has_result = messages
349 .iter()
350 .any(|m| matches!(m, AgentMessage::Llm(LlmMessage::ToolResult(_))));
351 let has_call = messages.iter().any(|m| {
352 matches!(m, AgentMessage::Llm(LlmMessage::Assistant(a))
353 if a.content.iter().any(|b| matches!(b, ContentBlock::ToolCall { .. })))
354 });
355 if has_result {
357 assert!(has_call);
358 }
359 }
360
361 #[test]
364 #[allow(deprecated)]
365 fn empty_messages_no_change() {
366 let compact = sliding_window(100, 50, 1);
367 let mut messages: Vec<AgentMessage> = vec![];
368 compact(&mut messages, false);
369 assert!(messages.is_empty());
370 }
371
372 #[test]
373 #[allow(deprecated)]
374 fn single_message_preserved() {
375 let body = "x".repeat(4000); let compact = sliding_window(10, 5, 1);
380 let mut messages = vec![text_message(&body)];
381 compact(&mut messages, false);
382 assert_eq!(messages.len(), 1);
383 }
384
385 #[test]
386 #[allow(deprecated)]
387 fn anchor_messages_always_kept() {
388 let body = "x".repeat(400); let compact = sliding_window(50, 25, 2); let mut messages = vec![
393 text_message(&body), text_message(&body), text_message(&body), text_message(&body), ];
398 compact(&mut messages, false);
399
400 assert!(messages.len() >= 2);
402 for msg in &messages[..2] {
404 if let AgentMessage::Llm(LlmMessage::User(u)) = msg {
405 assert_eq!(u.content[0], ContentBlock::Text { text: body.clone() });
406 } else {
407 panic!("expected user message in anchor position");
408 }
409 }
410 }
411
412 #[test]
413 #[allow(deprecated)]
414 fn all_messages_under_budget_with_large_system_prompt() {
415 let compact = sliding_window(500, 250, 1);
419 let mut messages = vec![
420 text_message(&"a".repeat(400)), text_message(&"b".repeat(400)), ];
423 compact(&mut messages, false);
425 assert_eq!(messages.len(), 2);
426 }
427
428 #[test]
429 #[allow(deprecated)]
430 fn tool_result_at_boundary_preserved() {
431 let body = "x".repeat(400); let compact = sliding_window(250, 100, 1);
438 let mut messages = vec![
439 text_message(&body), text_message(&body), tool_call_message("tc1"), tool_result_message("tc1", &body), ];
444 compact(&mut messages, false);
445
446 let has_result = messages
447 .iter()
448 .any(|m| matches!(m, AgentMessage::Llm(LlmMessage::ToolResult(_))));
449 let has_call = messages.iter().any(|m| {
450 matches!(m, AgentMessage::Llm(LlmMessage::Assistant(a))
451 if a.content.iter().any(|b| matches!(b, ContentBlock::ToolCall { .. })))
452 });
453 if has_result {
454 assert!(has_call, "tool result kept without its preceding tool call");
455 }
456 }
457
458 #[test]
459 #[allow(deprecated)]
460 fn consecutive_tool_pairs_preserved() {
461 let compact = sliding_window(500, 100, 1);
464 let body = "x".repeat(400); let mut messages = vec![
467 text_message(&body), text_message(&body), tool_call_message("tc1"), tool_result_message("tc1", "r1"), tool_call_message("tc2"), tool_result_message("tc2", "r2"), ];
474 compact(&mut messages, false);
475
476 for msg in &messages {
478 if let AgentMessage::Llm(LlmMessage::ToolResult(tr)) = msg {
479 let call_present = messages.iter().any(|m| {
480 matches!(m, AgentMessage::Llm(LlmMessage::Assistant(a))
481 if a.content.iter().any(|b| matches!(b, ContentBlock::ToolCall { id, .. } if id == &tr.tool_call_id)))
482 });
483 assert!(
484 call_present,
485 "tool result {} kept without its call",
486 tr.tool_call_id
487 );
488 }
489 }
490 }
491
492 #[test]
493 #[allow(deprecated)]
494 fn custom_messages_token_estimation() {
495 #[derive(Debug)]
499 struct TestCustom;
500 impl crate::types::CustomMessage for TestCustom {
501 fn as_any(&self) -> &dyn std::any::Any {
502 self
503 }
504 }
505
506 let compact = sliding_window(150, 50, 1);
508 let mut messages: Vec<AgentMessage> = vec![
509 AgentMessage::Custom(Box::new(TestCustom)), AgentMessage::Custom(Box::new(TestCustom)), ];
512 compact(&mut messages, false);
515 assert_eq!(messages.len(), 1);
516 }
517
518 #[test]
519 #[allow(deprecated)]
520 fn overflow_budget_smaller_than_normal() {
521 let body = "x".repeat(400); let compact = sliding_window(350, 150, 1);
524
525 let mut normal_msgs = vec![
528 text_message(&body),
529 text_message(&body),
530 text_message(&body),
531 text_message(&body),
532 ];
533 compact(&mut normal_msgs, false);
534 let normal_count = normal_msgs.len();
535
536 let mut overflow_msgs = vec![
538 text_message(&body),
539 text_message(&body),
540 text_message(&body),
541 text_message(&body),
542 ];
543 compact(&mut overflow_msgs, true);
544 let overflow_count = overflow_msgs.len();
545
546 assert!(
547 overflow_count < normal_count,
548 "overflow budget ({overflow_count} msgs) should be more aggressive than normal ({normal_count} msgs)"
549 );
550 }
551
552 #[test]
555 fn default_token_counter_matches_estimate_tokens() {
556 let msg = text_message(&"x".repeat(400));
557 assert_eq!(
558 DefaultTokenCounter.count_tokens(&msg),
559 estimate_tokens(&msg)
560 );
561 assert_eq!(DefaultTokenCounter.count_tokens(&msg), 100);
562 }
563
564 #[test]
565 fn default_token_counter_custom_message_flat_100() {
566 #[derive(Debug)]
567 struct TestCustom;
568 impl crate::types::CustomMessage for TestCustom {
569 fn as_any(&self) -> &dyn std::any::Any {
570 self
571 }
572 }
573
574 let msg = AgentMessage::Custom(Box::new(TestCustom));
575 assert_eq!(DefaultTokenCounter.count_tokens(&msg), 100);
576 }
577
578 struct CharCounter;
580
581 impl TokenCounter for CharCounter {
582 fn count_tokens(&self, message: &AgentMessage) -> usize {
583 match message {
584 AgentMessage::Llm(llm) => content_blocks(llm)
585 .iter()
586 .map(|b| match b {
587 ContentBlock::Text { text } => text.len(),
588 _ => 0,
589 })
590 .sum(),
591 AgentMessage::Custom(_) => 50,
592 }
593 }
594 }
595
596 #[test]
597 fn custom_counter_used_by_compact_sliding_window_with() {
598 let body = "x".repeat(400);
601 let mut messages = vec![
602 text_message(&body),
603 text_message(&body),
604 text_message(&body),
605 ];
606
607 let result = compact_sliding_window_with(&mut messages, 500, 1, Some(&CharCounter));
608 assert!(result.is_some());
609 assert_eq!(messages.len(), 1);
611 let r = result.unwrap();
612 assert_eq!(r.tokens_before, 1200);
613 assert_eq!(r.tokens_after, 400);
614 }
615
616 #[test]
617 fn custom_counter_no_compaction_when_under_budget() {
618 let body = "x".repeat(100);
620 let mut messages = vec![text_message(&body), text_message(&body)];
621
622 let result = compact_sliding_window_with(&mut messages, 500, 1, Some(&CharCounter));
623 assert!(result.is_none());
624 assert_eq!(messages.len(), 2);
625 }
626
627 #[test]
628 fn compact_sliding_window_backward_compat() {
629 let body = "x".repeat(400); let mut messages = vec![
632 text_message(&body),
633 text_message(&body),
634 text_message(&body),
635 ];
636 let result = compact_sliding_window(&mut messages, 250, 1);
637 assert!(result.is_some());
638 assert_eq!(messages.len(), 2);
639 }
640
641 #[test]
642 fn compaction_report_includes_dropped_messages() {
643 let body = "x".repeat(400); let mut messages = vec![
648 text_message(&body),
649 text_message(&body),
650 text_message(&body),
651 text_message(&body),
652 ];
653 let report = compact_sliding_window_with(&mut messages, 250, 1, None).unwrap();
655
656 assert_eq!(report.dropped_count, 2);
657 assert_eq!(report.dropped_messages.len(), 2);
658 assert_eq!(messages.len(), 2);
660 }
661
662 #[test]
663 fn compaction_report_dropped_messages_empty_when_no_compaction() {
664 let mut messages = vec![text_message("hello"), text_message("world")];
665 let result = compact_sliding_window_with(&mut messages, 10_000, 1, None);
666 assert!(result.is_none());
668 }
669
670 fn model_with_window(window: u64) -> crate::types::ModelSpec {
673 crate::types::ModelSpec {
674 provider: "test".into(),
675 model_id: "test-model".into(),
676 thinking_level: crate::types::ThinkingLevel::default(),
677 thinking_budgets: None,
678 provider_config: None,
679 capabilities: Some(
680 crate::types::ModelCapabilities::none().with_max_context_window(window),
681 ),
682 }
683 }
684
685 fn model_no_window() -> crate::types::ModelSpec {
686 crate::types::ModelSpec {
687 provider: "test".into(),
688 model_id: "test-model".into(),
689 thinking_level: crate::types::ThinkingLevel::default(),
690 thinking_budgets: None,
691 provider_config: None,
692 capabilities: None,
693 }
694 }
695
696 #[test]
697 fn overflow_within_budget_returns_false() {
698 let messages = vec![text_message(&"x".repeat(400))]; assert!(!is_context_overflow(
700 &messages,
701 &model_with_window(1000),
702 None
703 ));
704 }
705
706 #[test]
707 fn overflow_exceeding_budget_returns_true() {
708 let messages = vec![
709 text_message(&"x".repeat(400)), text_message(&"x".repeat(400)), ];
712 assert!(is_context_overflow(
713 &messages,
714 &model_with_window(150),
715 None
716 ));
717 }
718
719 #[test]
720 fn overflow_no_window_returns_false() {
721 let messages = vec![text_message(&"x".repeat(40_000))]; assert!(!is_context_overflow(&messages, &model_no_window(), None));
723 }
724
725 #[test]
726 fn overflow_custom_counter() {
727 let messages = vec![text_message(&"x".repeat(400))]; assert!(is_context_overflow(
730 &messages,
731 &model_with_window(300),
732 Some(&CharCounter)
733 ));
734 assert!(!is_context_overflow(
736 &messages,
737 &model_with_window(300),
738 None
739 ));
740 }
741
742 #[test]
743 fn overflow_empty_messages_returns_false() {
744 let messages: Vec<AgentMessage> = vec![];
745 assert!(!is_context_overflow(
746 &messages,
747 &model_with_window(100),
748 None
749 ));
750 }
751}