1use crate::contracts::thread::{Message, Role};
8use crate::engine::token_estimator::{estimate_message_tokens, estimate_messages_tokens};
9
10pub use tirea_contract::runtime::inference::ContextWindowPolicy;
12
13#[derive(Debug)]
15pub struct TruncationResult<'a> {
16 pub messages: Vec<&'a Message>,
18 pub truncated_count: usize,
20 pub estimated_total_tokens: usize,
22}
23
24pub fn truncate_to_budget<'a>(
31 system_messages: &'a [Message],
32 history_messages: &'a [Message],
33 tool_tokens: usize,
34 policy: &ContextWindowPolicy,
35) -> TruncationResult<'a> {
36 let available = policy
37 .max_context_tokens
38 .saturating_sub(policy.max_output_tokens)
39 .saturating_sub(tool_tokens);
40
41 let system_tokens = estimate_messages_tokens(system_messages);
42 let history_budget = available.saturating_sub(system_tokens);
43
44 let split = find_split_point(history_messages, history_budget, policy.min_recent_messages);
48
49 let kept = &history_messages[split..];
50 let kept_tokens = estimate_messages_tokens(kept);
51 let truncated_count = split;
52
53 let mut messages: Vec<&Message> = Vec::with_capacity(system_messages.len() + kept.len());
54 for msg in system_messages {
55 messages.push(msg);
56 }
57 for msg in kept {
58 messages.push(msg);
59 }
60
61 TruncationResult {
62 messages,
63 truncated_count,
64 estimated_total_tokens: system_tokens + kept_tokens + tool_tokens,
65 }
66}
67
68fn find_split_point(history: &[Message], budget_tokens: usize, min_recent: usize) -> usize {
74 if history.is_empty() {
75 return 0;
76 }
77
78 let must_keep = min_recent.min(history.len());
80 let must_keep_start = history.len().saturating_sub(must_keep);
81
82 let mut used_tokens = 0usize;
84 let mut candidate_split = history.len(); for i in (0..history.len()).rev() {
87 let msg_tokens = estimate_message_tokens(&history[i]);
88 let new_total = used_tokens + msg_tokens;
89
90 if i >= must_keep_start {
92 used_tokens = new_total;
93 candidate_split = i;
94 continue;
95 }
96
97 if new_total > budget_tokens {
99 break;
100 }
101
102 used_tokens = new_total;
103 candidate_split = i;
104 }
105
106 adjust_split_for_tool_pairs(history, candidate_split)
110}
111
112fn adjust_split_for_tool_pairs(history: &[Message], mut split: usize) -> usize {
119 if split == 0 || split >= history.len() {
120 return split;
121 }
122
123 while split > 0 && history[split].role == Role::Tool {
126 split -= 1;
127 }
128
129 if split > 0 {
132 let last_dropped = &history[split - 1];
133 if last_dropped.role == Role::Assistant && last_dropped.tool_calls.is_some() {
134 while split < history.len() && history[split].role == Role::Tool {
136 split += 1;
137 }
138 }
139 }
140
141 split
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::contracts::thread::ToolCall;
148 use serde_json::json;
149
150 fn user(content: &str) -> Message {
151 Message::user(content)
152 }
153
154 fn assistant(content: &str) -> Message {
155 Message::assistant(content)
156 }
157
158 fn assistant_with_calls(content: &str, calls: Vec<ToolCall>) -> Message {
159 Message::assistant_with_tool_calls(content, calls)
160 }
161
162 fn tool_result(call_id: &str, content: &str) -> Message {
163 Message::tool(call_id, content)
164 }
165
166 fn system(content: &str) -> Message {
167 Message::system(content)
168 }
169
170 #[test]
171 fn no_truncation_when_within_budget() {
172 let sys = vec![system("You are helpful.")];
173 let history = vec![user("Hi"), assistant("Hello!")];
174 let policy = ContextWindowPolicy {
175 max_context_tokens: 200_000,
176 max_output_tokens: 8_192,
177 ..Default::default()
178 };
179
180 let result = truncate_to_budget(&sys, &history, 0, &policy);
181 assert_eq!(result.truncated_count, 0);
182 assert_eq!(result.messages.len(), 3); }
184
185 #[test]
186 fn truncation_drops_oldest_messages() {
187 let sys = vec![system("sys")];
188 let history: Vec<Message> = (0..100)
189 .map(|i| {
190 if i % 2 == 0 {
191 user(&format!("message {i}"))
192 } else {
193 assistant(&format!("response {i}"))
194 }
195 })
196 .collect();
197
198 let policy = ContextWindowPolicy {
199 max_context_tokens: 200, max_output_tokens: 50,
201 min_recent_messages: 4,
202 ..Default::default()
203 };
204
205 let result = truncate_to_budget(&sys, &history, 10, &policy);
206 assert!(result.truncated_count > 0);
207 let kept_history = result.messages.len() - 1; assert!(kept_history >= 4);
210 }
211
212 #[test]
213 fn tool_pair_not_broken() {
214 let sys = vec![system("sys")];
215 let history = vec![
216 user("Do something"),
217 assistant_with_calls(
218 "Using tool",
219 vec![ToolCall::new("c1", "search", json!({"q": "x"}))],
220 ),
221 tool_result("c1", "found it"),
222 assistant("Here is the answer."),
223 user("Thanks"),
224 assistant("You're welcome!"),
225 ];
226
227 let policy = ContextWindowPolicy {
229 max_context_tokens: 120,
230 max_output_tokens: 30,
231 min_recent_messages: 2,
232 ..Default::default()
233 };
234
235 let result = truncate_to_budget(&sys, &history, 10, &policy);
236
237 let kept_history: Vec<_> = result.messages.iter().skip(1).collect();
240 if !kept_history.is_empty() {
241 assert_ne!(
242 kept_history[0].role,
243 Role::Tool,
244 "First kept history message should not be an orphaned tool result"
245 );
246 }
247 }
248
249 #[test]
250 fn min_recent_always_preserved() {
251 let sys = vec![system("sys")];
252 let history: Vec<Message> = (0..20).map(|i| user(&format!("msg {i}"))).collect();
253
254 let policy = ContextWindowPolicy {
255 max_context_tokens: 50, max_output_tokens: 10,
257 min_recent_messages: 5,
258 ..Default::default()
259 };
260
261 let result = truncate_to_budget(&sys, &history, 0, &policy);
262 let kept_history = result.messages.len() - 1;
263 assert!(kept_history >= 5, "must keep at least min_recent_messages");
264 }
265
266 #[test]
267 fn adjust_split_moves_back_for_orphaned_tool_result() {
268 let history = vec![
269 user("a"), assistant_with_calls("b", vec![ToolCall::new("c1", "t", json!({}))]), tool_result("c1", "r"), user("c"), ];
274
275 let adjusted = adjust_split_for_tool_pairs(&history, 2);
277 assert_eq!(adjusted, 1, "should include the assistant with tool calls");
278 }
279
280 #[test]
281 fn adjust_split_drops_orphaned_tool_results() {
282 let history = vec![
283 user("a"), assistant_with_calls("b", vec![ToolCall::new("c1", "t", json!({}))]), tool_result("c1", "r"), user("c"), ];
288
289 let adjusted = adjust_split_for_tool_pairs(&history, 2);
292 assert_eq!(adjusted, 1);
293 }
294
295 #[test]
296 fn empty_history() {
297 let sys = vec![system("sys")];
298 let history: Vec<Message> = vec![];
299 let policy = ContextWindowPolicy::default();
300
301 let result = truncate_to_budget(&sys, &history, 0, &policy);
302 assert_eq!(result.truncated_count, 0);
303 assert_eq!(result.messages.len(), 1);
304 }
305
306 #[test]
307 fn adjust_split_handles_multiple_consecutive_tool_results() {
308 let history = vec![
310 user("start"), assistant_with_calls(
312 "calling two tools",
313 vec![
314 ToolCall::new("c1", "t1", json!({})),
315 ToolCall::new("c2", "t2", json!({})),
316 ],
317 ), tool_result("c1", "result1"), tool_result("c2", "result2"), user("continue"), ];
322
323 let adjusted = adjust_split_for_tool_pairs(&history, 2);
325 assert_eq!(adjusted, 1, "should include assistant with both tool calls");
326
327 let adjusted = adjust_split_for_tool_pairs(&history, 3);
329 assert_eq!(
330 adjusted, 1,
331 "should walk back through all consecutive tool results"
332 );
333 }
334
335 #[test]
336 fn adjust_split_drops_orphaned_results_after_dropped_assistant() {
337 let history = vec![
340 user("start"), assistant_with_calls("calling", vec![ToolCall::new("c1", "t1", json!({}))]), tool_result("c1", "result"), user("next question"), assistant("answer"), ];
346
347 let adjusted = adjust_split_for_tool_pairs(&history, 3);
351 assert_eq!(adjusted, 3, "split at user boundary should be stable");
352 }
353
354 #[test]
355 fn all_system_messages_preserved_with_empty_history() {
356 let sys = vec![
357 system("system line 1"),
358 system("system line 2"),
359 system("system line 3"),
360 ];
361 let history: Vec<Message> = vec![];
362 let policy = ContextWindowPolicy {
363 max_context_tokens: 100,
364 max_output_tokens: 10,
365 min_recent_messages: 5,
366 ..Default::default()
367 };
368
369 let result = truncate_to_budget(&sys, &history, 0, &policy);
370 assert_eq!(result.messages.len(), 3, "all system messages preserved");
371 assert_eq!(result.truncated_count, 0);
372 }
373
374 #[test]
375 fn tool_tokens_reduce_available_budget() {
376 let sys = vec![system("sys")];
377 let history: Vec<Message> = (0..50)
378 .map(|i| user(&format!("message {i} with some extra content padding")))
379 .collect();
380
381 let policy = ContextWindowPolicy {
382 max_context_tokens: 500,
383 max_output_tokens: 100,
384 min_recent_messages: 2,
385 ..Default::default()
386 };
387
388 let result_no_tools = truncate_to_budget(&sys, &history, 0, &policy);
389 let result_with_tools = truncate_to_budget(&sys, &history, 200, &policy);
390
391 assert!(
392 result_with_tools.truncated_count > result_no_tools.truncated_count,
393 "tool token overhead should cause more truncation"
394 );
395 }
396
397 #[test]
398 fn default_policy_values() {
399 let p = ContextWindowPolicy::default();
400 assert_eq!(p.max_context_tokens, 200_000);
401 assert_eq!(p.max_output_tokens, 16_384);
402 assert_eq!(p.min_recent_messages, 10);
403 assert!(p.enable_prompt_cache);
404 }
405}