Skip to main content

swink_agent/
context_transformer.rs

1//! Pluggable context transformation with compaction awareness.
2//!
3//! Replaces the bare `TransformContextFn` closure with a trait that supports
4//! both transformation and compaction reporting.
5
6use std::sync::Arc;
7
8use crate::context::{CompactionReport, TokenCounter, compact_sliding_window_with};
9use crate::types::AgentMessage;
10
11pub trait ContextTransformer: Send + Sync {
12    /// Transform the context messages in-place.
13    ///
14    /// Called synchronously before each LLM call. The `overflow` flag is true
15    /// when the previous turn exceeded the context window.
16    ///
17    /// Returns `Some(CompactionReport)` if messages were dropped, `None` otherwise.
18    fn transform(
19        &self,
20        messages: &mut Vec<AgentMessage>,
21        overflow: bool,
22    ) -> Option<CompactionReport>;
23}
24
25/// Blanket impl for existing closures (backward compat).
26impl<F: Fn(&mut Vec<AgentMessage>, bool) + Send + Sync> ContextTransformer for F {
27    fn transform(
28        &self,
29        messages: &mut Vec<AgentMessage>,
30        overflow: bool,
31    ) -> Option<CompactionReport> {
32        let before = messages.len();
33        self(messages, overflow);
34        let after = messages.len();
35        if after < before {
36            Some(CompactionReport {
37                dropped_count: before - after,
38                tokens_before: 0, // bare closures can't report token counts
39                tokens_after: 0,
40                overflow,
41                dropped_messages: Vec::new(), // bare closures don't have access to the dropped slice
42            })
43        } else {
44            None
45        }
46    }
47}
48
49/// Built-in sliding window context transformer with compaction reporting.
50///
51/// Wraps the same logic as [`sliding_window`](crate::sliding_window) but
52/// captures compaction metrics for reporting.
53///
54/// Accepts an optional [`TokenCounter`] for pluggable token estimation.
55/// When none is provided, the default `chars / 4` heuristic is used.
56pub struct SlidingWindowTransformer {
57    normal_budget: usize,
58    overflow_budget: usize,
59    anchor: usize,
60    token_counter: Option<Arc<dyn TokenCounter>>,
61    /// When caching is active, protects this many messages from compaction.
62    /// The effective anchor becomes `max(anchor, cached_prefix_len)`.
63    cached_prefix_len: usize,
64}
65
66impl SlidingWindowTransformer {
67    /// Create a new sliding window transformer.
68    ///
69    /// # Arguments
70    ///
71    /// * `normal_budget` - Token budget under normal operation.
72    /// * `overflow_budget` - Smaller token budget used when overflow is signaled.
73    /// * `anchor` - Number of messages at the start to always preserve.
74    #[must_use]
75    pub fn new(normal_budget: usize, overflow_budget: usize, anchor: usize) -> Self {
76        Self {
77            normal_budget,
78            overflow_budget,
79            anchor,
80            token_counter: None,
81            cached_prefix_len: 0,
82        }
83    }
84
85    #[must_use]
86    pub fn with_token_counter(mut self, counter: Arc<dyn TokenCounter>) -> Self {
87        self.token_counter = Some(counter);
88        self
89    }
90
91    /// Set the cached prefix length to protect from compaction.
92    ///
93    /// When caching is active, the effective anchor is `max(anchor, cached_prefix_len)`.
94    #[must_use]
95    pub const fn with_cached_prefix_len(mut self, len: usize) -> Self {
96        self.cached_prefix_len = len;
97        self
98    }
99
100    /// Update the cached prefix length (for runtime updates from the turn pipeline).
101    pub const fn set_cached_prefix_len(&mut self, len: usize) {
102        self.cached_prefix_len = len;
103    }
104}
105
106impl ContextTransformer for SlidingWindowTransformer {
107    fn transform(
108        &self,
109        messages: &mut Vec<AgentMessage>,
110        overflow: bool,
111    ) -> Option<CompactionReport> {
112        let budget = if overflow {
113            self.overflow_budget
114        } else {
115            self.normal_budget
116        };
117
118        let effective_anchor = self.anchor.max(self.cached_prefix_len);
119        let counter_ref = self.token_counter.as_deref();
120        let mut report =
121            compact_sliding_window_with(messages, budget, effective_anchor, counter_ref)?;
122        report.overflow = overflow;
123        Some(report)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::types::{ContentBlock, LlmMessage, UserMessage};
131
132    fn text_message(text: &str) -> AgentMessage {
133        AgentMessage::Llm(LlmMessage::User(UserMessage {
134            content: vec![ContentBlock::Text {
135                text: text.to_owned(),
136            }],
137            timestamp: 0,
138            cache_hint: None,
139        }))
140    }
141
142    #[test]
143    fn sliding_window_transformer_reports_dropped_messages() {
144        let transformer = SlidingWindowTransformer::new(250, 100, 1);
145        // Each message: 400 chars / 4 = 100 tokens
146        let body = "x".repeat(400);
147        let mut messages = vec![
148            text_message(&body),
149            text_message(&body),
150            text_message(&body),
151            text_message(&body),
152        ];
153
154        let report = transformer.transform(&mut messages, false);
155        assert!(report.is_some(), "should report compaction");
156        let report = report.unwrap();
157        assert_eq!(report.dropped_count, 2);
158        assert_eq!(report.tokens_before, 400);
159        assert!(report.tokens_after < report.tokens_before);
160        assert!(!report.overflow);
161        assert_eq!(messages.len(), 2);
162    }
163
164    #[test]
165    fn sliding_window_transformer_no_report_under_budget() {
166        let transformer = SlidingWindowTransformer::new(10_000, 5_000, 1);
167        let mut messages = vec![text_message("hello"), text_message("world")];
168
169        let report = transformer.transform(&mut messages, false);
170        assert!(report.is_none(), "should not report when under budget");
171        assert_eq!(messages.len(), 2);
172    }
173
174    #[test]
175    fn closure_blanket_impl_works() {
176        let closure = |msgs: &mut Vec<AgentMessage>, _overflow: bool| {
177            if msgs.len() > 2 {
178                msgs.truncate(2);
179            }
180        };
181
182        let mut messages = vec![
183            text_message("a"),
184            text_message("b"),
185            text_message("c"),
186            text_message("d"),
187        ];
188
189        let report = closure.transform(&mut messages, false);
190        assert!(report.is_some());
191        let report = report.unwrap();
192        assert_eq!(report.dropped_count, 2);
193        // Bare closures can't report token counts
194        assert_eq!(report.tokens_before, 0);
195        assert_eq!(report.tokens_after, 0);
196        assert_eq!(messages.len(), 2);
197    }
198
199    #[test]
200    fn overflow_uses_smaller_budget() {
201        let transformer = SlidingWindowTransformer::new(1000, 150, 1);
202        let body = "x".repeat(400);
203        let mut messages = vec![
204            text_message(&body),
205            text_message(&body),
206            text_message(&body),
207            text_message(&body),
208        ];
209
210        // Under normal budget (1000), total is 400 tokens -- no trim.
211        let report = transformer.transform(&mut messages, false);
212        assert!(report.is_none());
213        assert_eq!(messages.len(), 4);
214
215        // Under overflow budget (150), should trim.
216        let report = transformer.transform(&mut messages, true);
217        assert!(report.is_some());
218        let report = report.unwrap();
219        assert!(report.overflow);
220        assert!(messages.len() < 4);
221    }
222
223    #[test]
224    fn sliding_window_transformer_with_custom_counter() {
225        use crate::context::TokenCounter;
226
227        /// Counts every character as one token (4x the default heuristic).
228        struct CharCounter;
229
230        impl TokenCounter for CharCounter {
231            fn count_tokens(&self, message: &AgentMessage) -> usize {
232                match message {
233                    AgentMessage::Llm(llm) => {
234                        let blocks = match llm {
235                            LlmMessage::User(m) => &m.content,
236                            _ => return 0,
237                        };
238                        blocks
239                            .iter()
240                            .map(|b| match b {
241                                ContentBlock::Text { text } => text.len(),
242                                _ => 0,
243                            })
244                            .sum()
245                    }
246                    AgentMessage::Custom(_) => 50,
247                }
248            }
249        }
250
251        // Each message: 400 chars.
252        // Default counter: 400/4 = 100 tokens each.
253        // CharCounter: 400 tokens each.
254        let body = "x".repeat(400);
255
256        // With default counter, 4 * 100 = 400 tokens. Budget 500 => no trim.
257        let default_transformer = SlidingWindowTransformer::new(500, 250, 1);
258        let mut messages = vec![
259            text_message(&body),
260            text_message(&body),
261            text_message(&body),
262            text_message(&body),
263        ];
264        let report = default_transformer.transform(&mut messages, false);
265        assert!(
266            report.is_none(),
267            "default counter should not trim at budget 500"
268        );
269        assert_eq!(messages.len(), 4);
270
271        // With CharCounter, 4 * 400 = 1600 tokens. Budget 500 => trims.
272        let custom_transformer =
273            SlidingWindowTransformer::new(500, 250, 1).with_token_counter(Arc::new(CharCounter));
274        let mut messages = vec![
275            text_message(&body),
276            text_message(&body),
277            text_message(&body),
278            text_message(&body),
279        ];
280        let report = custom_transformer.transform(&mut messages, false);
281        assert!(report.is_some(), "char counter should trim at budget 500");
282        assert!(messages.len() < 4);
283    }
284}