Skip to main content

token_budget_pool/
lib.rs

1//! # token-budget-pool
2//!
3//! Shared token + dollar budget across N concurrent LLM tasks.
4//!
5//! Drop a `BudgetPool` at the top of an agent run; pass `&pool` to every
6//! task that issues LLM calls; call [`BudgetPool::record`] after each
7//! response. The pool serializes the updates and returns
8//! [`BudgetExceeded`] when a record would push past any cap.
9//!
10//! ## Example
11//!
12//! ```
13//! use token_budget_pool::{BudgetPool, Caps};
14//!
15//! let pool = BudgetPool::with_caps(Caps {
16//!     max_input_tokens: Some(10_000),
17//!     max_output_tokens: Some(5_000),
18//!     max_total_tokens: None,
19//!     max_cost_usd: Some(1.0),
20//! });
21//!
22//! pool.record(1_000, 500, 0.05).unwrap(); // fits
23//! let err = pool.record(20_000, 0, 0.0).unwrap_err(); // input cap blown
24//! assert!(format!("{err}").contains("input_tokens"));
25//! ```
26
27#![deny(missing_docs)]
28
29use std::sync::Mutex;
30
31/// Caps for a single pool. Any cap left as `None` is unenforced.
32#[derive(Debug, Clone, Copy, Default, PartialEq)]
33pub struct Caps {
34    /// Cap on cumulative input tokens across all recorded calls.
35    pub max_input_tokens: Option<u64>,
36    /// Cap on cumulative output tokens across all recorded calls.
37    pub max_output_tokens: Option<u64>,
38    /// Cap on cumulative input + output tokens.
39    pub max_total_tokens: Option<u64>,
40    /// Cap on cumulative USD spend.
41    pub max_cost_usd: Option<f64>,
42}
43
44/// Running totals.
45#[derive(Debug, Clone, Copy, Default, PartialEq)]
46pub struct Totals {
47    /// Cumulative input tokens recorded.
48    pub input_tokens: u64,
49    /// Cumulative output tokens recorded.
50    pub output_tokens: u64,
51    /// Cumulative dollars recorded.
52    pub cost_usd: f64,
53    /// Number of `record` calls counted.
54    pub calls: u64,
55}
56
57impl Totals {
58    /// Sum of input + output tokens.
59    pub fn total_tokens(&self) -> u64 {
60        self.input_tokens + self.output_tokens
61    }
62}
63
64/// Error returned when a `record` call would push past a cap.
65///
66/// The error names the first cap that would be exceeded; subsequent caps
67/// may also be exceeded. **The pool's totals are NOT updated** when this
68/// error fires — the call is rejected outright.
69#[derive(Debug, Clone, Copy, PartialEq)]
70pub struct BudgetExceeded {
71    /// Which cap blew. One of `"input_tokens"`, `"output_tokens"`,
72    /// `"total_tokens"`, `"cost_usd"`.
73    pub cap: &'static str,
74    /// The cap limit that was breached.
75    pub limit: f64,
76    /// What the running total would have become.
77    pub attempted: f64,
78}
79
80impl std::fmt::Display for BudgetExceeded {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        write!(
83            f,
84            "budget cap `{}` exceeded: limit={}, attempted={}",
85            self.cap, self.limit, self.attempted
86        )
87    }
88}
89
90impl std::error::Error for BudgetExceeded {}
91
92/// Shared budget. Cheap to construct; record() takes a mutex.
93#[derive(Debug)]
94pub struct BudgetPool {
95    caps: Caps,
96    state: Mutex<Totals>,
97}
98
99impl BudgetPool {
100    /// Build a pool with the given caps. All caps default to `None`.
101    pub fn with_caps(caps: Caps) -> Self {
102        Self {
103            caps,
104            state: Mutex::new(Totals::default()),
105        }
106    }
107
108    /// Build an unconstrained pool (no caps).
109    pub fn unconstrained() -> Self {
110        Self::with_caps(Caps::default())
111    }
112
113    /// Record one call's usage. Returns the updated totals on success, or
114    /// [`BudgetExceeded`] (totals unchanged) on cap breach.
115    pub fn record(
116        &self,
117        input_tokens: u64,
118        output_tokens: u64,
119        cost_usd: f64,
120    ) -> Result<Totals, BudgetExceeded> {
121        let mut s = self.state.lock().unwrap();
122
123        let next_in = s.input_tokens + input_tokens;
124        let next_out = s.output_tokens + output_tokens;
125        let next_total = next_in + next_out;
126        let next_cost = s.cost_usd + cost_usd;
127
128        if let Some(cap) = self.caps.max_input_tokens {
129            if next_in > cap {
130                return Err(BudgetExceeded {
131                    cap: "input_tokens",
132                    limit: cap as f64,
133                    attempted: next_in as f64,
134                });
135            }
136        }
137        if let Some(cap) = self.caps.max_output_tokens {
138            if next_out > cap {
139                return Err(BudgetExceeded {
140                    cap: "output_tokens",
141                    limit: cap as f64,
142                    attempted: next_out as f64,
143                });
144            }
145        }
146        if let Some(cap) = self.caps.max_total_tokens {
147            if next_total > cap {
148                return Err(BudgetExceeded {
149                    cap: "total_tokens",
150                    limit: cap as f64,
151                    attempted: next_total as f64,
152                });
153            }
154        }
155        if let Some(cap) = self.caps.max_cost_usd {
156            if next_cost > cap {
157                return Err(BudgetExceeded {
158                    cap: "cost_usd",
159                    limit: cap,
160                    attempted: next_cost,
161                });
162            }
163        }
164
165        s.input_tokens = next_in;
166        s.output_tokens = next_out;
167        s.cost_usd = next_cost;
168        s.calls += 1;
169        Ok(*s)
170    }
171
172    /// Read current totals.
173    pub fn totals(&self) -> Totals {
174        *self.state.lock().unwrap()
175    }
176
177    /// Read the caps.
178    pub fn caps(&self) -> Caps {
179        self.caps
180    }
181
182    /// Reset the pool to zero totals (caps unchanged).
183    pub fn reset(&self) {
184        *self.state.lock().unwrap() = Totals::default();
185    }
186}