1use std::collections::HashMap;
5use std::sync::Arc;
6
7use parking_lot::Mutex;
8
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12#[error("daily budget exhausted: spent {spent_cents:.2} / {budget_cents:.2} cents")]
13pub struct BudgetExhausted {
14 pub spent_cents: f64,
15 pub budget_cents: f64,
16}
17
18#[derive(Debug, Clone, Default)]
20pub struct ProviderUsage {
21 pub input_tokens: u64,
22 pub cache_read_tokens: u64,
23 pub cache_write_tokens: u64,
24 pub output_tokens: u64,
25 pub cost_cents: f64,
26 pub request_count: u64,
27 pub model: String,
29}
30
31#[derive(Debug, Clone)]
32pub struct ModelPricing {
33 pub prompt_cents_per_1k: f64,
34 pub completion_cents_per_1k: f64,
35 pub cache_read_cents_per_1k: f64,
37 pub cache_write_cents_per_1k: f64,
39}
40
41struct CostState {
42 spent_cents: f64,
43 day: u32,
44 providers: HashMap<String, ProviderUsage>,
45 successful_tasks: u64,
46}
47
48pub struct CostTracker {
49 pricing: HashMap<String, ModelPricing>,
50 state: Arc<Mutex<CostState>>,
51 max_daily_cents: f64,
52 enabled: bool,
53}
54
55fn current_day() -> u32 {
56 use std::time::{SystemTime, UNIX_EPOCH};
57 let secs = SystemTime::now()
58 .duration_since(UNIX_EPOCH)
59 .unwrap_or_default()
60 .as_secs();
61 u32::try_from(secs / 86_400).unwrap_or(0)
63}
64
65fn claude_pricing(prompt: f64, completion: f64) -> ModelPricing {
66 ModelPricing {
67 prompt_cents_per_1k: prompt,
68 completion_cents_per_1k: completion,
69 cache_read_cents_per_1k: prompt * 0.1,
71 cache_write_cents_per_1k: prompt * 1.25,
72 }
73}
74
75fn openai_pricing(prompt: f64, completion: f64) -> ModelPricing {
76 ModelPricing {
77 prompt_cents_per_1k: prompt,
78 completion_cents_per_1k: completion,
79 cache_read_cents_per_1k: prompt * 0.5,
81 cache_write_cents_per_1k: 0.0,
82 }
83}
84
85fn default_pricing() -> HashMap<String, ModelPricing> {
86 let mut m = HashMap::new();
87 m.insert("claude-sonnet-4-20250514".into(), claude_pricing(0.3, 1.5));
89 m.insert("claude-opus-4-20250514".into(), claude_pricing(1.5, 7.5));
90 m.insert("claude-opus-4-1-20250805".into(), claude_pricing(1.5, 7.5));
92 m.insert("claude-haiku-4-5-20251001".into(), claude_pricing(0.1, 0.5));
94 m.insert(
95 "claude-sonnet-4-5-20250929".into(),
96 claude_pricing(0.3, 1.5),
97 );
98 m.insert("claude-opus-4-5-20251101".into(), claude_pricing(0.5, 2.5));
99 m.insert("claude-sonnet-4-6".into(), claude_pricing(0.3, 1.5));
101 m.insert("claude-opus-4-6".into(), claude_pricing(0.5, 2.5));
102 m.insert("gpt-4o".into(), openai_pricing(0.25, 1.0));
104 m.insert("gpt-4o-mini".into(), openai_pricing(0.015, 0.06));
105 m.insert("gpt-5".into(), openai_pricing(0.125, 1.0));
107 m.insert("gpt-5-mini".into(), openai_pricing(0.025, 0.2));
109 m
110}
111
112fn reset_if_new_day(state: &mut CostState) {
113 let today = current_day();
114 if state.day != today {
115 state.spent_cents = 0.0;
116 state.day = today;
117 state.providers.clear();
118 state.successful_tasks = 0;
119 }
120}
121
122impl CostTracker {
123 #[must_use]
124 pub fn new(enabled: bool, max_daily_cents: f64) -> Self {
125 Self {
126 pricing: default_pricing(),
127 state: Arc::new(Mutex::new(CostState {
128 spent_cents: 0.0,
129 day: current_day(),
130 providers: HashMap::new(),
131 successful_tasks: 0,
132 })),
133 max_daily_cents,
134 enabled,
135 }
136 }
137
138 #[must_use]
139 pub fn with_pricing(mut self, model: &str, pricing: ModelPricing) -> Self {
140 self.pricing.insert(model.to_owned(), pricing);
141 self
142 }
143
144 #[allow(clippy::too_many_arguments)] pub fn record_usage(
155 &self,
156 provider_name: &str,
157 provider_kind: &str,
158 model: &str,
159 input_tokens: u64,
160 cache_read_tokens: u64,
161 cache_write_tokens: u64,
162 output_tokens: u64,
163 ) {
164 if !self.enabled {
165 return;
166 }
167 let pricing = if let Some(p) = self.pricing.get(model).cloned() {
168 p
169 } else {
170 let is_local = matches!(provider_kind, "ollama" | "candle" | "local");
171 if is_local {
172 tracing::debug!(model, "local model; cost recorded as zero");
173 } else {
174 tracing::warn!(
175 model,
176 "model not found in pricing table; cost recorded as zero"
177 );
178 }
179 ModelPricing {
180 prompt_cents_per_1k: 0.0,
181 completion_cents_per_1k: 0.0,
182 cache_read_cents_per_1k: 0.0,
183 cache_write_cents_per_1k: 0.0,
184 }
185 };
186 #[allow(clippy::cast_precision_loss)]
187 let cost = pricing.prompt_cents_per_1k * (input_tokens as f64) / 1000.0
188 + pricing.completion_cents_per_1k * (output_tokens as f64) / 1000.0
189 + pricing.cache_read_cents_per_1k * (cache_read_tokens as f64) / 1000.0
190 + pricing.cache_write_cents_per_1k * (cache_write_tokens as f64) / 1000.0;
191
192 let mut state = self.state.lock();
193 reset_if_new_day(&mut state);
194 state.spent_cents += cost;
195
196 let entry = state.providers.entry(provider_name.to_owned()).or_default();
197 entry.input_tokens += input_tokens;
198 entry.cache_read_tokens += cache_read_tokens;
199 entry.cache_write_tokens += cache_write_tokens;
200 entry.output_tokens += output_tokens;
201 entry.cost_cents += cost;
202 entry.request_count += 1;
203 model.clone_into(&mut entry.model);
204 }
205
206 pub fn check_budget(&self) -> Result<(), BudgetExhausted> {
210 if !self.enabled {
211 return Ok(());
212 }
213 let mut state = self.state.lock();
214 reset_if_new_day(&mut state);
215 if self.max_daily_cents > 0.0 && state.spent_cents >= self.max_daily_cents {
216 return Err(BudgetExhausted {
217 spent_cents: state.spent_cents,
218 budget_cents: self.max_daily_cents,
219 });
220 }
221 Ok(())
222 }
223
224 #[must_use]
226 pub fn max_daily_cents(&self) -> f64 {
227 self.max_daily_cents
228 }
229
230 #[must_use]
231 pub fn current_spend(&self) -> f64 {
232 let state = self.state.lock();
233 state.spent_cents
234 }
235
236 pub fn record_successful_task(&self) {
240 if !self.enabled {
241 return;
242 }
243 let mut state = self.state.lock();
244 reset_if_new_day(&mut state);
245 state.successful_tasks += 1;
246 }
247
248 #[must_use]
250 pub fn cps(&self) -> Option<f64> {
251 let state = self.state.lock();
252 if state.successful_tasks == 0 {
253 return None;
254 }
255 #[allow(clippy::cast_precision_loss)]
256 Some(state.spent_cents / state.successful_tasks as f64)
257 }
258
259 #[must_use]
261 pub fn successful_tasks(&self) -> u64 {
262 self.state.lock().successful_tasks
263 }
264
265 #[must_use]
267 pub fn provider_breakdown(&self) -> Vec<(String, ProviderUsage)> {
268 let state = self.state.lock();
269 let mut breakdown: Vec<(String, ProviderUsage)> = state
270 .providers
271 .iter()
272 .map(|(k, v)| (k.clone(), v.clone()))
273 .collect();
274 breakdown.sort_by(|a, b| {
275 b.1.cost_cents
276 .partial_cmp(&a.1.cost_cents)
277 .unwrap_or(std::cmp::Ordering::Equal)
278 });
279 breakdown
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 fn record(tracker: &CostTracker, provider: &str, model: &str, input: u64, output: u64) {
288 tracker.record_usage(provider, "cloud", model, input, 0, 0, output);
289 }
290
291 #[test]
292 fn cost_tracker_records_usage_and_calculates_cost() {
293 let tracker = CostTracker::new(true, 1000.0);
294 record(&tracker, "openai", "gpt-4o", 1000, 1000);
295 let spend = tracker.current_spend();
297 assert!((spend - 1.25).abs() < 0.001);
298 }
299
300 #[test]
301 fn check_budget_passes_when_under_limit() {
302 let tracker = CostTracker::new(true, 100.0);
303 record(&tracker, "openai", "gpt-4o-mini", 100, 100);
304 assert!(tracker.check_budget().is_ok());
305 }
306
307 #[test]
308 fn check_budget_fails_when_over_limit() {
309 let tracker = CostTracker::new(true, 0.01);
310 record(&tracker, "claude", "claude-opus-4-20250514", 10000, 10000);
311 assert!(tracker.check_budget().is_err());
312 }
313
314 #[test]
315 fn daily_reset_clears_spending() {
316 let tracker = CostTracker::new(true, 100.0);
317 record(&tracker, "openai", "gpt-4o", 1000, 1000);
318 assert!(tracker.current_spend() > 0.0);
319 {
321 let mut state = tracker.state.lock();
322 state.day = 0; }
324 assert!(tracker.check_budget().is_ok());
326 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
327 }
328
329 #[test]
330 fn daily_reset_clears_provider_breakdown() {
331 let tracker = CostTracker::new(true, 100.0);
332 record(&tracker, "openai", "gpt-4o", 1000, 1000);
333 assert!(!tracker.provider_breakdown().is_empty());
334 {
336 let mut state = tracker.state.lock();
337 state.day = 0;
338 }
339 assert!(tracker.check_budget().is_ok());
340 assert!(tracker.provider_breakdown().is_empty());
341 }
342
343 #[test]
344 fn ollama_zero_cost() {
345 let tracker = CostTracker::new(true, 100.0);
346 record(&tracker, "ollama", "llama3:8b", 10000, 10000);
347 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
348 }
349
350 #[test]
351 fn ollama_unknown_model_no_warn_no_panic() {
352 let tracker = CostTracker::new(true, 100.0);
354 tracker.record_usage(
355 "local",
356 "ollama",
357 "totally-unknown-ollama-model",
358 5000,
359 0,
360 0,
361 5000,
362 );
363 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
364 }
365
366 #[test]
367 fn cloud_unknown_model_still_records_zero_cost() {
368 let tracker = CostTracker::new(true, 100.0);
370 tracker.record_usage(
371 "openai",
372 "cloud",
373 "totally-unknown-cloud-model",
374 5000,
375 0,
376 0,
377 5000,
378 );
379 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
380 }
381
382 #[test]
383 fn unknown_model_zero_cost() {
384 let tracker = CostTracker::new(true, 100.0);
385 record(&tracker, "unknown", "totally-unknown-model", 5000, 5000);
386 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
387 }
388
389 #[test]
390 fn known_claude_model_has_nonzero_cost() {
391 let tracker = CostTracker::new(true, 1000.0);
392 record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 1000);
393 assert!(tracker.current_spend() > 0.0);
394 }
395
396 #[test]
397 fn gpt5_pricing_is_correct() {
398 let tracker = CostTracker::new(true, 1000.0);
399 record(&tracker, "openai", "gpt-5", 1000, 1000);
400 let spend = tracker.current_spend();
402 assert!((spend - 1.125).abs() < 0.001);
403 }
404
405 #[test]
406 fn gpt5_mini_pricing_is_correct() {
407 let tracker = CostTracker::new(true, 1000.0);
408 record(&tracker, "openai", "gpt-5-mini", 1000, 1000);
409 let spend = tracker.current_spend();
411 assert!((spend - 0.225).abs() < 0.001);
412 }
413
414 #[test]
415 fn disabled_tracker_always_passes() {
416 let tracker = CostTracker::new(false, 0.0);
417 record(
418 &tracker,
419 "claude",
420 "claude-opus-4-20250514",
421 1_000_000,
422 1_000_000,
423 );
424 assert!(tracker.check_budget().is_ok());
425 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
426 }
427
428 #[test]
429 fn check_budget_unlimited_when_max_daily_cents_is_zero() {
430 let tracker = CostTracker::new(true, 0.0);
431 record(
432 &tracker,
433 "claude",
434 "claude-opus-4-20250514",
435 100_000,
436 100_000,
437 );
438 assert!(tracker.check_budget().is_ok());
439 }
440
441 #[test]
442 fn per_provider_accumulation() {
443 let tracker = CostTracker::new(true, 1000.0);
444 record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 500);
445 record(&tracker, "openai", "gpt-4o", 2000, 1000);
446 record(&tracker, "claude", "claude-haiku-4-5-20251001", 500, 200);
447
448 let breakdown = tracker.provider_breakdown();
449 assert_eq!(breakdown.len(), 2);
450
451 let claude = breakdown.iter().find(|(n, _)| n == "claude").unwrap();
452 assert_eq!(claude.1.request_count, 2);
453 assert_eq!(claude.1.input_tokens, 1500);
454 assert_eq!(claude.1.output_tokens, 700);
455
456 let openai = breakdown.iter().find(|(n, _)| n == "openai").unwrap();
457 assert_eq!(openai.1.request_count, 1);
458 assert_eq!(openai.1.input_tokens, 2000);
459 }
460
461 #[test]
462 fn provider_breakdown_sorted_by_cost_desc() {
463 let tracker = CostTracker::new(true, 1000.0);
464 record(&tracker, "cheap", "gpt-4o-mini", 100, 100);
466 record(&tracker, "expensive", "claude-opus-4-20250514", 10000, 5000);
467
468 let breakdown = tracker.provider_breakdown();
469 assert_eq!(breakdown[0].0, "expensive");
470 }
471
472 #[test]
473 fn cache_tokens_included_in_cost() {
474 let tracker = CostTracker::new(true, 1000.0);
475 tracker.record_usage(
478 "claude",
479 "cloud",
480 "claude-haiku-4-5-20251001",
481 0,
482 1000,
483 0,
484 0,
485 );
486 let spend = tracker.current_spend();
487 assert!(spend > 0.0, "cache read should contribute to cost");
488 }
489
490 #[test]
491 fn cache_write_cost_included_in_total() {
492 let tracker = CostTracker::new(true, 1000.0);
493 tracker.record_usage("claude-provider", "cloud", "claude-opus-4-6", 0, 0, 1000, 0);
497 let cost = tracker.current_spend();
498 assert!((cost - 0.625).abs() < 0.001);
499 }
500
501 #[test]
502 fn provider_breakdown_empty_when_disabled() {
503 let tracker = CostTracker::new(false, 100.0);
504 tracker.record_usage(
505 "claude",
506 "cloud",
507 "claude-haiku-4-5-20251001",
508 1000,
509 0,
510 0,
511 1000,
512 );
513 assert!(tracker.provider_breakdown().is_empty());
514 }
515
516 #[test]
517 fn cps_none_when_no_tasks() {
518 let tracker = CostTracker::new(true, 100.0);
519 assert!(tracker.cps().is_none());
520 assert_eq!(tracker.successful_tasks(), 0);
521 }
522
523 #[test]
524 fn cps_calculated_correctly() {
525 let tracker = CostTracker::new(true, 100.0);
526 record(&tracker, "openai", "gpt-4o", 1000, 1000);
528 tracker.record_successful_task();
529 tracker.record_successful_task();
530 assert_eq!(tracker.successful_tasks(), 2);
531 let cps = tracker.cps().expect("cps should be Some after tasks");
532 assert!((cps - 0.625).abs() < 0.001, "cps={cps}");
534 }
535
536 #[test]
537 fn cps_resets_on_new_day() {
538 let tracker = CostTracker::new(true, 100.0);
539 record(&tracker, "openai", "gpt-4o", 1000, 1000);
540 tracker.record_successful_task();
541 assert_eq!(tracker.successful_tasks(), 1);
542 {
544 let mut state = tracker.state.lock();
545 state.day = 0;
546 }
547 assert!(tracker.check_budget().is_ok());
549 assert_eq!(tracker.successful_tasks(), 0);
550 assert!(tracker.cps().is_none());
551 }
552}