1use std::sync::Arc;
7
8use crate::context::{CompactionReport, TokenCounter, compact_sliding_window_with};
9use crate::types::AgentMessage;
10
11pub trait ContextTransformer: Send + Sync {
12 fn transform(
19 &self,
20 messages: &mut Vec<AgentMessage>,
21 overflow: bool,
22 ) -> Option<CompactionReport>;
23}
24
25impl<F: Fn(&mut Vec<AgentMessage>, bool) + Send + Sync> ContextTransformer for F {
27 fn transform(
28 &self,
29 messages: &mut Vec<AgentMessage>,
30 overflow: bool,
31 ) -> Option<CompactionReport> {
32 let before = messages.len();
33 self(messages, overflow);
34 let after = messages.len();
35 if after < before {
36 Some(CompactionReport {
37 dropped_count: before - after,
38 tokens_before: 0, tokens_after: 0,
40 overflow,
41 dropped_messages: Vec::new(), })
43 } else {
44 None
45 }
46 }
47}
48
49pub struct SlidingWindowTransformer {
57 normal_budget: usize,
58 overflow_budget: usize,
59 anchor: usize,
60 token_counter: Option<Arc<dyn TokenCounter>>,
61 cached_prefix_len: usize,
64}
65
66impl SlidingWindowTransformer {
67 #[must_use]
75 pub fn new(normal_budget: usize, overflow_budget: usize, anchor: usize) -> Self {
76 Self {
77 normal_budget,
78 overflow_budget,
79 anchor,
80 token_counter: None,
81 cached_prefix_len: 0,
82 }
83 }
84
85 #[must_use]
86 pub fn with_token_counter(mut self, counter: Arc<dyn TokenCounter>) -> Self {
87 self.token_counter = Some(counter);
88 self
89 }
90
91 #[must_use]
95 pub const fn with_cached_prefix_len(mut self, len: usize) -> Self {
96 self.cached_prefix_len = len;
97 self
98 }
99
100 pub const fn set_cached_prefix_len(&mut self, len: usize) {
102 self.cached_prefix_len = len;
103 }
104}
105
106impl ContextTransformer for SlidingWindowTransformer {
107 fn transform(
108 &self,
109 messages: &mut Vec<AgentMessage>,
110 overflow: bool,
111 ) -> Option<CompactionReport> {
112 let budget = if overflow {
113 self.overflow_budget
114 } else {
115 self.normal_budget
116 };
117
118 let effective_anchor = self.anchor.max(self.cached_prefix_len);
119 let counter_ref = self.token_counter.as_deref();
120 let mut report =
121 compact_sliding_window_with(messages, budget, effective_anchor, counter_ref)?;
122 report.overflow = overflow;
123 Some(report)
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use crate::types::{ContentBlock, LlmMessage, UserMessage};
131
132 fn text_message(text: &str) -> AgentMessage {
133 AgentMessage::Llm(LlmMessage::User(UserMessage {
134 content: vec![ContentBlock::Text {
135 text: text.to_owned(),
136 }],
137 timestamp: 0,
138 cache_hint: None,
139 }))
140 }
141
142 #[test]
143 fn sliding_window_transformer_reports_dropped_messages() {
144 let transformer = SlidingWindowTransformer::new(250, 100, 1);
145 let body = "x".repeat(400);
147 let mut messages = vec![
148 text_message(&body),
149 text_message(&body),
150 text_message(&body),
151 text_message(&body),
152 ];
153
154 let report = transformer.transform(&mut messages, false);
155 assert!(report.is_some(), "should report compaction");
156 let report = report.unwrap();
157 assert_eq!(report.dropped_count, 2);
158 assert_eq!(report.tokens_before, 400);
159 assert!(report.tokens_after < report.tokens_before);
160 assert!(!report.overflow);
161 assert_eq!(messages.len(), 2);
162 }
163
164 #[test]
165 fn sliding_window_transformer_no_report_under_budget() {
166 let transformer = SlidingWindowTransformer::new(10_000, 5_000, 1);
167 let mut messages = vec![text_message("hello"), text_message("world")];
168
169 let report = transformer.transform(&mut messages, false);
170 assert!(report.is_none(), "should not report when under budget");
171 assert_eq!(messages.len(), 2);
172 }
173
174 #[test]
175 fn closure_blanket_impl_works() {
176 let closure = |msgs: &mut Vec<AgentMessage>, _overflow: bool| {
177 if msgs.len() > 2 {
178 msgs.truncate(2);
179 }
180 };
181
182 let mut messages = vec![
183 text_message("a"),
184 text_message("b"),
185 text_message("c"),
186 text_message("d"),
187 ];
188
189 let report = closure.transform(&mut messages, false);
190 assert!(report.is_some());
191 let report = report.unwrap();
192 assert_eq!(report.dropped_count, 2);
193 assert_eq!(report.tokens_before, 0);
195 assert_eq!(report.tokens_after, 0);
196 assert_eq!(messages.len(), 2);
197 }
198
199 #[test]
200 fn overflow_uses_smaller_budget() {
201 let transformer = SlidingWindowTransformer::new(1000, 150, 1);
202 let body = "x".repeat(400);
203 let mut messages = vec![
204 text_message(&body),
205 text_message(&body),
206 text_message(&body),
207 text_message(&body),
208 ];
209
210 let report = transformer.transform(&mut messages, false);
212 assert!(report.is_none());
213 assert_eq!(messages.len(), 4);
214
215 let report = transformer.transform(&mut messages, true);
217 assert!(report.is_some());
218 let report = report.unwrap();
219 assert!(report.overflow);
220 assert!(messages.len() < 4);
221 }
222
223 #[test]
224 fn sliding_window_transformer_with_custom_counter() {
225 use crate::context::TokenCounter;
226
227 struct CharCounter;
229
230 impl TokenCounter for CharCounter {
231 fn count_tokens(&self, message: &AgentMessage) -> usize {
232 match message {
233 AgentMessage::Llm(llm) => {
234 let blocks = match llm {
235 LlmMessage::User(m) => &m.content,
236 _ => return 0,
237 };
238 blocks
239 .iter()
240 .map(|b| match b {
241 ContentBlock::Text { text } => text.len(),
242 _ => 0,
243 })
244 .sum()
245 }
246 AgentMessage::Custom(_) => 50,
247 }
248 }
249 }
250
251 let body = "x".repeat(400);
255
256 let default_transformer = SlidingWindowTransformer::new(500, 250, 1);
258 let mut messages = vec![
259 text_message(&body),
260 text_message(&body),
261 text_message(&body),
262 text_message(&body),
263 ];
264 let report = default_transformer.transform(&mut messages, false);
265 assert!(
266 report.is_none(),
267 "default counter should not trim at budget 500"
268 );
269 assert_eq!(messages.len(), 4);
270
271 let custom_transformer =
273 SlidingWindowTransformer::new(500, 250, 1).with_token_counter(Arc::new(CharCounter));
274 let mut messages = vec![
275 text_message(&body),
276 text_message(&body),
277 text_message(&body),
278 text_message(&body),
279 ];
280 let report = custom_transformer.transform(&mut messages, false);
281 assert!(report.is_some(), "char counter should trim at budget 500");
282 assert!(messages.len() < 4);
283 }
284}