Skip to main content

saorsa_agent/context/
compaction.rs

1//! Context compaction strategies for managing conversation token limits.
2
3use saorsa_ai::message::Message;
4use saorsa_ai::tokens::{estimate_conversation_tokens, estimate_message_tokens};
5
6/// Strategy for compacting conversation history.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum CompactionStrategy {
9    /// Remove oldest messages first, preserving recent and system messages.
10    TruncateOldest,
11    /// Summarize blocks of messages (not yet implemented).
12    SummarizeBlocks,
13    /// Hybrid approach: truncate old, summarize middle blocks.
14    Hybrid,
15}
16
17/// Configuration for context compaction.
18#[derive(Debug, Clone)]
19pub struct CompactionConfig {
20    /// Maximum tokens to keep after compaction.
21    pub max_tokens: u32,
22    /// Number of most recent messages to always preserve.
23    pub preserve_recent_count: usize,
24    /// Compaction strategy to use.
25    pub strategy: CompactionStrategy,
26}
27
28impl Default for CompactionConfig {
29    fn default() -> Self {
30        Self {
31            max_tokens: 100_000,
32            preserve_recent_count: 5,
33            strategy: CompactionStrategy::TruncateOldest,
34        }
35    }
36}
37
38/// Statistics from a compaction operation.
39#[derive(Debug, Clone)]
40pub struct CompactionStats {
41    /// Original token count before compaction.
42    pub original_tokens: u32,
43    /// Token count after compaction.
44    pub compacted_tokens: u32,
45    /// Number of messages removed.
46    pub messages_removed: usize,
47}
48
49/// Compact a conversation history according to the given configuration.
50///
51/// System messages and the most recent N messages are always preserved.
52/// Returns the compacted message list and statistics.
53pub fn compact(
54    messages: &[Message],
55    system: Option<&str>,
56    config: &CompactionConfig,
57) -> (Vec<Message>, CompactionStats) {
58    let original_tokens = estimate_conversation_tokens(messages, system);
59
60    // If we're already under the limit, no compaction needed
61    if original_tokens <= config.max_tokens {
62        return (
63            messages.to_vec(),
64            CompactionStats {
65                original_tokens,
66                compacted_tokens: original_tokens,
67                messages_removed: 0,
68            },
69        );
70    }
71
72    match config.strategy {
73        CompactionStrategy::TruncateOldest => {
74            truncate_oldest(messages, system, config, original_tokens)
75        }
76        CompactionStrategy::SummarizeBlocks | CompactionStrategy::Hybrid => {
77            // For now, fall back to truncate (summarization not implemented)
78            truncate_oldest(messages, system, config, original_tokens)
79        }
80    }
81}
82
83/// Truncate oldest messages first, preserving recent messages.
84///
85/// Note: saorsa-ai Message doesn't have a "system" role - system prompts are separate.
86/// This function preserves recent messages and fits as many older messages as possible.
87fn truncate_oldest(
88    messages: &[Message],
89    system: Option<&str>,
90    config: &CompactionConfig,
91    original_tokens: u32,
92) -> (Vec<Message>, CompactionStats) {
93    let system_tokens = system.map_or(0, saorsa_ai::tokens::estimate_tokens);
94
95    // All messages are either User or Assistant (no system role in Message)
96    let non_system = messages;
97
98    let recent_start = non_system
99        .len()
100        .saturating_sub(config.preserve_recent_count);
101    let old_messages = &non_system[..recent_start];
102    let recent_messages = &non_system[recent_start..];
103
104    // Calculate tokens for recent messages
105    let recent_tokens: u32 = recent_messages.iter().map(estimate_message_tokens).sum();
106
107    // Available tokens for old messages
108    let available_for_old = config
109        .max_tokens
110        .saturating_sub(system_tokens)
111        .saturating_sub(recent_tokens);
112
113    // Keep as many old messages as fit
114    let mut kept_old = Vec::new();
115    let mut current_tokens = 0u32;
116
117    for msg in old_messages.iter().rev() {
118        let msg_tokens = estimate_message_tokens(msg);
119        if current_tokens + msg_tokens <= available_for_old {
120            kept_old.push((*msg).clone());
121            current_tokens += msg_tokens;
122        } else {
123            break;
124        }
125    }
126    kept_old.reverse();
127
128    // Reconstruct message list: kept_old + recent
129    let mut result = Vec::new();
130    result.extend(kept_old);
131    result.extend(recent_messages.iter().map(|m| (*m).clone()));
132
133    let compacted_tokens = estimate_conversation_tokens(&result, system);
134    let messages_removed = messages.len() - result.len();
135
136    (
137        result,
138        CompactionStats {
139            original_tokens,
140            compacted_tokens,
141            messages_removed,
142        },
143    )
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use saorsa_ai::message::{Message, Role};
150
151    fn make_message(role: &str, text: &str) -> Message {
152        match role {
153            "user" => Message::user(text),
154            "assistant" => Message::assistant(text),
155            _ => unreachable!("Invalid role"),
156        }
157    }
158
159    #[test]
160    fn test_no_compaction_when_under_limit() {
161        let messages = vec![
162            make_message("user", "Hello"),
163            make_message("assistant", "Hi"),
164        ];
165        let config = CompactionConfig {
166            max_tokens: 100_000,
167            ..Default::default()
168        };
169
170        let (compacted, stats) = compact(&messages, None, &config);
171
172        assert_eq!(compacted.len(), messages.len());
173        assert_eq!(stats.messages_removed, 0);
174        assert_eq!(stats.original_tokens, stats.compacted_tokens);
175    }
176
177    #[test]
178    fn test_truncate_oldest_removes_old_messages() {
179        let large_text = "x".repeat(1000);
180        let messages = vec![
181            make_message("user", &large_text),
182            make_message("assistant", &large_text),
183            make_message("user", &large_text),
184            make_message("assistant", &large_text),
185            make_message("user", "Recent message"),
186            make_message("assistant", "Recent response"),
187        ];
188        let config = CompactionConfig {
189            max_tokens: 100, // Low limit to force removal of large old messages
190            preserve_recent_count: 2,
191            strategy: CompactionStrategy::TruncateOldest,
192        };
193
194        let (compacted, stats) = compact(&messages, None, &config);
195
196        // Should preserve at least the recent 2
197        assert!(compacted.len() >= 2);
198        assert!(stats.messages_removed > 0);
199        assert!(stats.compacted_tokens <= config.max_tokens);
200    }
201
202    #[test]
203    fn test_recent_messages_always_preserved() {
204        let large_text = "a".repeat(1000);
205        let messages = vec![
206            make_message("user", &large_text), // Large old message
207            make_message("assistant", "Old response"),
208            make_message("user", "Recent 1"),
209            make_message("assistant", "Recent 2"),
210        ];
211        let config = CompactionConfig {
212            max_tokens: 100,
213            preserve_recent_count: 2,
214            strategy: CompactionStrategy::TruncateOldest,
215        };
216
217        let (compacted, _stats) = compact(&messages, None, &config);
218
219        // Last 2 messages should always be there
220        assert!(compacted.len() >= 2);
221        let last_two = &compacted[compacted.len() - 2..];
222        assert_eq!(last_two[0].role, Role::User);
223        assert_eq!(last_two[1].role, Role::Assistant);
224    }
225
226    #[test]
227    fn test_compaction_with_system_prompt() {
228        let large_text = "a".repeat(1000);
229        let messages = vec![
230            make_message("user", &large_text),
231            make_message("assistant", "Response"),
232        ];
233        let system = Some("System prompt here");
234        let config = CompactionConfig {
235            max_tokens: 100,
236            preserve_recent_count: 1,
237            strategy: CompactionStrategy::TruncateOldest,
238        };
239
240        let (_compacted, stats) = compact(&messages, system, &config);
241
242        // Should compact while accounting for system prompt tokens
243        assert!(stats.compacted_tokens <= config.max_tokens);
244    }
245
246    #[test]
247    fn test_compaction_achieves_target() {
248        let a_text = "a".repeat(1000);
249        let b_text = "b".repeat(1000);
250        let c_text = "c".repeat(1000);
251        let d_text = "d".repeat(1000);
252
253        let messages = vec![
254            make_message("user", &a_text),
255            make_message("assistant", &b_text),
256            make_message("user", &c_text),
257            make_message("assistant", &d_text),
258            make_message("user", "Recent"),
259        ];
260        let config = CompactionConfig {
261            max_tokens: 100,
262            preserve_recent_count: 1,
263            strategy: CompactionStrategy::TruncateOldest,
264        };
265
266        let (compacted, stats) = compact(&messages, None, &config);
267
268        // Should be significantly reduced
269        assert!(stats.compacted_tokens <= config.max_tokens);
270        assert!(stats.messages_removed > 0);
271        assert!(compacted.len() < messages.len());
272    }
273
274    #[test]
275    fn test_statistics_tracked_correctly() {
276        let messages = vec![
277            make_message("user", "Message 1"),
278            make_message("assistant", "Response 1"),
279            make_message("user", "Message 2"),
280        ];
281        let config = CompactionConfig {
282            max_tokens: 20,
283            preserve_recent_count: 1,
284            strategy: CompactionStrategy::TruncateOldest,
285        };
286
287        let (compacted, stats) = compact(&messages, None, &config);
288
289        assert_eq!(stats.messages_removed, messages.len() - compacted.len());
290        assert!(stats.original_tokens > 0);
291        assert!(stats.compacted_tokens > 0);
292        assert!(stats.compacted_tokens <= stats.original_tokens);
293    }
294
295    #[test]
296    fn test_default_config() {
297        let config = CompactionConfig::default();
298        assert_eq!(config.max_tokens, 100_000);
299        assert_eq!(config.preserve_recent_count, 5);
300        assert_eq!(config.strategy, CompactionStrategy::TruncateOldest);
301    }
302}