1use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10#[error("daily budget exhausted: spent {spent_cents:.2} / {budget_cents:.2} cents")]
11pub struct BudgetExhausted {
12 pub spent_cents: f64,
13 pub budget_cents: f64,
14}
15
16#[derive(Debug, Clone)]
17pub struct ModelPricing {
18 pub prompt_cents_per_1k: f64,
19 pub completion_cents_per_1k: f64,
20}
21
22struct CostState {
23 spent_cents: f64,
24 day: u32,
25}
26
27pub struct CostTracker {
28 pricing: HashMap<String, ModelPricing>,
29 state: Arc<Mutex<CostState>>,
30 max_daily_cents: f64,
31 enabled: bool,
32}
33
34fn current_day() -> u32 {
35 use std::time::{SystemTime, UNIX_EPOCH};
36 let secs = SystemTime::now()
37 .duration_since(UNIX_EPOCH)
38 .unwrap_or_default()
39 .as_secs();
40 u32::try_from(secs / 86_400).unwrap_or(0)
42}
43
44fn default_pricing() -> HashMap<String, ModelPricing> {
45 let mut m = HashMap::new();
46 m.insert(
48 "claude-sonnet-4-20250514".into(),
49 ModelPricing {
50 prompt_cents_per_1k: 0.3,
51 completion_cents_per_1k: 1.5,
52 },
53 );
54 m.insert(
55 "claude-opus-4-20250514".into(),
56 ModelPricing {
57 prompt_cents_per_1k: 1.5,
58 completion_cents_per_1k: 7.5,
59 },
60 );
61 m.insert(
63 "claude-opus-4-1-20250805".into(),
64 ModelPricing {
65 prompt_cents_per_1k: 1.5,
66 completion_cents_per_1k: 7.5,
67 },
68 );
69 m.insert(
71 "claude-haiku-4-5-20251001".into(),
72 ModelPricing {
73 prompt_cents_per_1k: 0.1,
74 completion_cents_per_1k: 0.5,
75 },
76 );
77 m.insert(
78 "claude-sonnet-4-5-20250929".into(),
79 ModelPricing {
80 prompt_cents_per_1k: 0.3,
81 completion_cents_per_1k: 1.5,
82 },
83 );
84 m.insert(
85 "claude-opus-4-5-20251101".into(),
86 ModelPricing {
87 prompt_cents_per_1k: 0.5,
88 completion_cents_per_1k: 2.5,
89 },
90 );
91 m.insert(
93 "claude-sonnet-4-6".into(),
94 ModelPricing {
95 prompt_cents_per_1k: 0.3,
96 completion_cents_per_1k: 1.5,
97 },
98 );
99 m.insert(
100 "claude-opus-4-6".into(),
101 ModelPricing {
102 prompt_cents_per_1k: 0.5,
103 completion_cents_per_1k: 2.5,
104 },
105 );
106 m.insert(
107 "gpt-4o".into(),
108 ModelPricing {
109 prompt_cents_per_1k: 0.25,
110 completion_cents_per_1k: 1.0,
111 },
112 );
113 m.insert(
114 "gpt-4o-mini".into(),
115 ModelPricing {
116 prompt_cents_per_1k: 0.015,
117 completion_cents_per_1k: 0.06,
118 },
119 );
120 m.insert(
122 "gpt-5".into(),
123 ModelPricing {
124 prompt_cents_per_1k: 0.125,
125 completion_cents_per_1k: 1.0,
126 },
127 );
128 m.insert(
130 "gpt-5-mini".into(),
131 ModelPricing {
132 prompt_cents_per_1k: 0.025,
133 completion_cents_per_1k: 0.2,
134 },
135 );
136 m
137}
138
139impl CostTracker {
140 #[must_use]
141 pub fn new(enabled: bool, max_daily_cents: f64) -> Self {
142 Self {
143 pricing: default_pricing(),
144 state: Arc::new(Mutex::new(CostState {
145 spent_cents: 0.0,
146 day: current_day(),
147 })),
148 max_daily_cents,
149 enabled,
150 }
151 }
152
153 #[must_use]
154 pub fn with_pricing(mut self, model: &str, pricing: ModelPricing) -> Self {
155 self.pricing.insert(model.to_owned(), pricing);
156 self
157 }
158
159 pub fn record_usage(&self, model: &str, prompt_tokens: u64, completion_tokens: u64) {
160 if !self.enabled {
161 return;
162 }
163 let pricing = if let Some(p) = self.pricing.get(model).cloned() {
164 p
165 } else {
166 tracing::warn!(
167 model,
168 "model not found in pricing table; cost recorded as zero"
169 );
170 ModelPricing {
171 prompt_cents_per_1k: 0.0,
172 completion_cents_per_1k: 0.0,
173 }
174 };
175 #[allow(clippy::cast_precision_loss)]
176 let cost = pricing.prompt_cents_per_1k * (prompt_tokens as f64) / 1000.0
177 + pricing.completion_cents_per_1k * (completion_tokens as f64) / 1000.0;
178
179 let mut state = self
180 .state
181 .lock()
182 .unwrap_or_else(std::sync::PoisonError::into_inner);
183 let today = current_day();
184 if state.day != today {
185 state.spent_cents = 0.0;
186 state.day = today;
187 }
188 state.spent_cents += cost;
189 }
190
191 pub fn check_budget(&self) -> Result<(), BudgetExhausted> {
195 if !self.enabled {
196 return Ok(());
197 }
198 let mut state = self
199 .state
200 .lock()
201 .unwrap_or_else(std::sync::PoisonError::into_inner);
202 let today = current_day();
203 if state.day != today {
204 state.spent_cents = 0.0;
205 state.day = today;
206 }
207 if self.max_daily_cents > 0.0 && state.spent_cents >= self.max_daily_cents {
208 return Err(BudgetExhausted {
209 spent_cents: state.spent_cents,
210 budget_cents: self.max_daily_cents,
211 });
212 }
213 Ok(())
214 }
215
216 #[must_use]
218 pub fn max_daily_cents(&self) -> f64 {
219 self.max_daily_cents
220 }
221
222 #[must_use]
223 pub fn current_spend(&self) -> f64 {
224 let state = self
225 .state
226 .lock()
227 .unwrap_or_else(std::sync::PoisonError::into_inner);
228 state.spent_cents
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn cost_tracker_records_usage_and_calculates_cost() {
238 let tracker = CostTracker::new(true, 1000.0);
239 tracker.record_usage("gpt-4o", 1000, 1000);
240 let spend = tracker.current_spend();
242 assert!((spend - 1.25).abs() < 0.001);
243 }
244
245 #[test]
246 fn check_budget_passes_when_under_limit() {
247 let tracker = CostTracker::new(true, 100.0);
248 tracker.record_usage("gpt-4o-mini", 100, 100);
249 assert!(tracker.check_budget().is_ok());
250 }
251
252 #[test]
253 fn check_budget_fails_when_over_limit() {
254 let tracker = CostTracker::new(true, 0.01);
255 tracker.record_usage("claude-opus-4-20250514", 10000, 10000);
256 assert!(tracker.check_budget().is_err());
257 }
258
259 #[test]
260 fn daily_reset_clears_spending() {
261 let tracker = CostTracker::new(true, 100.0);
262 tracker.record_usage("gpt-4o", 1000, 1000);
263 assert!(tracker.current_spend() > 0.0);
264 {
266 let mut state = tracker.state.lock().unwrap();
267 state.day = 0; }
269 assert!(tracker.check_budget().is_ok());
271 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
272 }
273
274 #[test]
275 fn ollama_zero_cost() {
276 let tracker = CostTracker::new(true, 100.0);
277 tracker.record_usage("llama3:8b", 10000, 10000);
278 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
279 }
280
281 #[test]
282 fn unknown_model_zero_cost() {
283 let tracker = CostTracker::new(true, 100.0);
284 tracker.record_usage("totally-unknown-model", 5000, 5000);
285 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
286 }
287
288 #[test]
289 fn known_claude_model_has_nonzero_cost() {
290 let tracker = CostTracker::new(true, 1000.0);
291 tracker.record_usage("claude-haiku-4-5-20251001", 1000, 1000);
292 assert!(tracker.current_spend() > 0.0);
293 }
294
295 #[test]
296 fn gpt5_pricing_is_correct() {
297 let tracker = CostTracker::new(true, 1000.0);
298 tracker.record_usage("gpt-5", 1000, 1000);
299 let spend = tracker.current_spend();
301 assert!((spend - 1.125).abs() < 0.001);
302 }
303
304 #[test]
305 fn gpt5_mini_pricing_is_correct() {
306 let tracker = CostTracker::new(true, 1000.0);
307 tracker.record_usage("gpt-5-mini", 1000, 1000);
308 let spend = tracker.current_spend();
310 assert!((spend - 0.225).abs() < 0.001);
311 }
312
313 #[test]
314 fn disabled_tracker_always_passes() {
315 let tracker = CostTracker::new(false, 0.0);
316 tracker.record_usage("claude-opus-4-20250514", 1_000_000, 1_000_000);
317 assert!(tracker.check_budget().is_ok());
318 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
319 }
320
321 #[test]
322 fn check_budget_unlimited_when_max_daily_cents_is_zero() {
323 let tracker = CostTracker::new(true, 0.0);
324 tracker.record_usage("claude-opus-4-20250514", 100_000, 100_000);
325 assert!(tracker.check_budget().is_ok());
326 }
327}