Skip to main content

stakpak_agent_core/
budget_context.rs

1use crate::context::{
2    ContextReducer, dedup_tool_results, merge_consecutive_same_role, remove_orphaned_tool_results,
3    strip_dangling_tool_calls,
4};
5use stakai::{ContentPart, Message, MessageContent, Model, Role, Tool};
6
7const TRIMMED_CONTENT_PLACEHOLDER: &str = "[trimmed]";
8const BYTES_PER_TOKEN: f64 = 3.5;
9const SAFETY_BUFFER_FACTOR: f64 = 1.05;
10/// Headroom factor applied to the trim target so the trim boundary stays frozen
11/// across multiple turns, preserving Anthropic prompt cache stability. Without this,
12/// trim_end advances every turn, invalidating cache on every request.
13const TRIM_HEADROOM_FACTOR: f64 = 0.75;
14
15#[derive(Debug, Clone)]
16pub struct BudgetAwareContextReducer {
17    keep_last_n_assistant_messages: usize,
18    context_budget_threshold: f32,
19}
20
21impl BudgetAwareContextReducer {
22    pub fn new(keep_last_n_assistant_messages: usize, context_budget_threshold: f32) -> Self {
23        Self {
24            keep_last_n_assistant_messages,
25            context_budget_threshold,
26        }
27    }
28
29    fn bytes_to_tokens(bytes: usize) -> u64 {
30        (bytes as f64 / BYTES_PER_TOKEN).ceil() as u64
31    }
32
33    fn estimate_content_part_tokens(part: &ContentPart) -> u64 {
34        match part {
35            ContentPart::Text { text, .. } => Self::bytes_to_tokens(text.len()),
36            ContentPart::ToolCall {
37                name, arguments, ..
38            } => {
39                let content_bytes = name.len() + arguments.to_string().len();
40                Self::bytes_to_tokens(content_bytes + 30)
41            }
42            ContentPart::ToolResult { content, .. } => {
43                let content_bytes = content
44                    .as_str()
45                    .map(|value| value.len())
46                    .unwrap_or_else(|| content.to_string().len());
47                Self::bytes_to_tokens(content_bytes + 30)
48            }
49            ContentPart::Image { .. } => 2000,
50        }
51    }
52
53    fn estimate_message_tokens_raw(msg: &Message) -> u64 {
54        let content_tokens = match &msg.content {
55            MessageContent::Text(text) => Self::bytes_to_tokens(text.len()),
56            MessageContent::Parts(parts) => {
57                let part_tokens: u64 = parts.iter().map(Self::estimate_content_part_tokens).sum();
58                let part_overhead = parts.len() as u64 * 3;
59                part_tokens + part_overhead
60            }
61        };
62
63        content_tokens + 8
64    }
65
66    fn estimate_tokens_raw(messages: &[Message]) -> u64 {
67        messages.iter().map(Self::estimate_message_tokens_raw).sum()
68    }
69
70    fn add_safety_buffer(raw_tokens: u64) -> u64 {
71        (raw_tokens as f64 * SAFETY_BUFFER_FACTOR).ceil() as u64
72    }
73
74    pub fn estimate_tokens(messages: &[Message]) -> u64 {
75        Self::add_safety_buffer(Self::estimate_tokens_raw(messages))
76    }
77
78    pub fn estimate_tool_overhead(tools: &[Tool]) -> u64 {
79        tools
80            .iter()
81            .map(|tool| {
82                let schema_len = tool.function.parameters.to_string().len();
83                let tool_bytes =
84                    tool.function.name.len() + tool.function.description.len() + schema_len;
85                let adjusted_bytes = (tool_bytes as f64 * 1.2).ceil() as usize;
86                Self::bytes_to_tokens(adjusted_bytes) + 8
87            })
88            .sum()
89    }
90
91    fn trim_message(msg: &mut Message) {
92        match &mut msg.content {
93            MessageContent::Text(text) => {
94                *text = TRIMMED_CONTENT_PLACEHOLDER.to_string();
95            }
96            MessageContent::Parts(parts) => {
97                for part in parts.iter_mut() {
98                    match part {
99                        ContentPart::Text { text, .. } => {
100                            *text = TRIMMED_CONTENT_PLACEHOLDER.to_string();
101                        }
102                        ContentPart::ToolResult { content, .. } => {
103                            *content = serde_json::json!(TRIMMED_CONTENT_PLACEHOLDER);
104                        }
105                        ContentPart::ToolCall { .. } | ContentPart::Image { .. } => {}
106                    }
107                }
108            }
109        }
110    }
111
112    fn trim_message_with_delta(msg: &mut Message) -> i64 {
113        let before = Self::estimate_message_tokens_raw(msg);
114        Self::trim_message(msg);
115        let after = Self::estimate_message_tokens_raw(msg);
116        after as i64 - before as i64
117    }
118
119    fn metadata_trimmed_up_to(metadata: &serde_json::Value) -> usize {
120        metadata
121            .get("trimmed_up_to_message_index")
122            .and_then(serde_json::Value::as_u64)
123            .unwrap_or(0) as usize
124    }
125
126    fn ensure_metadata_object(metadata: &mut serde_json::Value) {
127        if !metadata.is_object() {
128            *metadata = serde_json::json!({});
129        }
130    }
131}
132
133impl ContextReducer for BudgetAwareContextReducer {
134    fn reduce(
135        &self,
136        messages: Vec<Message>,
137        model: &Model,
138        max_output_tokens: u32,
139        tools: &[Tool],
140        metadata: &mut serde_json::Value,
141    ) -> Vec<Message> {
142        let messages = merge_consecutive_same_role(messages);
143        let messages = dedup_tool_results(messages);
144        let messages = strip_dangling_tool_calls(messages);
145        let mut messages = remove_orphaned_tool_results(messages);
146
147        let available_context_window = model.limit.context.saturating_sub(max_output_tokens as u64);
148        let threshold = (available_context_window as f32 * self.context_budget_threshold) as u64;
149        let trim_target = (threshold as f64 * TRIM_HEADROOM_FACTOR) as u64;
150        let tool_overhead = Self::estimate_tool_overhead(tools);
151
152        let prev_trimmed_up_to = Self::metadata_trimmed_up_to(metadata);
153        let mut raw_tokens = Self::estimate_tokens_raw(&messages);
154
155        if prev_trimmed_up_to == 0
156            && Self::add_safety_buffer(raw_tokens) + tool_overhead <= threshold
157        {
158            return messages;
159        }
160
161        let len = messages.len();
162        let mut keep_n_trim_end = if self.keep_last_n_assistant_messages > 0 {
163            0
164        } else {
165            len
166        };
167
168        if self.keep_last_n_assistant_messages > 0 {
169            let mut assistant_count = 0usize;
170            for i in (0..len).rev() {
171                if messages[i].role == Role::Assistant {
172                    assistant_count += 1;
173                    if assistant_count >= self.keep_last_n_assistant_messages {
174                        keep_n_trim_end = i;
175                        break;
176                    }
177                }
178            }
179        }
180
181        let prev_clamped = prev_trimmed_up_to.min(len);
182        for msg in &mut messages[..prev_clamped] {
183            if msg.role == Role::Assistant || msg.role == Role::Tool {
184                let delta = Self::trim_message_with_delta(msg);
185                raw_tokens = (raw_tokens as i64 + delta).max(0) as u64;
186            }
187        }
188
189        let effective_estimated_tokens = Self::add_safety_buffer(raw_tokens) + tool_overhead;
190
191        let effective_trim_end = if effective_estimated_tokens > threshold {
192            let mut candidate = if keep_n_trim_end > 0 {
193                for msg in messages
194                    .iter_mut()
195                    .take(keep_n_trim_end.min(len))
196                    .skip(prev_clamped)
197                {
198                    if msg.role == Role::Assistant || msg.role == Role::Tool {
199                        let delta = Self::trim_message_with_delta(msg);
200                        raw_tokens = (raw_tokens as i64 + delta).max(0) as u64;
201                    }
202                }
203                keep_n_trim_end
204            } else {
205                prev_trimmed_up_to
206            };
207
208            let mut current_tokens = Self::add_safety_buffer(raw_tokens) + tool_overhead;
209            if current_tokens > trim_target {
210                let mut scan_idx = candidate;
211                while scan_idx < len {
212                    if messages[scan_idx].role == Role::Assistant
213                        || messages[scan_idx].role == Role::Tool
214                    {
215                        let delta = Self::trim_message_with_delta(&mut messages[scan_idx]);
216                        raw_tokens = (raw_tokens as i64 + delta).max(0) as u64;
217                        candidate = scan_idx + 1;
218
219                        current_tokens = Self::add_safety_buffer(raw_tokens) + tool_overhead;
220                        if current_tokens <= trim_target {
221                            break;
222                        }
223                    }
224                    scan_idx += 1;
225                }
226            }
227
228            candidate.max(prev_trimmed_up_to)
229        } else {
230            prev_trimmed_up_to
231        };
232
233        // Final pass: ensure every message in [prev_clamped..effective_trim_end] is trimmed.
234        // Some of these may have been trimmed already by Phase 2 above — that's harmless
235        // because trim_message_with_delta is idempotent (delta=0 on already-trimmed content).
236        let clamped_end = effective_trim_end.min(len);
237        for msg in messages.iter_mut().take(clamped_end).skip(prev_clamped) {
238            if msg.role == Role::Assistant || msg.role == Role::Tool {
239                let delta = Self::trim_message_with_delta(msg);
240                raw_tokens = (raw_tokens as i64 + delta).max(0) as u64;
241            }
242        }
243
244        Self::ensure_metadata_object(metadata);
245        if let Some(obj) = metadata.as_object_mut() {
246            obj.insert(
247                "trimmed_up_to_message_index".to_string(),
248                serde_json::json!(effective_trim_end),
249            );
250        }
251
252        messages
253    }
254}