Skip to main content

tirea_agent_loop/engine/
context_window.rs

1//! Context window management: policy-driven message truncation.
2//!
3//! When conversation history exceeds the model's context window budget,
4//! older messages are dropped while preserving tool-call/result pairs
5//! and a minimum number of recent messages.
6
7use crate::contracts::thread::{Message, Role};
8use crate::engine::token_estimator::{estimate_message_tokens, estimate_messages_tokens};
9
10// Re-export from tirea-contract (canonical location).
11pub use tirea_contract::runtime::inference::ContextWindowPolicy;
12
13/// Result of truncation.
14#[derive(Debug)]
15pub struct TruncationResult<'a> {
16    /// Messages to include in the inference request (system + kept history).
17    pub messages: Vec<&'a Message>,
18    /// Number of history messages dropped.
19    pub truncated_count: usize,
20    /// Estimated total tokens after truncation (system + history + tools).
21    pub estimated_total_tokens: usize,
22}
23
24/// Truncate conversation history to fit within the context window budget.
25///
26/// System messages are never truncated. History messages are dropped from the
27/// oldest end, but tool-call/result pairs are kept or dropped as a unit.
28///
29/// Returns references into the input slices.
30pub fn truncate_to_budget<'a>(
31    system_messages: &'a [Message],
32    history_messages: &'a [Message],
33    tool_tokens: usize,
34    policy: &ContextWindowPolicy,
35) -> TruncationResult<'a> {
36    let available = policy
37        .max_context_tokens
38        .saturating_sub(policy.max_output_tokens)
39        .saturating_sub(tool_tokens);
40
41    let system_tokens = estimate_messages_tokens(system_messages);
42    let history_budget = available.saturating_sub(system_tokens);
43
44    // Find a safe split point in history that fits the budget.
45    // We keep messages from the end (most recent first) and find how far
46    // back we can go.
47    let split = find_split_point(history_messages, history_budget, policy.min_recent_messages);
48
49    let kept = &history_messages[split..];
50    let kept_tokens = estimate_messages_tokens(kept);
51    let truncated_count = split;
52
53    let mut messages: Vec<&Message> = Vec::with_capacity(system_messages.len() + kept.len());
54    for msg in system_messages {
55        messages.push(msg);
56    }
57    for msg in kept {
58        messages.push(msg);
59    }
60
61    TruncationResult {
62        messages,
63        truncated_count,
64        estimated_total_tokens: system_tokens + kept_tokens + tool_tokens,
65    }
66}
67
68/// Find the index at which to split history: messages[split..] are kept.
69///
70/// Respects tool-call/result pair boundaries: if dropping a message would
71/// orphan a tool result (no matching assistant call) or orphan a tool call
72/// (no matching result), we adjust the split to keep the pair together.
73fn find_split_point(history: &[Message], budget_tokens: usize, min_recent: usize) -> usize {
74    if history.is_empty() {
75        return 0;
76    }
77
78    // Minimum: keep at least min_recent messages (or all if fewer).
79    let must_keep = min_recent.min(history.len());
80    let must_keep_start = history.len().saturating_sub(must_keep);
81
82    // Walk backward from the end, accumulating tokens.
83    let mut used_tokens = 0usize;
84    let mut candidate_split = history.len(); // start with keeping nothing
85
86    for i in (0..history.len()).rev() {
87        let msg_tokens = estimate_message_tokens(&history[i]);
88        let new_total = used_tokens + msg_tokens;
89
90        // If we're in the must-keep zone, always include.
91        if i >= must_keep_start {
92            used_tokens = new_total;
93            candidate_split = i;
94            continue;
95        }
96
97        // If adding this message exceeds budget, stop.
98        if new_total > budget_tokens {
99            break;
100        }
101
102        used_tokens = new_total;
103        candidate_split = i;
104    }
105
106    // Adjust split to respect tool-call/result pair boundaries.
107    // We don't want to keep a Tool result message without its preceding
108    // Assistant tool-call message, or vice versa.
109    adjust_split_for_tool_pairs(history, candidate_split)
110}
111
112/// Adjust the split point so that tool-call/result pairs are not broken.
113///
114/// If the first kept message is a Tool result, we need to also keep the
115/// preceding Assistant message that contains the matching tool call.
116/// If the last dropped message is an Assistant with tool calls, we need
117/// to also drop the following Tool result messages.
118fn adjust_split_for_tool_pairs(history: &[Message], mut split: usize) -> usize {
119    if split == 0 || split >= history.len() {
120        return split;
121    }
122
123    // If the first kept message is a Tool response, move split backward
124    // to include the paired Assistant message.
125    while split > 0 && history[split].role == Role::Tool {
126        split -= 1;
127    }
128
129    // If the last dropped message (split-1) is an Assistant with tool_calls,
130    // move split forward to drop its orphaned tool results too.
131    if split > 0 {
132        let last_dropped = &history[split - 1];
133        if last_dropped.role == Role::Assistant && last_dropped.tool_calls.is_some() {
134            // Drop the tool results that follow.
135            while split < history.len() && history[split].role == Role::Tool {
136                split += 1;
137            }
138        }
139    }
140
141    split
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::contracts::thread::ToolCall;
148    use serde_json::json;
149
150    fn user(content: &str) -> Message {
151        Message::user(content)
152    }
153
154    fn assistant(content: &str) -> Message {
155        Message::assistant(content)
156    }
157
158    fn assistant_with_calls(content: &str, calls: Vec<ToolCall>) -> Message {
159        Message::assistant_with_tool_calls(content, calls)
160    }
161
162    fn tool_result(call_id: &str, content: &str) -> Message {
163        Message::tool(call_id, content)
164    }
165
166    fn system(content: &str) -> Message {
167        Message::system(content)
168    }
169
170    #[test]
171    fn no_truncation_when_within_budget() {
172        let sys = vec![system("You are helpful.")];
173        let history = vec![user("Hi"), assistant("Hello!")];
174        let policy = ContextWindowPolicy {
175            max_context_tokens: 200_000,
176            max_output_tokens: 8_192,
177            ..Default::default()
178        };
179
180        let result = truncate_to_budget(&sys, &history, 0, &policy);
181        assert_eq!(result.truncated_count, 0);
182        assert_eq!(result.messages.len(), 3); // 1 system + 2 history
183    }
184
185    #[test]
186    fn truncation_drops_oldest_messages() {
187        let sys = vec![system("sys")];
188        let history: Vec<Message> = (0..100)
189            .map(|i| {
190                if i % 2 == 0 {
191                    user(&format!("message {i}"))
192                } else {
193                    assistant(&format!("response {i}"))
194                }
195            })
196            .collect();
197
198        let policy = ContextWindowPolicy {
199            max_context_tokens: 200, // very tight budget
200            max_output_tokens: 50,
201            min_recent_messages: 4,
202            ..Default::default()
203        };
204
205        let result = truncate_to_budget(&sys, &history, 10, &policy);
206        assert!(result.truncated_count > 0);
207        // Must keep at least min_recent_messages
208        let kept_history = result.messages.len() - 1; // minus system
209        assert!(kept_history >= 4);
210    }
211
212    #[test]
213    fn tool_pair_not_broken() {
214        let sys = vec![system("sys")];
215        let history = vec![
216            user("Do something"),
217            assistant_with_calls(
218                "Using tool",
219                vec![ToolCall::new("c1", "search", json!({"q": "x"}))],
220            ),
221            tool_result("c1", "found it"),
222            assistant("Here is the answer."),
223            user("Thanks"),
224            assistant("You're welcome!"),
225        ];
226
227        // Budget that can fit ~3-4 messages but not all 6.
228        let policy = ContextWindowPolicy {
229            max_context_tokens: 120,
230            max_output_tokens: 30,
231            min_recent_messages: 2,
232            ..Default::default()
233        };
234
235        let result = truncate_to_budget(&sys, &history, 10, &policy);
236
237        // Check that no Tool message is the first kept history message
238        // (which would mean its paired Assistant was dropped).
239        let kept_history: Vec<_> = result.messages.iter().skip(1).collect();
240        if !kept_history.is_empty() {
241            assert_ne!(
242                kept_history[0].role,
243                Role::Tool,
244                "First kept history message should not be an orphaned tool result"
245            );
246        }
247    }
248
249    #[test]
250    fn min_recent_always_preserved() {
251        let sys = vec![system("sys")];
252        let history: Vec<Message> = (0..20).map(|i| user(&format!("msg {i}"))).collect();
253
254        let policy = ContextWindowPolicy {
255            max_context_tokens: 50, // impossibly tight
256            max_output_tokens: 10,
257            min_recent_messages: 5,
258            ..Default::default()
259        };
260
261        let result = truncate_to_budget(&sys, &history, 0, &policy);
262        let kept_history = result.messages.len() - 1;
263        assert!(kept_history >= 5, "must keep at least min_recent_messages");
264    }
265
266    #[test]
267    fn adjust_split_moves_back_for_orphaned_tool_result() {
268        let history = vec![
269            user("a"),                                                            // 0
270            assistant_with_calls("b", vec![ToolCall::new("c1", "t", json!({}))]), // 1
271            tool_result("c1", "r"),                                               // 2
272            user("c"),                                                            // 3
273        ];
274
275        // If naive split is 2 (keep [2,3]), tool result at 2 is orphaned.
276        let adjusted = adjust_split_for_tool_pairs(&history, 2);
277        assert_eq!(adjusted, 1, "should include the assistant with tool calls");
278    }
279
280    #[test]
281    fn adjust_split_drops_orphaned_tool_results() {
282        let history = vec![
283            user("a"),                                                            // 0
284            assistant_with_calls("b", vec![ToolCall::new("c1", "t", json!({}))]), // 1
285            tool_result("c1", "r"),                                               // 2
286            user("c"),                                                            // 3
287        ];
288
289        // If naive split is 2 (keep [2,3]), adjust backward to 1.
290        // But if split is at index 2 (a Tool), adjust moves to 1.
291        let adjusted = adjust_split_for_tool_pairs(&history, 2);
292        assert_eq!(adjusted, 1);
293    }
294
295    #[test]
296    fn empty_history() {
297        let sys = vec![system("sys")];
298        let history: Vec<Message> = vec![];
299        let policy = ContextWindowPolicy::default();
300
301        let result = truncate_to_budget(&sys, &history, 0, &policy);
302        assert_eq!(result.truncated_count, 0);
303        assert_eq!(result.messages.len(), 1);
304    }
305
306    #[test]
307    fn adjust_split_handles_multiple_consecutive_tool_results() {
308        // Assistant with 2 tool calls followed by 2 tool results
309        let history = vec![
310            user("start"), // 0
311            assistant_with_calls(
312                "calling two tools",
313                vec![
314                    ToolCall::new("c1", "t1", json!({})),
315                    ToolCall::new("c2", "t2", json!({})),
316                ],
317            ), // 1
318            tool_result("c1", "result1"), // 2
319            tool_result("c2", "result2"), // 3
320            user("continue"), // 4
321        ];
322
323        // Naive split at 2 (first kept = tool result) → should move back to 1
324        let adjusted = adjust_split_for_tool_pairs(&history, 2);
325        assert_eq!(adjusted, 1, "should include assistant with both tool calls");
326
327        // Naive split at 3 (first kept = tool result) → should move back to 1
328        let adjusted = adjust_split_for_tool_pairs(&history, 3);
329        assert_eq!(
330            adjusted, 1,
331            "should walk back through all consecutive tool results"
332        );
333    }
334
335    #[test]
336    fn adjust_split_drops_orphaned_results_after_dropped_assistant() {
337        // When split=2 and history[1] is an assistant with tool_calls,
338        // the tool results at [2],[3] become orphaned and should be dropped too.
339        let history = vec![
340            user("start"),                                                               // 0
341            assistant_with_calls("calling", vec![ToolCall::new("c1", "t1", json!({}))]), // 1
342            tool_result("c1", "result"),                                                 // 2
343            user("next question"),                                                       // 3
344            assistant("answer"),                                                         // 4
345        ];
346
347        // Naive split at 3 means we keep [3,4]. Last dropped = history[2] which is
348        // a tool result, not an assistant. Split at 3 is fine here since history[3]
349        // is a user message.
350        let adjusted = adjust_split_for_tool_pairs(&history, 3);
351        assert_eq!(adjusted, 3, "split at user boundary should be stable");
352    }
353
354    #[test]
355    fn all_system_messages_preserved_with_empty_history() {
356        let sys = vec![
357            system("system line 1"),
358            system("system line 2"),
359            system("system line 3"),
360        ];
361        let history: Vec<Message> = vec![];
362        let policy = ContextWindowPolicy {
363            max_context_tokens: 100,
364            max_output_tokens: 10,
365            min_recent_messages: 5,
366            ..Default::default()
367        };
368
369        let result = truncate_to_budget(&sys, &history, 0, &policy);
370        assert_eq!(result.messages.len(), 3, "all system messages preserved");
371        assert_eq!(result.truncated_count, 0);
372    }
373
374    #[test]
375    fn tool_tokens_reduce_available_budget() {
376        let sys = vec![system("sys")];
377        let history: Vec<Message> = (0..50)
378            .map(|i| user(&format!("message {i} with some extra content padding")))
379            .collect();
380
381        let policy = ContextWindowPolicy {
382            max_context_tokens: 500,
383            max_output_tokens: 100,
384            min_recent_messages: 2,
385            ..Default::default()
386        };
387
388        let result_no_tools = truncate_to_budget(&sys, &history, 0, &policy);
389        let result_with_tools = truncate_to_budget(&sys, &history, 200, &policy);
390
391        assert!(
392            result_with_tools.truncated_count > result_no_tools.truncated_count,
393            "tool token overhead should cause more truncation"
394        );
395    }
396
397    #[test]
398    fn default_policy_values() {
399        let p = ContextWindowPolicy::default();
400        assert_eq!(p.max_context_tokens, 200_000);
401        assert_eq!(p.max_output_tokens, 16_384);
402        assert_eq!(p.min_recent_messages, 10);
403        assert!(p.enable_prompt_cache);
404    }
405}