1use saorsa_ai::message::Message;
4use saorsa_ai::tokens::{estimate_conversation_tokens, estimate_message_tokens};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum CompactionStrategy {
9 TruncateOldest,
11 SummarizeBlocks,
13 Hybrid,
15}
16
17#[derive(Debug, Clone)]
19pub struct CompactionConfig {
20 pub max_tokens: u32,
22 pub preserve_recent_count: usize,
24 pub strategy: CompactionStrategy,
26}
27
28impl Default for CompactionConfig {
29 fn default() -> Self {
30 Self {
31 max_tokens: 100_000,
32 preserve_recent_count: 5,
33 strategy: CompactionStrategy::TruncateOldest,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct CompactionStats {
41 pub original_tokens: u32,
43 pub compacted_tokens: u32,
45 pub messages_removed: usize,
47}
48
49pub fn compact(
54 messages: &[Message],
55 system: Option<&str>,
56 config: &CompactionConfig,
57) -> (Vec<Message>, CompactionStats) {
58 let original_tokens = estimate_conversation_tokens(messages, system);
59
60 if original_tokens <= config.max_tokens {
62 return (
63 messages.to_vec(),
64 CompactionStats {
65 original_tokens,
66 compacted_tokens: original_tokens,
67 messages_removed: 0,
68 },
69 );
70 }
71
72 match config.strategy {
73 CompactionStrategy::TruncateOldest => {
74 truncate_oldest(messages, system, config, original_tokens)
75 }
76 CompactionStrategy::SummarizeBlocks | CompactionStrategy::Hybrid => {
77 truncate_oldest(messages, system, config, original_tokens)
79 }
80 }
81}
82
83fn truncate_oldest(
88 messages: &[Message],
89 system: Option<&str>,
90 config: &CompactionConfig,
91 original_tokens: u32,
92) -> (Vec<Message>, CompactionStats) {
93 let system_tokens = system.map_or(0, saorsa_ai::tokens::estimate_tokens);
94
95 let non_system = messages;
97
98 let recent_start = non_system
99 .len()
100 .saturating_sub(config.preserve_recent_count);
101 let old_messages = &non_system[..recent_start];
102 let recent_messages = &non_system[recent_start..];
103
104 let recent_tokens: u32 = recent_messages.iter().map(estimate_message_tokens).sum();
106
107 let available_for_old = config
109 .max_tokens
110 .saturating_sub(system_tokens)
111 .saturating_sub(recent_tokens);
112
113 let mut kept_old = Vec::new();
115 let mut current_tokens = 0u32;
116
117 for msg in old_messages.iter().rev() {
118 let msg_tokens = estimate_message_tokens(msg);
119 if current_tokens + msg_tokens <= available_for_old {
120 kept_old.push((*msg).clone());
121 current_tokens += msg_tokens;
122 } else {
123 break;
124 }
125 }
126 kept_old.reverse();
127
128 let mut result = Vec::new();
130 result.extend(kept_old);
131 result.extend(recent_messages.iter().map(|m| (*m).clone()));
132
133 let compacted_tokens = estimate_conversation_tokens(&result, system);
134 let messages_removed = messages.len() - result.len();
135
136 (
137 result,
138 CompactionStats {
139 original_tokens,
140 compacted_tokens,
141 messages_removed,
142 },
143 )
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use saorsa_ai::message::{Message, Role};
150
151 fn make_message(role: &str, text: &str) -> Message {
152 match role {
153 "user" => Message::user(text),
154 "assistant" => Message::assistant(text),
155 _ => unreachable!("Invalid role"),
156 }
157 }
158
159 #[test]
160 fn test_no_compaction_when_under_limit() {
161 let messages = vec![
162 make_message("user", "Hello"),
163 make_message("assistant", "Hi"),
164 ];
165 let config = CompactionConfig {
166 max_tokens: 100_000,
167 ..Default::default()
168 };
169
170 let (compacted, stats) = compact(&messages, None, &config);
171
172 assert_eq!(compacted.len(), messages.len());
173 assert_eq!(stats.messages_removed, 0);
174 assert_eq!(stats.original_tokens, stats.compacted_tokens);
175 }
176
177 #[test]
178 fn test_truncate_oldest_removes_old_messages() {
179 let large_text = "x".repeat(1000);
180 let messages = vec![
181 make_message("user", &large_text),
182 make_message("assistant", &large_text),
183 make_message("user", &large_text),
184 make_message("assistant", &large_text),
185 make_message("user", "Recent message"),
186 make_message("assistant", "Recent response"),
187 ];
188 let config = CompactionConfig {
189 max_tokens: 100, preserve_recent_count: 2,
191 strategy: CompactionStrategy::TruncateOldest,
192 };
193
194 let (compacted, stats) = compact(&messages, None, &config);
195
196 assert!(compacted.len() >= 2);
198 assert!(stats.messages_removed > 0);
199 assert!(stats.compacted_tokens <= config.max_tokens);
200 }
201
202 #[test]
203 fn test_recent_messages_always_preserved() {
204 let large_text = "a".repeat(1000);
205 let messages = vec![
206 make_message("user", &large_text), make_message("assistant", "Old response"),
208 make_message("user", "Recent 1"),
209 make_message("assistant", "Recent 2"),
210 ];
211 let config = CompactionConfig {
212 max_tokens: 100,
213 preserve_recent_count: 2,
214 strategy: CompactionStrategy::TruncateOldest,
215 };
216
217 let (compacted, _stats) = compact(&messages, None, &config);
218
219 assert!(compacted.len() >= 2);
221 let last_two = &compacted[compacted.len() - 2..];
222 assert_eq!(last_two[0].role, Role::User);
223 assert_eq!(last_two[1].role, Role::Assistant);
224 }
225
226 #[test]
227 fn test_compaction_with_system_prompt() {
228 let large_text = "a".repeat(1000);
229 let messages = vec![
230 make_message("user", &large_text),
231 make_message("assistant", "Response"),
232 ];
233 let system = Some("System prompt here");
234 let config = CompactionConfig {
235 max_tokens: 100,
236 preserve_recent_count: 1,
237 strategy: CompactionStrategy::TruncateOldest,
238 };
239
240 let (_compacted, stats) = compact(&messages, system, &config);
241
242 assert!(stats.compacted_tokens <= config.max_tokens);
244 }
245
246 #[test]
247 fn test_compaction_achieves_target() {
248 let a_text = "a".repeat(1000);
249 let b_text = "b".repeat(1000);
250 let c_text = "c".repeat(1000);
251 let d_text = "d".repeat(1000);
252
253 let messages = vec![
254 make_message("user", &a_text),
255 make_message("assistant", &b_text),
256 make_message("user", &c_text),
257 make_message("assistant", &d_text),
258 make_message("user", "Recent"),
259 ];
260 let config = CompactionConfig {
261 max_tokens: 100,
262 preserve_recent_count: 1,
263 strategy: CompactionStrategy::TruncateOldest,
264 };
265
266 let (compacted, stats) = compact(&messages, None, &config);
267
268 assert!(stats.compacted_tokens <= config.max_tokens);
270 assert!(stats.messages_removed > 0);
271 assert!(compacted.len() < messages.len());
272 }
273
274 #[test]
275 fn test_statistics_tracked_correctly() {
276 let messages = vec![
277 make_message("user", "Message 1"),
278 make_message("assistant", "Response 1"),
279 make_message("user", "Message 2"),
280 ];
281 let config = CompactionConfig {
282 max_tokens: 20,
283 preserve_recent_count: 1,
284 strategy: CompactionStrategy::TruncateOldest,
285 };
286
287 let (compacted, stats) = compact(&messages, None, &config);
288
289 assert_eq!(stats.messages_removed, messages.len() - compacted.len());
290 assert!(stats.original_tokens > 0);
291 assert!(stats.compacted_tokens > 0);
292 assert!(stats.compacted_tokens <= stats.original_tokens);
293 }
294
295 #[test]
296 fn test_default_config() {
297 let config = CompactionConfig::default();
298 assert_eq!(config.max_tokens, 100_000);
299 assert_eq!(config.preserve_recent_count, 5);
300 assert_eq!(config.strategy, CompactionStrategy::TruncateOldest);
301 }
302}