1use std::collections::HashMap;
5use std::sync::Arc;
6
7use parking_lot::Mutex;
8
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12#[error("daily budget exhausted: spent {spent_cents:.2} / {budget_cents:.2} cents")]
13pub struct BudgetExhausted {
14 pub spent_cents: f64,
15 pub budget_cents: f64,
16}
17
18#[derive(Debug, Clone, Default)]
20pub struct ProviderUsage {
21 pub input_tokens: u64,
22 pub cache_read_tokens: u64,
23 pub cache_write_tokens: u64,
24 pub output_tokens: u64,
25 pub cost_cents: f64,
26 pub request_count: u64,
27 pub model: String,
29}
30
31#[derive(Debug, Clone)]
32pub struct ModelPricing {
33 pub prompt_cents_per_1k: f64,
34 pub completion_cents_per_1k: f64,
35 pub cache_read_cents_per_1k: f64,
37 pub cache_write_cents_per_1k: f64,
39}
40
41struct CostState {
42 spent_cents: f64,
43 day: u32,
44 providers: HashMap<String, ProviderUsage>,
45}
46
47pub struct CostTracker {
48 pricing: HashMap<String, ModelPricing>,
49 state: Arc<Mutex<CostState>>,
50 max_daily_cents: f64,
51 enabled: bool,
52}
53
54fn current_day() -> u32 {
55 use std::time::{SystemTime, UNIX_EPOCH};
56 let secs = SystemTime::now()
57 .duration_since(UNIX_EPOCH)
58 .unwrap_or_default()
59 .as_secs();
60 u32::try_from(secs / 86_400).unwrap_or(0)
62}
63
64fn claude_pricing(prompt: f64, completion: f64) -> ModelPricing {
65 ModelPricing {
66 prompt_cents_per_1k: prompt,
67 completion_cents_per_1k: completion,
68 cache_read_cents_per_1k: prompt * 0.1,
70 cache_write_cents_per_1k: prompt * 1.25,
71 }
72}
73
74fn openai_pricing(prompt: f64, completion: f64) -> ModelPricing {
75 ModelPricing {
76 prompt_cents_per_1k: prompt,
77 completion_cents_per_1k: completion,
78 cache_read_cents_per_1k: prompt * 0.5,
80 cache_write_cents_per_1k: 0.0,
81 }
82}
83
84fn default_pricing() -> HashMap<String, ModelPricing> {
85 let mut m = HashMap::new();
86 m.insert("claude-sonnet-4-20250514".into(), claude_pricing(0.3, 1.5));
88 m.insert("claude-opus-4-20250514".into(), claude_pricing(1.5, 7.5));
89 m.insert("claude-opus-4-1-20250805".into(), claude_pricing(1.5, 7.5));
91 m.insert("claude-haiku-4-5-20251001".into(), claude_pricing(0.1, 0.5));
93 m.insert(
94 "claude-sonnet-4-5-20250929".into(),
95 claude_pricing(0.3, 1.5),
96 );
97 m.insert("claude-opus-4-5-20251101".into(), claude_pricing(0.5, 2.5));
98 m.insert("claude-sonnet-4-6".into(), claude_pricing(0.3, 1.5));
100 m.insert("claude-opus-4-6".into(), claude_pricing(0.5, 2.5));
101 m.insert("gpt-4o".into(), openai_pricing(0.25, 1.0));
103 m.insert("gpt-4o-mini".into(), openai_pricing(0.015, 0.06));
104 m.insert("gpt-5".into(), openai_pricing(0.125, 1.0));
106 m.insert("gpt-5-mini".into(), openai_pricing(0.025, 0.2));
108 m
109}
110
111fn reset_if_new_day(state: &mut CostState) {
112 let today = current_day();
113 if state.day != today {
114 state.spent_cents = 0.0;
115 state.day = today;
116 state.providers.clear();
117 }
118}
119
120impl CostTracker {
121 #[must_use]
122 pub fn new(enabled: bool, max_daily_cents: f64) -> Self {
123 Self {
124 pricing: default_pricing(),
125 state: Arc::new(Mutex::new(CostState {
126 spent_cents: 0.0,
127 day: current_day(),
128 providers: HashMap::new(),
129 })),
130 max_daily_cents,
131 enabled,
132 }
133 }
134
135 #[must_use]
136 pub fn with_pricing(mut self, model: &str, pricing: ModelPricing) -> Self {
137 self.pricing.insert(model.to_owned(), pricing);
138 self
139 }
140
141 pub fn record_usage(
146 &self,
147 provider_name: &str,
148 model: &str,
149 input_tokens: u64,
150 cache_read_tokens: u64,
151 cache_write_tokens: u64,
152 output_tokens: u64,
153 ) {
154 if !self.enabled {
155 return;
156 }
157 let pricing = if let Some(p) = self.pricing.get(model).cloned() {
158 p
159 } else {
160 tracing::warn!(
161 model,
162 "model not found in pricing table; cost recorded as zero"
163 );
164 ModelPricing {
165 prompt_cents_per_1k: 0.0,
166 completion_cents_per_1k: 0.0,
167 cache_read_cents_per_1k: 0.0,
168 cache_write_cents_per_1k: 0.0,
169 }
170 };
171 #[allow(clippy::cast_precision_loss)]
172 let cost = pricing.prompt_cents_per_1k * (input_tokens as f64) / 1000.0
173 + pricing.completion_cents_per_1k * (output_tokens as f64) / 1000.0
174 + pricing.cache_read_cents_per_1k * (cache_read_tokens as f64) / 1000.0
175 + pricing.cache_write_cents_per_1k * (cache_write_tokens as f64) / 1000.0;
176
177 let mut state = self.state.lock();
178 reset_if_new_day(&mut state);
179 state.spent_cents += cost;
180
181 let entry = state.providers.entry(provider_name.to_owned()).or_default();
182 entry.input_tokens += input_tokens;
183 entry.cache_read_tokens += cache_read_tokens;
184 entry.cache_write_tokens += cache_write_tokens;
185 entry.output_tokens += output_tokens;
186 entry.cost_cents += cost;
187 entry.request_count += 1;
188 model.clone_into(&mut entry.model);
189 }
190
191 pub fn check_budget(&self) -> Result<(), BudgetExhausted> {
195 if !self.enabled {
196 return Ok(());
197 }
198 let mut state = self.state.lock();
199 reset_if_new_day(&mut state);
200 if self.max_daily_cents > 0.0 && state.spent_cents >= self.max_daily_cents {
201 return Err(BudgetExhausted {
202 spent_cents: state.spent_cents,
203 budget_cents: self.max_daily_cents,
204 });
205 }
206 Ok(())
207 }
208
209 #[must_use]
211 pub fn max_daily_cents(&self) -> f64 {
212 self.max_daily_cents
213 }
214
215 #[must_use]
216 pub fn current_spend(&self) -> f64 {
217 let state = self.state.lock();
218 state.spent_cents
219 }
220
221 #[must_use]
223 pub fn provider_breakdown(&self) -> Vec<(String, ProviderUsage)> {
224 let state = self.state.lock();
225 let mut breakdown: Vec<(String, ProviderUsage)> = state
226 .providers
227 .iter()
228 .map(|(k, v)| (k.clone(), v.clone()))
229 .collect();
230 breakdown.sort_by(|a, b| {
231 b.1.cost_cents
232 .partial_cmp(&a.1.cost_cents)
233 .unwrap_or(std::cmp::Ordering::Equal)
234 });
235 breakdown
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 fn record(tracker: &CostTracker, provider: &str, model: &str, input: u64, output: u64) {
244 tracker.record_usage(provider, model, input, 0, 0, output);
245 }
246
247 #[test]
248 fn cost_tracker_records_usage_and_calculates_cost() {
249 let tracker = CostTracker::new(true, 1000.0);
250 record(&tracker, "openai", "gpt-4o", 1000, 1000);
251 let spend = tracker.current_spend();
253 assert!((spend - 1.25).abs() < 0.001);
254 }
255
256 #[test]
257 fn check_budget_passes_when_under_limit() {
258 let tracker = CostTracker::new(true, 100.0);
259 record(&tracker, "openai", "gpt-4o-mini", 100, 100);
260 assert!(tracker.check_budget().is_ok());
261 }
262
263 #[test]
264 fn check_budget_fails_when_over_limit() {
265 let tracker = CostTracker::new(true, 0.01);
266 record(&tracker, "claude", "claude-opus-4-20250514", 10000, 10000);
267 assert!(tracker.check_budget().is_err());
268 }
269
270 #[test]
271 fn daily_reset_clears_spending() {
272 let tracker = CostTracker::new(true, 100.0);
273 record(&tracker, "openai", "gpt-4o", 1000, 1000);
274 assert!(tracker.current_spend() > 0.0);
275 {
277 let mut state = tracker.state.lock();
278 state.day = 0; }
280 assert!(tracker.check_budget().is_ok());
282 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
283 }
284
285 #[test]
286 fn daily_reset_clears_provider_breakdown() {
287 let tracker = CostTracker::new(true, 100.0);
288 record(&tracker, "openai", "gpt-4o", 1000, 1000);
289 assert!(!tracker.provider_breakdown().is_empty());
290 {
292 let mut state = tracker.state.lock();
293 state.day = 0;
294 }
295 assert!(tracker.check_budget().is_ok());
296 assert!(tracker.provider_breakdown().is_empty());
297 }
298
299 #[test]
300 fn ollama_zero_cost() {
301 let tracker = CostTracker::new(true, 100.0);
302 record(&tracker, "ollama", "llama3:8b", 10000, 10000);
303 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
304 }
305
306 #[test]
307 fn unknown_model_zero_cost() {
308 let tracker = CostTracker::new(true, 100.0);
309 record(&tracker, "unknown", "totally-unknown-model", 5000, 5000);
310 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
311 }
312
313 #[test]
314 fn known_claude_model_has_nonzero_cost() {
315 let tracker = CostTracker::new(true, 1000.0);
316 record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 1000);
317 assert!(tracker.current_spend() > 0.0);
318 }
319
320 #[test]
321 fn gpt5_pricing_is_correct() {
322 let tracker = CostTracker::new(true, 1000.0);
323 record(&tracker, "openai", "gpt-5", 1000, 1000);
324 let spend = tracker.current_spend();
326 assert!((spend - 1.125).abs() < 0.001);
327 }
328
329 #[test]
330 fn gpt5_mini_pricing_is_correct() {
331 let tracker = CostTracker::new(true, 1000.0);
332 record(&tracker, "openai", "gpt-5-mini", 1000, 1000);
333 let spend = tracker.current_spend();
335 assert!((spend - 0.225).abs() < 0.001);
336 }
337
338 #[test]
339 fn disabled_tracker_always_passes() {
340 let tracker = CostTracker::new(false, 0.0);
341 record(
342 &tracker,
343 "claude",
344 "claude-opus-4-20250514",
345 1_000_000,
346 1_000_000,
347 );
348 assert!(tracker.check_budget().is_ok());
349 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
350 }
351
352 #[test]
353 fn check_budget_unlimited_when_max_daily_cents_is_zero() {
354 let tracker = CostTracker::new(true, 0.0);
355 record(
356 &tracker,
357 "claude",
358 "claude-opus-4-20250514",
359 100_000,
360 100_000,
361 );
362 assert!(tracker.check_budget().is_ok());
363 }
364
365 #[test]
366 fn per_provider_accumulation() {
367 let tracker = CostTracker::new(true, 1000.0);
368 record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 500);
369 record(&tracker, "openai", "gpt-4o", 2000, 1000);
370 record(&tracker, "claude", "claude-haiku-4-5-20251001", 500, 200);
371
372 let breakdown = tracker.provider_breakdown();
373 assert_eq!(breakdown.len(), 2);
374
375 let claude = breakdown.iter().find(|(n, _)| n == "claude").unwrap();
376 assert_eq!(claude.1.request_count, 2);
377 assert_eq!(claude.1.input_tokens, 1500);
378 assert_eq!(claude.1.output_tokens, 700);
379
380 let openai = breakdown.iter().find(|(n, _)| n == "openai").unwrap();
381 assert_eq!(openai.1.request_count, 1);
382 assert_eq!(openai.1.input_tokens, 2000);
383 }
384
385 #[test]
386 fn provider_breakdown_sorted_by_cost_desc() {
387 let tracker = CostTracker::new(true, 1000.0);
388 record(&tracker, "cheap", "gpt-4o-mini", 100, 100);
390 record(&tracker, "expensive", "claude-opus-4-20250514", 10000, 5000);
391
392 let breakdown = tracker.provider_breakdown();
393 assert_eq!(breakdown[0].0, "expensive");
394 }
395
396 #[test]
397 fn cache_tokens_included_in_cost() {
398 let tracker = CostTracker::new(true, 1000.0);
399 tracker.record_usage("claude", "claude-haiku-4-5-20251001", 0, 1000, 0, 0);
402 let spend = tracker.current_spend();
403 assert!(spend > 0.0, "cache read should contribute to cost");
404 }
405
406 #[test]
407 fn cache_write_cost_included_in_total() {
408 let tracker = CostTracker::new(true, 1000.0);
409 tracker.record_usage("claude-provider", "claude-opus-4-6", 0, 0, 1000, 0);
413 let cost = tracker.current_spend();
414 assert!((cost - 0.625).abs() < 0.001);
415 }
416
417 #[test]
418 fn provider_breakdown_empty_when_disabled() {
419 let tracker = CostTracker::new(false, 100.0);
420 tracker.record_usage("claude", "claude-haiku-4-5-20251001", 1000, 0, 0, 1000);
421 assert!(tracker.provider_breakdown().is_empty());
422 }
423}