1use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use crate::core::{AgentStopReason, LoopState, PolicyFn, StepOutcome};
31
32#[derive(Debug, Clone, Copy, Default)]
34pub struct Budget {
35 pub max_prompt_tokens: Option<usize>,
36 pub max_completion_tokens: Option<usize>,
37 pub max_tool_invocations: Option<usize>,
38 pub max_time: Option<Duration>,
39}
40
41#[derive(Debug, Clone)]
43pub struct BudgetUsage {
44 pub prompt_tokens: usize,
45 pub completion_tokens: usize,
46 pub tools: usize,
47 pub start_time: Instant,
48}
49
50impl Default for BudgetUsage {
51 fn default() -> Self {
52 Self {
53 prompt_tokens: 0,
54 completion_tokens: 0,
55 tools: 0,
56 start_time: Instant::now(),
57 }
58 }
59}
60
61pub fn budget_policy(b: Budget) -> PolicyFn {
65 let usage = Arc::new(std::sync::Mutex::new(BudgetUsage::default()));
66 let usage_cl = usage.clone();
67 PolicyFn(Arc::new(move |_state: &LoopState, last: &StepOutcome| {
68 let usage = usage_cl.clone();
69 let mut u = usage.lock().unwrap();
70 match last {
71 StepOutcome::Next { aux, .. } | StepOutcome::Done { aux, .. } => {
72 u.prompt_tokens += aux.prompt_tokens;
73 u.completion_tokens += aux.completion_tokens;
74 u.tools += aux.tool_invocations;
75 }
76 }
77 if let Some(max) = b.max_time {
79 if u.start_time.elapsed() >= max {
80 return Some(AgentStopReason::TimeBudgetExceeded);
81 }
82 }
83 if let Some(max) = b.max_prompt_tokens {
84 if u.prompt_tokens >= max {
85 return Some(AgentStopReason::TokensBudgetExceeded);
86 }
87 }
88 if let Some(max) = b.max_completion_tokens {
89 if u.completion_tokens >= max {
90 return Some(AgentStopReason::TokensBudgetExceeded);
91 }
92 }
93 if let Some(max) = b.max_tool_invocations {
94 if u.tools >= max {
95 return Some(AgentStopReason::ToolBudgetExceeded);
96 }
97 }
98 None
99 }))
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use crate::core::{AgentPolicy, CompositePolicy, StepAux};
106
107 fn fake_next_step(prompt: usize, completion: usize, tools: usize) -> StepOutcome {
108 StepOutcome::Next {
109 messages: vec![],
110 aux: StepAux {
111 prompt_tokens: prompt,
112 completion_tokens: completion,
113 tool_invocations: tools,
114 },
115 invoked_tools: vec![],
116 }
117 }
118
119 #[tokio::test]
120 async fn stops_on_token_budget() {
121 let budget = Budget {
122 max_prompt_tokens: Some(10),
123 ..Default::default()
124 };
125 let policy = budget_policy(budget);
126 let comp = CompositePolicy::new(vec![policy]);
127 let state = LoopState { steps: 1 };
128 assert!(comp.decide(&state, &fake_next_step(5, 0, 0)).is_none());
130 assert!(matches!(
132 comp.decide(&state, &fake_next_step(5, 0, 0)),
133 Some(AgentStopReason::TokensBudgetExceeded)
134 ));
135 }
136
137 #[tokio::test]
138 async fn stops_on_tool_budget() {
139 let budget = Budget {
140 max_tool_invocations: Some(2),
141 ..Default::default()
142 };
143 let policy = budget_policy(budget);
144 let comp = CompositePolicy::new(vec![policy]);
145 let state = LoopState { steps: 1 };
146 assert!(comp.decide(&state, &fake_next_step(0, 0, 1)).is_none());
147 assert!(matches!(
148 comp.decide(&state, &fake_next_step(0, 0, 1)),
149 Some(AgentStopReason::ToolBudgetExceeded)
150 ));
151 }
152}