stakpak_agent_core/
budget_context.rs1use crate::context::{
2 ContextReducer, dedup_tool_results, merge_consecutive_same_role, remove_orphaned_tool_results,
3 strip_dangling_tool_calls,
4};
5use stakai::{ContentPart, Message, MessageContent, Model, Role, Tool};
6
7const TRIMMED_CONTENT_PLACEHOLDER: &str = "[trimmed]";
8const BYTES_PER_TOKEN: f64 = 3.5;
9const SAFETY_BUFFER_FACTOR: f64 = 1.05;
10const TRIM_HEADROOM_FACTOR: f64 = 0.75;
14
15#[derive(Debug, Clone)]
16pub struct BudgetAwareContextReducer {
17 keep_last_n_assistant_messages: usize,
18 context_budget_threshold: f32,
19}
20
21impl BudgetAwareContextReducer {
22 pub fn new(keep_last_n_assistant_messages: usize, context_budget_threshold: f32) -> Self {
23 Self {
24 keep_last_n_assistant_messages,
25 context_budget_threshold,
26 }
27 }
28
29 fn bytes_to_tokens(bytes: usize) -> u64 {
30 (bytes as f64 / BYTES_PER_TOKEN).ceil() as u64
31 }
32
33 fn estimate_content_part_tokens(part: &ContentPart) -> u64 {
34 match part {
35 ContentPart::Text { text, .. } => Self::bytes_to_tokens(text.len()),
36 ContentPart::ToolCall {
37 name, arguments, ..
38 } => {
39 let content_bytes = name.len() + arguments.to_string().len();
40 Self::bytes_to_tokens(content_bytes + 30)
41 }
42 ContentPart::ToolResult { content, .. } => {
43 let content_bytes = content
44 .as_str()
45 .map(|value| value.len())
46 .unwrap_or_else(|| content.to_string().len());
47 Self::bytes_to_tokens(content_bytes + 30)
48 }
49 ContentPart::Image { .. } => 2000,
50 }
51 }
52
53 fn estimate_message_tokens_raw(msg: &Message) -> u64 {
54 let content_tokens = match &msg.content {
55 MessageContent::Text(text) => Self::bytes_to_tokens(text.len()),
56 MessageContent::Parts(parts) => {
57 let part_tokens: u64 = parts.iter().map(Self::estimate_content_part_tokens).sum();
58 let part_overhead = parts.len() as u64 * 3;
59 part_tokens + part_overhead
60 }
61 };
62
63 content_tokens + 8
64 }
65
66 fn estimate_tokens_raw(messages: &[Message]) -> u64 {
67 messages.iter().map(Self::estimate_message_tokens_raw).sum()
68 }
69
70 fn add_safety_buffer(raw_tokens: u64) -> u64 {
71 (raw_tokens as f64 * SAFETY_BUFFER_FACTOR).ceil() as u64
72 }
73
74 pub fn estimate_tokens(messages: &[Message]) -> u64 {
75 Self::add_safety_buffer(Self::estimate_tokens_raw(messages))
76 }
77
78 pub fn estimate_tool_overhead(tools: &[Tool]) -> u64 {
79 tools
80 .iter()
81 .map(|tool| {
82 let schema_len = tool.function.parameters.to_string().len();
83 let tool_bytes =
84 tool.function.name.len() + tool.function.description.len() + schema_len;
85 let adjusted_bytes = (tool_bytes as f64 * 1.2).ceil() as usize;
86 Self::bytes_to_tokens(adjusted_bytes) + 8
87 })
88 .sum()
89 }
90
91 fn trim_message(msg: &mut Message) {
92 match &mut msg.content {
93 MessageContent::Text(text) => {
94 *text = TRIMMED_CONTENT_PLACEHOLDER.to_string();
95 }
96 MessageContent::Parts(parts) => {
97 for part in parts.iter_mut() {
98 match part {
99 ContentPart::Text { text, .. } => {
100 *text = TRIMMED_CONTENT_PLACEHOLDER.to_string();
101 }
102 ContentPart::ToolResult { content, .. } => {
103 *content = serde_json::json!(TRIMMED_CONTENT_PLACEHOLDER);
104 }
105 ContentPart::ToolCall { .. } | ContentPart::Image { .. } => {}
106 }
107 }
108 }
109 }
110 }
111
112 fn trim_message_with_delta(msg: &mut Message) -> i64 {
113 let before = Self::estimate_message_tokens_raw(msg);
114 Self::trim_message(msg);
115 let after = Self::estimate_message_tokens_raw(msg);
116 after as i64 - before as i64
117 }
118
119 fn metadata_trimmed_up_to(metadata: &serde_json::Value) -> usize {
120 metadata
121 .get("trimmed_up_to_message_index")
122 .and_then(serde_json::Value::as_u64)
123 .unwrap_or(0) as usize
124 }
125
126 fn ensure_metadata_object(metadata: &mut serde_json::Value) {
127 if !metadata.is_object() {
128 *metadata = serde_json::json!({});
129 }
130 }
131}
132
133impl ContextReducer for BudgetAwareContextReducer {
134 fn reduce(
135 &self,
136 messages: Vec<Message>,
137 model: &Model,
138 max_output_tokens: u32,
139 tools: &[Tool],
140 metadata: &mut serde_json::Value,
141 ) -> Vec<Message> {
142 let messages = merge_consecutive_same_role(messages);
143 let messages = dedup_tool_results(messages);
144 let messages = strip_dangling_tool_calls(messages);
145 let mut messages = remove_orphaned_tool_results(messages);
146
147 let available_context_window = model.limit.context.saturating_sub(max_output_tokens as u64);
148 let threshold = (available_context_window as f32 * self.context_budget_threshold) as u64;
149 let trim_target = (threshold as f64 * TRIM_HEADROOM_FACTOR) as u64;
150 let tool_overhead = Self::estimate_tool_overhead(tools);
151
152 let prev_trimmed_up_to = Self::metadata_trimmed_up_to(metadata);
153 let mut raw_tokens = Self::estimate_tokens_raw(&messages);
154
155 if prev_trimmed_up_to == 0
156 && Self::add_safety_buffer(raw_tokens) + tool_overhead <= threshold
157 {
158 return messages;
159 }
160
161 let len = messages.len();
162 let mut keep_n_trim_end = if self.keep_last_n_assistant_messages > 0 {
163 0
164 } else {
165 len
166 };
167
168 if self.keep_last_n_assistant_messages > 0 {
169 let mut assistant_count = 0usize;
170 for i in (0..len).rev() {
171 if messages[i].role == Role::Assistant {
172 assistant_count += 1;
173 if assistant_count >= self.keep_last_n_assistant_messages {
174 keep_n_trim_end = i;
175 break;
176 }
177 }
178 }
179 }
180
181 let prev_clamped = prev_trimmed_up_to.min(len);
182 for msg in &mut messages[..prev_clamped] {
183 if msg.role == Role::Assistant || msg.role == Role::Tool {
184 let delta = Self::trim_message_with_delta(msg);
185 raw_tokens = (raw_tokens as i64 + delta).max(0) as u64;
186 }
187 }
188
189 let effective_estimated_tokens = Self::add_safety_buffer(raw_tokens) + tool_overhead;
190
191 let effective_trim_end = if effective_estimated_tokens > threshold {
192 let mut candidate = if keep_n_trim_end > 0 {
193 for msg in messages
194 .iter_mut()
195 .take(keep_n_trim_end.min(len))
196 .skip(prev_clamped)
197 {
198 if msg.role == Role::Assistant || msg.role == Role::Tool {
199 let delta = Self::trim_message_with_delta(msg);
200 raw_tokens = (raw_tokens as i64 + delta).max(0) as u64;
201 }
202 }
203 keep_n_trim_end
204 } else {
205 prev_trimmed_up_to
206 };
207
208 let mut current_tokens = Self::add_safety_buffer(raw_tokens) + tool_overhead;
209 if current_tokens > trim_target {
210 let mut scan_idx = candidate;
211 while scan_idx < len {
212 if messages[scan_idx].role == Role::Assistant
213 || messages[scan_idx].role == Role::Tool
214 {
215 let delta = Self::trim_message_with_delta(&mut messages[scan_idx]);
216 raw_tokens = (raw_tokens as i64 + delta).max(0) as u64;
217 candidate = scan_idx + 1;
218
219 current_tokens = Self::add_safety_buffer(raw_tokens) + tool_overhead;
220 if current_tokens <= trim_target {
221 break;
222 }
223 }
224 scan_idx += 1;
225 }
226 }
227
228 candidate.max(prev_trimmed_up_to)
229 } else {
230 prev_trimmed_up_to
231 };
232
233 let clamped_end = effective_trim_end.min(len);
237 for msg in messages.iter_mut().take(clamped_end).skip(prev_clamped) {
238 if msg.role == Role::Assistant || msg.role == Role::Tool {
239 let delta = Self::trim_message_with_delta(msg);
240 raw_tokens = (raw_tokens as i64 + delta).max(0) as u64;
241 }
242 }
243
244 Self::ensure_metadata_object(metadata);
245 if let Some(obj) = metadata.as_object_mut() {
246 obj.insert(
247 "trimmed_up_to_message_index".to_string(),
248 serde_json::json!(effective_trim_end),
249 );
250 }
251
252 messages
253 }
254}