1use std::collections::HashMap;
2
3use crate::preset::Preset;
4use crate::types::AgentId;
5
6#[derive(Debug, Clone)]
12pub struct AgentBudget {
13 pub agent_id: AgentId,
14 pub allocated: u32,
16 pub consumed: u32,
18 pub pinned: u32,
20}
21
22impl AgentBudget {
23 fn new(agent_id: AgentId, allocated: u32) -> Self {
24 AgentBudget {
25 agent_id,
26 allocated,
27 consumed: 0,
28 pinned: 0,
29 }
30 }
31
32 fn available(&self) -> u32 {
34 self.allocated
35 .saturating_sub(self.consumed)
36 .saturating_sub(self.pinned)
37 }
38
39 fn consumed_pct(&self) -> f64 {
40 if self.allocated == 0 {
41 return 1.0;
42 }
43 (self.consumed + self.pinned) as f64 / self.allocated as f64
44 }
45
46 fn pinned_pct(&self) -> f64 {
47 if self.allocated == 0 {
48 return 0.0;
49 }
50 self.pinned as f64 / self.allocated as f64
51 }
52}
53
54#[derive(Debug, Clone, PartialEq)]
60pub enum BudgetWarning {
61 ThresholdCrossed {
63 agent: AgentId,
64 percentage: f64,
65 remaining: u32,
66 },
67 PredictiveOverage {
69 agent: AgentId,
70 current: f64,
71 projected: f64,
72 },
73 PinnedExcessive {
75 agent: AgentId,
76 pinned_pct: f64,
77 },
78 AgentBudgetExhausted { agent: AgentId },
80}
81
82#[derive(Debug, Clone)]
88pub struct UsagePrediction {
89 pub current_pct: f64,
90 pub projected_pct: f64,
91 pub would_exceed_ceiling: bool,
92}
93
94#[derive(Debug, Clone)]
96pub struct UsageReport {
97 pub agent_id: AgentId,
98 pub consumed: u32,
99 pub allocated: u32,
100 pub pinned: u32,
101 pub available: u32,
102 pub consumed_pct: f64,
103}
104
105pub struct BudgetTracker {
111 window_size: u32,
112 agents: HashMap<AgentId, AgentBudget>,
113 warning_threshold: f64,
115 ceiling_threshold: f64,
117}
118
119impl BudgetTracker {
120 pub fn new(window_size: u32, preset: &Preset) -> Self {
126 let warning_threshold = preset.budget.warning_threshold;
127 let ceiling_threshold = preset.budget.ceiling_threshold;
128
129 let mut agents: HashMap<AgentId, AgentBudget> = HashMap::new();
130
131 if preset.budget.agents.is_empty() {
132 agents.insert(
133 "default".to_string(),
134 AgentBudget::new("default".to_string(), window_size),
135 );
136 } else {
137 for (name, fraction) in &preset.budget.agents {
138 let allocated = (window_size as f64 * fraction).round() as u32;
139 agents.insert(name.clone(), AgentBudget::new(name.clone(), allocated));
140 }
141 }
142
143 BudgetTracker {
144 window_size,
145 agents,
146 warning_threshold,
147 ceiling_threshold,
148 }
149 }
150
151 pub fn with_thresholds(
153 window_size: u32,
154 warning_threshold: f64,
155 ceiling_threshold: f64,
156 ) -> Self {
157 let mut agents = HashMap::new();
158 agents.insert(
159 "default".to_string(),
160 AgentBudget::new("default".to_string(), window_size),
161 );
162 BudgetTracker {
163 window_size,
164 agents,
165 warning_threshold,
166 ceiling_threshold,
167 }
168 }
169
170 fn ensure_agent(&mut self, agent: &AgentId) {
175 if !self.agents.contains_key(agent) {
176 self.agents.insert(
177 agent.clone(),
178 AgentBudget::new(agent.clone(), self.window_size),
179 );
180 }
181 }
182
183 pub fn record_tokens(&mut self, agent: AgentId, tokens: u32) -> Vec<BudgetWarning> {
191 self.ensure_agent(&agent);
192 let mut warnings = Vec::new();
193
194 let budget = self.agents.get_mut(&agent).unwrap();
195 let before_pct = budget.consumed_pct();
196 budget.consumed = budget.consumed.saturating_add(tokens);
197 let after_pct = budget.consumed_pct();
198
199 if budget.consumed >= budget.allocated {
201 warnings.push(BudgetWarning::AgentBudgetExhausted {
202 agent: agent.clone(),
203 });
204 return warnings;
205 }
206
207 if before_pct < self.warning_threshold && after_pct >= self.warning_threshold {
209 warnings.push(BudgetWarning::ThresholdCrossed {
210 agent: agent.clone(),
211 percentage: after_pct,
212 remaining: budget.available(),
213 });
214 }
215
216 let pinned_pct = budget.pinned_pct();
218 if pinned_pct > 0.5 {
219 warnings.push(BudgetWarning::PinnedExcessive {
220 agent: agent.clone(),
221 pinned_pct,
222 });
223 }
224
225 warnings
226 }
227
228 pub fn predict_usage(&self, agent: AgentId, pending_tokens: u32) -> UsagePrediction {
230 let budget = match self.agents.get(&agent) {
231 Some(b) => b,
232 None => {
233 let current_pct = 0.0;
235 let projected_pct = pending_tokens as f64 / self.window_size as f64;
236 return UsagePrediction {
237 current_pct,
238 projected_pct,
239 would_exceed_ceiling: projected_pct >= self.ceiling_threshold,
240 };
241 }
242 };
243
244 let current_pct = budget.consumed_pct();
245 let projected_consumed = budget.consumed.saturating_add(pending_tokens);
246 let projected_pct = if budget.allocated == 0 {
247 1.0
248 } else {
249 (projected_consumed + budget.pinned) as f64 / budget.allocated as f64
250 };
251
252 UsagePrediction {
253 current_pct,
254 projected_pct,
255 would_exceed_ceiling: projected_pct >= self.ceiling_threshold,
256 }
257 }
258
259 pub fn available(&self, agent: AgentId) -> u32 {
263 match self.agents.get(&agent) {
264 Some(b) => b.available(),
265 None => self.window_size,
266 }
267 }
268
269 pub fn pin_tokens(&mut self, agent: AgentId, tokens: u32) -> Vec<BudgetWarning> {
274 self.ensure_agent(&agent);
275 let mut warnings = Vec::new();
276
277 let budget = self.agents.get_mut(&agent).unwrap();
278 budget.pinned = budget.pinned.saturating_add(tokens);
279
280 if budget.pinned_pct() > 0.5 {
281 warnings.push(BudgetWarning::PinnedExcessive {
282 agent: agent.clone(),
283 pinned_pct: budget.pinned_pct(),
284 });
285 }
286
287 warnings
288 }
289
290 pub fn unpin_tokens(&mut self, agent: AgentId, tokens: u32) {
292 self.ensure_agent(&agent);
293 let budget = self.agents.get_mut(&agent).unwrap();
294 budget.pinned = budget.pinned.saturating_sub(tokens);
295 }
296
297 pub fn usage_report(&self, agent: AgentId) -> UsageReport {
299 match self.agents.get(&agent) {
300 Some(b) => UsageReport {
301 agent_id: agent,
302 consumed: b.consumed,
303 allocated: b.allocated,
304 pinned: b.pinned,
305 available: b.available(),
306 consumed_pct: b.consumed_pct(),
307 },
308 None => UsageReport {
309 agent_id: agent,
310 consumed: 0,
311 allocated: self.window_size,
312 pinned: 0,
313 available: self.window_size,
314 consumed_pct: 0.0,
315 },
316 }
317 }
318
319 pub fn window_size(&self) -> u32 {
321 self.window_size
322 }
323}
324
325#[cfg(test)]
330mod tests {
331 use super::*;
332 use proptest::prelude::*;
333
334 fn tracker(window: u32) -> BudgetTracker {
339 BudgetTracker::with_thresholds(window, 0.70, 0.85)
340 }
341
342 #[allow(dead_code)]
343 fn arb_agent() -> impl Strategy<Value = AgentId> {
344 "[a-z]{1,8}".prop_map(|s| s)
345 }
346
347 proptest! {
353 #[test]
360 fn prop_budget_token_count_invariant(
361 amounts in prop::collection::vec(0u32..=10_000u32, 0..=50),
362 ) {
363 let window = 1_000_000u32;
364 let mut bt = tracker(window);
365 let agent = "agent".to_string();
366
367 let expected: u32 = amounts.iter().copied().fold(0u32, |acc, x| acc.saturating_add(x));
368
369 for &a in &amounts {
370 bt.record_tokens(agent.clone(), a);
371 }
372
373 let report = bt.usage_report(agent);
374 prop_assert_eq!(report.consumed, expected);
375 }
376 }
377
378 proptest! {
384 #[test]
392 fn prop_threshold_warning_fires_on_crossing(
393 window in 1_000u32..=200_000u32,
395 wt_raw in 1_000u32..=8_000u32,
397 ) {
399 let warning_threshold = wt_raw as f64 / 10_000.0;
400 let ceiling_threshold = (warning_threshold + 0.05).min(0.99);
402
403 let mut bt = BudgetTracker::with_thresholds(window, warning_threshold, ceiling_threshold);
404 let agent = "a".to_string();
405
406 let just_below = ((window as f64 * warning_threshold) - 1.0).max(0.0) as u32;
408 bt.record_tokens(agent.clone(), just_below);
409
410 let report = bt.usage_report(agent.clone());
412 let pct = report.consumed_pct;
413 prop_assert!(pct < warning_threshold || just_below == 0,
414 "pct={} should be below warning_threshold={}", pct, warning_threshold);
415
416 let push_over = 2u32;
418 let warnings = bt.record_tokens(agent.clone(), push_over);
419
420 let new_pct = bt.usage_report(agent.clone()).consumed_pct;
421 if new_pct >= warning_threshold && new_pct < 1.0 {
422 let has_threshold_warning = warnings.iter().any(|w| {
423 matches!(w, BudgetWarning::ThresholdCrossed { .. })
424 || matches!(w, BudgetWarning::AgentBudgetExhausted { .. })
425 });
426 prop_assert!(has_threshold_warning,
427 "expected ThresholdCrossed or Exhausted warning at pct={}", new_pct);
428 }
429 }
430
431 #[test]
435 fn prop_predictive_warning_fires_above_ceiling(
436 window in 1_000u32..=200_000u32,
437 consumed_raw in 0u32..=8_000u32,
438 pending_raw in 0u32..=5_000u32,
439 ) {
440 let warning_threshold = 0.70;
441 let ceiling_threshold = 0.85;
442 let mut bt = BudgetTracker::with_thresholds(window, warning_threshold, ceiling_threshold);
443 let agent = "a".to_string();
444
445 let consumed = (consumed_raw as f64 / 10_000.0 * window as f64) as u32;
447 let pending = (pending_raw as f64 / 10_000.0 * window as f64) as u32;
448
449 bt.record_tokens(agent.clone(), consumed);
450 let pred = bt.predict_usage(agent.clone(), pending);
451
452 if pred.projected_pct >= ceiling_threshold {
453 prop_assert!(pred.would_exceed_ceiling,
454 "projected_pct={} >= ceiling={} but would_exceed_ceiling=false",
455 pred.projected_pct, ceiling_threshold);
456 } else {
457 prop_assert!(!pred.would_exceed_ceiling,
458 "projected_pct={} < ceiling={} but would_exceed_ceiling=true",
459 pred.projected_pct, ceiling_threshold);
460 }
461 }
462 }
463
464 proptest! {
470 #[test]
477 fn prop_available_equals_window_minus_consumed_minus_pinned(
478 window in 1u32..=1_000_000u32,
479 consumed in 0u32..=500_000u32,
480 pinned in 0u32..=500_000u32,
481 ) {
482 let mut bt = tracker(window);
483 let agent = "a".to_string();
484
485 bt.record_tokens(agent.clone(), 0);
487
488 {
490 let b = bt.agents.get_mut(&agent).unwrap();
491 b.consumed = consumed.min(window);
492 b.pinned = pinned.min(window);
493 }
494
495 let expected = window
496 .saturating_sub(consumed.min(window))
497 .saturating_sub(pinned.min(window));
498
499 let actual = bt.available(agent);
500 prop_assert_eq!(actual, expected);
501 }
502 }
503
504 proptest! {
510 #[test]
517 fn prop_excessive_pin_warning(
518 window in 2u32..=1_000_000u32,
519 pin_raw in 5_001u32..=10_000u32,
521 ) {
522 let pinned = (pin_raw as f64 / 10_000.0 * window as f64).ceil() as u32;
523 let pinned = pinned.max(window / 2 + 1);
525
526 let mut bt = tracker(window);
527 let agent = "a".to_string();
528
529 let warnings = bt.pin_tokens(agent.clone(), pinned);
530
531 let has_excessive = warnings
532 .iter()
533 .any(|w| matches!(w, BudgetWarning::PinnedExcessive { .. }));
534 prop_assert!(has_excessive,
535 "expected PinnedExcessive warning for pinned={} / window={}", pinned, window);
536 }
537
538 #[test]
540 fn prop_no_excessive_pin_warning_below_threshold(
541 window in 2u32..=1_000_000u32,
542 pin_raw in 0u32..=5_000u32,
544 ) {
545 let pinned = (pin_raw as f64 / 10_000.0 * window as f64) as u32;
546
547 let mut bt = tracker(window);
548 let agent = "a".to_string();
549
550 let warnings = bt.pin_tokens(agent.clone(), pinned);
551
552 let has_excessive = warnings
553 .iter()
554 .any(|w| matches!(w, BudgetWarning::PinnedExcessive { .. }));
555 prop_assert!(!has_excessive,
556 "unexpected PinnedExcessive warning for pinned={} / window={}", pinned, window);
557 }
558 }
559
560 proptest! {
566 #[test]
573 fn prop_multi_agent_isolation(
574 tokens_a in 0u32..=50_000u32,
575 tokens_b in 0u32..=50_000u32,
576 ) {
577 let window = 1_000_000u32;
578 let mut bt = tracker(window);
579
580 let agent_a = "alpha".to_string();
581 let agent_b = "beta".to_string();
582
583 bt.record_tokens(agent_a.clone(), 0);
585 bt.record_tokens(agent_b.clone(), 0);
586
587 bt.record_tokens(agent_a.clone(), tokens_a);
589
590 let report_a = bt.usage_report(agent_a.clone());
591 let report_b = bt.usage_report(agent_b.clone());
592
593 prop_assert_eq!(report_a.consumed, tokens_a,
594 "agent_a consumed should be {}", tokens_a);
595 prop_assert_eq!(report_b.consumed, 0,
596 "agent_b consumed should still be 0 after recording for agent_a");
597
598 bt.record_tokens(agent_b.clone(), tokens_b);
600
601 let report_a2 = bt.usage_report(agent_a.clone());
602 let report_b2 = bt.usage_report(agent_b.clone());
603
604 prop_assert_eq!(report_a2.consumed, tokens_a,
605 "agent_a consumed should remain {}", tokens_a);
606 prop_assert_eq!(report_b2.consumed, tokens_b,
607 "agent_b consumed should be {}", tokens_b);
608 }
609 }
610
611 proptest! {
617 #[test]
625 fn prop_agent_exhausted_warning(
626 window in 100u32..=1_000_000u32,
627 ) {
628 let mut bt = BudgetTracker::with_thresholds(window, 0.70, 0.85);
630 let agent = "worker".to_string();
631
632 let warnings = bt.record_tokens(agent.clone(), window);
634
635 let exhausted = warnings
636 .iter()
637 .any(|w| matches!(w, BudgetWarning::AgentBudgetExhausted { .. }));
638 prop_assert!(exhausted,
639 "expected AgentBudgetExhausted after consuming full window={}", window);
640 }
641
642 #[test]
644 fn prop_near_budget_warning(
645 window in 1_000u32..=1_000_000u32,
646 ) {
647 let mut bt = BudgetTracker::with_thresholds(window, 0.90, 0.95);
649 let agent = "worker".to_string();
650
651 let tokens = (window as f64 * 0.91) as u32;
653 let warnings = bt.record_tokens(agent.clone(), tokens);
654
655 let has_warning = warnings.iter().any(|w| {
656 matches!(w, BudgetWarning::ThresholdCrossed { .. })
657 || matches!(w, BudgetWarning::AgentBudgetExhausted { .. })
658 });
659 prop_assert!(has_warning,
660 "expected threshold warning at 91% of window={}", window);
661 }
662 }
663
664 #[test]
669 fn test_available_decreases_with_pin() {
670 let mut bt = tracker(1_000);
671 let agent = "a".to_string();
672 bt.record_tokens(agent.clone(), 200);
673 assert_eq!(bt.available(agent.clone()), 800);
674 bt.pin_tokens(agent.clone(), 100);
675 assert_eq!(bt.available(agent.clone()), 700);
676 bt.unpin_tokens(agent.clone(), 50);
677 assert_eq!(bt.available(agent.clone()), 750);
678 }
679
680 #[test]
681 fn test_unpin_does_not_go_negative() {
682 let mut bt = tracker(1_000);
683 let agent = "a".to_string();
684 bt.pin_tokens(agent.clone(), 100);
685 bt.unpin_tokens(agent.clone(), 200); assert_eq!(bt.available(agent.clone()), 1_000);
687 }
688
689 #[test]
690 fn test_usage_report_fields() {
691 let mut bt = tracker(10_000);
692 let agent = "x".to_string();
693 bt.record_tokens(agent.clone(), 3_000);
694 bt.pin_tokens(agent.clone(), 1_000);
695 let r = bt.usage_report(agent);
696 assert_eq!(r.consumed, 3_000);
697 assert_eq!(r.pinned, 1_000);
698 assert_eq!(r.allocated, 10_000);
699 assert_eq!(r.available, 6_000);
700 assert!((r.consumed_pct - 0.4).abs() < 1e-9);
701 }
702}