1use std::collections::HashMap;
5use std::sync::Arc;
6
7use parking_lot::Mutex;
8
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12#[error("daily budget exhausted: spent {spent_cents:.2} / {budget_cents:.2} cents")]
13pub struct BudgetExhausted {
14 pub spent_cents: f64,
15 pub budget_cents: f64,
16}
17
18#[derive(Debug, Clone, Default)]
20pub struct ProviderUsage {
21 pub input_tokens: u64,
22 pub cache_read_tokens: u64,
23 pub cache_write_tokens: u64,
24 pub output_tokens: u64,
25 pub cost_cents: f64,
26 pub request_count: u64,
27 pub model: String,
29}
30
31#[derive(Debug, Clone)]
32pub struct ModelPricing {
33 pub prompt_cents_per_1k: f64,
34 pub completion_cents_per_1k: f64,
35 pub cache_read_cents_per_1k: f64,
37 pub cache_write_cents_per_1k: f64,
39}
40
41struct CostState {
42 spent_cents: f64,
43 day: u32,
44 providers: HashMap<String, ProviderUsage>,
45}
46
47pub struct CostTracker {
48 pricing: HashMap<String, ModelPricing>,
49 state: Arc<Mutex<CostState>>,
50 max_daily_cents: f64,
51 enabled: bool,
52}
53
54fn current_day() -> u32 {
55 use std::time::{SystemTime, UNIX_EPOCH};
56 let secs = SystemTime::now()
57 .duration_since(UNIX_EPOCH)
58 .unwrap_or_default()
59 .as_secs();
60 u32::try_from(secs / 86_400).unwrap_or(0)
62}
63
64fn claude_pricing(prompt: f64, completion: f64) -> ModelPricing {
65 ModelPricing {
66 prompt_cents_per_1k: prompt,
67 completion_cents_per_1k: completion,
68 cache_read_cents_per_1k: prompt * 0.1,
70 cache_write_cents_per_1k: prompt * 1.25,
71 }
72}
73
74fn openai_pricing(prompt: f64, completion: f64) -> ModelPricing {
75 ModelPricing {
76 prompt_cents_per_1k: prompt,
77 completion_cents_per_1k: completion,
78 cache_read_cents_per_1k: prompt * 0.5,
80 cache_write_cents_per_1k: 0.0,
81 }
82}
83
84fn default_pricing() -> HashMap<String, ModelPricing> {
85 let mut m = HashMap::new();
86 m.insert("claude-sonnet-4-20250514".into(), claude_pricing(0.3, 1.5));
88 m.insert("claude-opus-4-20250514".into(), claude_pricing(1.5, 7.5));
89 m.insert("claude-opus-4-1-20250805".into(), claude_pricing(1.5, 7.5));
91 m.insert("claude-haiku-4-5-20251001".into(), claude_pricing(0.1, 0.5));
93 m.insert(
94 "claude-sonnet-4-5-20250929".into(),
95 claude_pricing(0.3, 1.5),
96 );
97 m.insert("claude-opus-4-5-20251101".into(), claude_pricing(0.5, 2.5));
98 m.insert("claude-sonnet-4-6".into(), claude_pricing(0.3, 1.5));
100 m.insert("claude-opus-4-6".into(), claude_pricing(0.5, 2.5));
101 m.insert("gpt-4o".into(), openai_pricing(0.25, 1.0));
103 m.insert("gpt-4o-mini".into(), openai_pricing(0.015, 0.06));
104 m.insert("gpt-5".into(), openai_pricing(0.125, 1.0));
106 m.insert("gpt-5-mini".into(), openai_pricing(0.025, 0.2));
108 m
109}
110
111fn reset_if_new_day(state: &mut CostState) {
112 let today = current_day();
113 if state.day != today {
114 state.spent_cents = 0.0;
115 state.day = today;
116 state.providers.clear();
117 }
118}
119
120impl CostTracker {
121 #[must_use]
122 pub fn new(enabled: bool, max_daily_cents: f64) -> Self {
123 Self {
124 pricing: default_pricing(),
125 state: Arc::new(Mutex::new(CostState {
126 spent_cents: 0.0,
127 day: current_day(),
128 providers: HashMap::new(),
129 })),
130 max_daily_cents,
131 enabled,
132 }
133 }
134
135 #[must_use]
136 pub fn with_pricing(mut self, model: &str, pricing: ModelPricing) -> Self {
137 self.pricing.insert(model.to_owned(), pricing);
138 self
139 }
140
141 #[allow(clippy::too_many_arguments)]
151 pub fn record_usage(
152 &self,
153 provider_name: &str,
154 provider_kind: &str,
155 model: &str,
156 input_tokens: u64,
157 cache_read_tokens: u64,
158 cache_write_tokens: u64,
159 output_tokens: u64,
160 ) {
161 if !self.enabled {
162 return;
163 }
164 let pricing = if let Some(p) = self.pricing.get(model).cloned() {
165 p
166 } else {
167 let is_local = matches!(provider_kind, "ollama" | "candle" | "local");
168 if is_local {
169 tracing::debug!(model, "local model; cost recorded as zero");
170 } else {
171 tracing::warn!(
172 model,
173 "model not found in pricing table; cost recorded as zero"
174 );
175 }
176 ModelPricing {
177 prompt_cents_per_1k: 0.0,
178 completion_cents_per_1k: 0.0,
179 cache_read_cents_per_1k: 0.0,
180 cache_write_cents_per_1k: 0.0,
181 }
182 };
183 #[allow(clippy::cast_precision_loss)]
184 let cost = pricing.prompt_cents_per_1k * (input_tokens as f64) / 1000.0
185 + pricing.completion_cents_per_1k * (output_tokens as f64) / 1000.0
186 + pricing.cache_read_cents_per_1k * (cache_read_tokens as f64) / 1000.0
187 + pricing.cache_write_cents_per_1k * (cache_write_tokens as f64) / 1000.0;
188
189 let mut state = self.state.lock();
190 reset_if_new_day(&mut state);
191 state.spent_cents += cost;
192
193 let entry = state.providers.entry(provider_name.to_owned()).or_default();
194 entry.input_tokens += input_tokens;
195 entry.cache_read_tokens += cache_read_tokens;
196 entry.cache_write_tokens += cache_write_tokens;
197 entry.output_tokens += output_tokens;
198 entry.cost_cents += cost;
199 entry.request_count += 1;
200 model.clone_into(&mut entry.model);
201 }
202
203 pub fn check_budget(&self) -> Result<(), BudgetExhausted> {
207 if !self.enabled {
208 return Ok(());
209 }
210 let mut state = self.state.lock();
211 reset_if_new_day(&mut state);
212 if self.max_daily_cents > 0.0 && state.spent_cents >= self.max_daily_cents {
213 return Err(BudgetExhausted {
214 spent_cents: state.spent_cents,
215 budget_cents: self.max_daily_cents,
216 });
217 }
218 Ok(())
219 }
220
221 #[must_use]
223 pub fn max_daily_cents(&self) -> f64 {
224 self.max_daily_cents
225 }
226
227 #[must_use]
228 pub fn current_spend(&self) -> f64 {
229 let state = self.state.lock();
230 state.spent_cents
231 }
232
233 #[must_use]
235 pub fn provider_breakdown(&self) -> Vec<(String, ProviderUsage)> {
236 let state = self.state.lock();
237 let mut breakdown: Vec<(String, ProviderUsage)> = state
238 .providers
239 .iter()
240 .map(|(k, v)| (k.clone(), v.clone()))
241 .collect();
242 breakdown.sort_by(|a, b| {
243 b.1.cost_cents
244 .partial_cmp(&a.1.cost_cents)
245 .unwrap_or(std::cmp::Ordering::Equal)
246 });
247 breakdown
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 fn record(tracker: &CostTracker, provider: &str, model: &str, input: u64, output: u64) {
256 tracker.record_usage(provider, "cloud", model, input, 0, 0, output);
257 }
258
259 #[test]
260 fn cost_tracker_records_usage_and_calculates_cost() {
261 let tracker = CostTracker::new(true, 1000.0);
262 record(&tracker, "openai", "gpt-4o", 1000, 1000);
263 let spend = tracker.current_spend();
265 assert!((spend - 1.25).abs() < 0.001);
266 }
267
268 #[test]
269 fn check_budget_passes_when_under_limit() {
270 let tracker = CostTracker::new(true, 100.0);
271 record(&tracker, "openai", "gpt-4o-mini", 100, 100);
272 assert!(tracker.check_budget().is_ok());
273 }
274
275 #[test]
276 fn check_budget_fails_when_over_limit() {
277 let tracker = CostTracker::new(true, 0.01);
278 record(&tracker, "claude", "claude-opus-4-20250514", 10000, 10000);
279 assert!(tracker.check_budget().is_err());
280 }
281
282 #[test]
283 fn daily_reset_clears_spending() {
284 let tracker = CostTracker::new(true, 100.0);
285 record(&tracker, "openai", "gpt-4o", 1000, 1000);
286 assert!(tracker.current_spend() > 0.0);
287 {
289 let mut state = tracker.state.lock();
290 state.day = 0; }
292 assert!(tracker.check_budget().is_ok());
294 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
295 }
296
297 #[test]
298 fn daily_reset_clears_provider_breakdown() {
299 let tracker = CostTracker::new(true, 100.0);
300 record(&tracker, "openai", "gpt-4o", 1000, 1000);
301 assert!(!tracker.provider_breakdown().is_empty());
302 {
304 let mut state = tracker.state.lock();
305 state.day = 0;
306 }
307 assert!(tracker.check_budget().is_ok());
308 assert!(tracker.provider_breakdown().is_empty());
309 }
310
311 #[test]
312 fn ollama_zero_cost() {
313 let tracker = CostTracker::new(true, 100.0);
314 record(&tracker, "ollama", "llama3:8b", 10000, 10000);
315 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
316 }
317
318 #[test]
319 fn ollama_unknown_model_no_warn_no_panic() {
320 let tracker = CostTracker::new(true, 100.0);
322 tracker.record_usage(
323 "local",
324 "ollama",
325 "totally-unknown-ollama-model",
326 5000,
327 0,
328 0,
329 5000,
330 );
331 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
332 }
333
334 #[test]
335 fn cloud_unknown_model_still_records_zero_cost() {
336 let tracker = CostTracker::new(true, 100.0);
338 tracker.record_usage(
339 "openai",
340 "cloud",
341 "totally-unknown-cloud-model",
342 5000,
343 0,
344 0,
345 5000,
346 );
347 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
348 }
349
350 #[test]
351 fn unknown_model_zero_cost() {
352 let tracker = CostTracker::new(true, 100.0);
353 record(&tracker, "unknown", "totally-unknown-model", 5000, 5000);
354 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
355 }
356
357 #[test]
358 fn known_claude_model_has_nonzero_cost() {
359 let tracker = CostTracker::new(true, 1000.0);
360 record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 1000);
361 assert!(tracker.current_spend() > 0.0);
362 }
363
364 #[test]
365 fn gpt5_pricing_is_correct() {
366 let tracker = CostTracker::new(true, 1000.0);
367 record(&tracker, "openai", "gpt-5", 1000, 1000);
368 let spend = tracker.current_spend();
370 assert!((spend - 1.125).abs() < 0.001);
371 }
372
373 #[test]
374 fn gpt5_mini_pricing_is_correct() {
375 let tracker = CostTracker::new(true, 1000.0);
376 record(&tracker, "openai", "gpt-5-mini", 1000, 1000);
377 let spend = tracker.current_spend();
379 assert!((spend - 0.225).abs() < 0.001);
380 }
381
382 #[test]
383 fn disabled_tracker_always_passes() {
384 let tracker = CostTracker::new(false, 0.0);
385 record(
386 &tracker,
387 "claude",
388 "claude-opus-4-20250514",
389 1_000_000,
390 1_000_000,
391 );
392 assert!(tracker.check_budget().is_ok());
393 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
394 }
395
396 #[test]
397 fn check_budget_unlimited_when_max_daily_cents_is_zero() {
398 let tracker = CostTracker::new(true, 0.0);
399 record(
400 &tracker,
401 "claude",
402 "claude-opus-4-20250514",
403 100_000,
404 100_000,
405 );
406 assert!(tracker.check_budget().is_ok());
407 }
408
409 #[test]
410 fn per_provider_accumulation() {
411 let tracker = CostTracker::new(true, 1000.0);
412 record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 500);
413 record(&tracker, "openai", "gpt-4o", 2000, 1000);
414 record(&tracker, "claude", "claude-haiku-4-5-20251001", 500, 200);
415
416 let breakdown = tracker.provider_breakdown();
417 assert_eq!(breakdown.len(), 2);
418
419 let claude = breakdown.iter().find(|(n, _)| n == "claude").unwrap();
420 assert_eq!(claude.1.request_count, 2);
421 assert_eq!(claude.1.input_tokens, 1500);
422 assert_eq!(claude.1.output_tokens, 700);
423
424 let openai = breakdown.iter().find(|(n, _)| n == "openai").unwrap();
425 assert_eq!(openai.1.request_count, 1);
426 assert_eq!(openai.1.input_tokens, 2000);
427 }
428
429 #[test]
430 fn provider_breakdown_sorted_by_cost_desc() {
431 let tracker = CostTracker::new(true, 1000.0);
432 record(&tracker, "cheap", "gpt-4o-mini", 100, 100);
434 record(&tracker, "expensive", "claude-opus-4-20250514", 10000, 5000);
435
436 let breakdown = tracker.provider_breakdown();
437 assert_eq!(breakdown[0].0, "expensive");
438 }
439
440 #[test]
441 fn cache_tokens_included_in_cost() {
442 let tracker = CostTracker::new(true, 1000.0);
443 tracker.record_usage(
446 "claude",
447 "cloud",
448 "claude-haiku-4-5-20251001",
449 0,
450 1000,
451 0,
452 0,
453 );
454 let spend = tracker.current_spend();
455 assert!(spend > 0.0, "cache read should contribute to cost");
456 }
457
458 #[test]
459 fn cache_write_cost_included_in_total() {
460 let tracker = CostTracker::new(true, 1000.0);
461 tracker.record_usage("claude-provider", "cloud", "claude-opus-4-6", 0, 0, 1000, 0);
465 let cost = tracker.current_spend();
466 assert!((cost - 0.625).abs() < 0.001);
467 }
468
469 #[test]
470 fn provider_breakdown_empty_when_disabled() {
471 let tracker = CostTracker::new(false, 100.0);
472 tracker.record_usage(
473 "claude",
474 "cloud",
475 "claude-haiku-4-5-20251001",
476 1000,
477 0,
478 0,
479 1000,
480 );
481 assert!(tracker.provider_breakdown().is_empty());
482 }
483}