1use crate::utils::generate_short_id;
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Instant;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CostReport {
14 pub total_input_tokens: u64,
16 pub total_output_tokens: u64,
18 pub total_cost_usd: f64,
20 pub model_costs: HashMap<String, ModelCost>,
22 pub timestamp: String,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ModelCost {
29 pub input_tokens: u64,
30 pub output_tokens: u64,
31 pub cost_usd: f64,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ModelPricing {
37 pub input_price_per_million: f64,
39 pub output_price_per_million: f64,
41}
42
43impl ModelPricing {
44 pub fn calculate_cost(&self, input_tokens: u64, output_tokens: u64) -> f64 {
46 let input_cost = (input_tokens as f64 / 1_000_000.0) * self.input_price_per_million;
47 let output_cost = (output_tokens as f64 / 1_000_000.0) * self.output_price_per_million;
48 input_cost + output_cost
49 }
50}
51
52fn default_pricing() -> HashMap<String, ModelPricing> {
54 let mut pricing = HashMap::new();
55
56 pricing.insert(
58 "claude-opus-4-6".to_string(),
59 ModelPricing {
60 input_price_per_million: 15.0,
61 output_price_per_million: 75.0,
62 },
63 );
64 pricing.insert(
65 "claude-sonnet-4-6".to_string(),
66 ModelPricing {
67 input_price_per_million: 3.0,
68 output_price_per_million: 15.0,
69 },
70 );
71 pricing.insert(
72 "claude-haiku-4-5".to_string(),
73 ModelPricing {
74 input_price_per_million: 0.8,
75 output_price_per_million: 4.0,
76 },
77 );
78
79 pricing.insert(
81 "claude-sonnet-4-6".to_string(),
82 ModelPricing {
83 input_price_per_million: 3.0,
84 output_price_per_million: 15.0,
85 },
86 );
87 pricing.insert(
88 "gpt-4o-mini".to_string(),
89 ModelPricing {
90 input_price_per_million: 0.15,
91 output_price_per_million: 0.6,
92 },
93 );
94
95 pricing
96}
97
98pub struct CostTracker {
100 usage: RwLock<HashMap<String, UsageRecord>>,
102 pricing: HashMap<String, ModelPricing>,
104 budget_limit: RwLock<Option<f64>>,
106 current_cost: RwLock<f64>,
108}
109
110#[derive(Debug, Clone)]
112struct UsageRecord {
113 model: String,
114 input_tokens: u64,
115 output_tokens: u64,
116 #[allow(dead_code)]
117 timestamp: Instant,
118}
119
120impl CostTracker {
121 pub fn new() -> Self {
122 Self {
123 usage: RwLock::new(HashMap::new()),
124 pricing: default_pricing(),
125 budget_limit: RwLock::new(None),
126 current_cost: RwLock::new(0.0),
127 }
128 }
129
130 pub fn set_budget_limit(&self, limit: f64) {
132 *self.budget_limit.write() = Some(limit);
133 }
134
135 pub fn record_usage(
137 &self,
138 model: &str,
139 input_tokens: u64,
140 output_tokens: u64,
141 ) -> anyhow::Result<()> {
142 let pricing = self.pricing.get(model).cloned().unwrap_or(ModelPricing {
144 input_price_per_million: 3.0,
146 output_price_per_million: 15.0,
147 });
148
149 let cost = pricing.calculate_cost(input_tokens, output_tokens);
150
151 let current = *self.current_cost.read();
153 let limit = *self.budget_limit.read();
154
155 if let Some(limit) = limit {
156 if current + cost > limit {
157 return Err(anyhow::anyhow!(
158 "Budget limit exceeded: current {:.4}, new {:.4}, limit {:.2}",
159 current,
160 current + cost,
161 limit
162 ));
163 }
164 }
165
166 let record_id = generate_short_id();
168 self.usage.write().insert(
169 record_id,
170 UsageRecord {
171 model: model.to_string(),
172 input_tokens,
173 output_tokens,
174 timestamp: Instant::now(),
175 },
176 );
177
178 *self.current_cost.write() += cost;
180
181 Ok(())
182 }
183
184 pub fn get_current_usage(&self) -> UsageSnapshot {
186 let usage = self.usage.read();
187 let mut model_costs = HashMap::new();
188 let mut total_input = 0;
189 let mut total_output = 0;
190
191 for record in usage.values() {
192 let entry = model_costs
193 .entry(record.model.clone())
194 .or_insert(ModelCost {
195 input_tokens: 0,
196 output_tokens: 0,
197 cost_usd: 0.0,
198 });
199
200 entry.input_tokens += record.input_tokens;
201 entry.output_tokens += record.output_tokens;
202
203 let pricing = self
204 .pricing
205 .get(&record.model)
206 .cloned()
207 .unwrap_or(ModelPricing {
208 input_price_per_million: 3.0,
209 output_price_per_million: 15.0,
210 });
211
212 entry.cost_usd += pricing.calculate_cost(record.input_tokens, record.output_tokens);
213
214 total_input += record.input_tokens;
215 total_output += record.output_tokens;
216 }
217
218 UsageSnapshot {
219 total_input_tokens: total_input,
220 total_output_tokens: total_output,
221 total_cost_usd: *self.current_cost.read(),
222 model_costs,
223 budget_remaining: self
224 .budget_limit
225 .read()
226 .map(|limit| limit - *self.current_cost.read()),
227 }
228 }
229
230 pub fn estimate_next_step(
232 &self,
233 model: &str,
234 estimated_input: u64,
235 estimated_output: u64,
236 ) -> CostEstimate {
237 let pricing = self.pricing.get(model).cloned().unwrap_or(ModelPricing {
238 input_price_per_million: 3.0,
239 output_price_per_million: 15.0,
240 });
241
242 let estimated_cost = pricing.calculate_cost(estimated_input, estimated_output);
243
244 CostEstimate {
245 min_tokens: estimated_input,
246 max_tokens: estimated_input + estimated_output,
247 estimated_cost_usd: estimated_cost,
248 confidence: "medium".to_string(), }
250 }
251
252 pub fn generate_report(&self) -> CostReport {
254 let snapshot = self.get_current_usage();
255
256 CostReport {
257 total_input_tokens: snapshot.total_input_tokens,
258 total_output_tokens: snapshot.total_output_tokens,
259 total_cost_usd: snapshot.total_cost_usd,
260 model_costs: snapshot.model_costs,
261 timestamp: chrono::Utc::now().to_rfc3339(),
262 }
263 }
264
265 pub fn reset(&self) {
267 self.usage.write().clear();
268 *self.current_cost.write() = 0.0;
269 }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct UsageSnapshot {
275 pub total_input_tokens: u64,
276 pub total_output_tokens: u64,
277 pub total_cost_usd: f64,
278 pub model_costs: HashMap<String, ModelCost>,
279 pub budget_remaining: Option<f64>,
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct CostEstimate {
285 pub min_tokens: u64,
286 pub max_tokens: u64,
287 pub estimated_cost_usd: f64,
288 pub confidence: String,
289}
290
291impl Default for CostTracker {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_pricing_calculation() {
303 let pricing = ModelPricing {
304 input_price_per_million: 3.0,
305 output_price_per_million: 15.0,
306 };
307
308 let cost = pricing.calculate_cost(1000, 500);
310 assert!(cost > 0.0);
311 assert!(cost < 1.0); }
313
314 #[test]
315 fn test_usage_tracking() {
316 let tracker = CostTracker::new();
317
318 tracker
319 .record_usage("claude-sonnet-4-6", 1000, 500)
320 .unwrap();
321
322 let snapshot = tracker.get_current_usage();
323 assert_eq!(snapshot.total_input_tokens, 1000);
324 assert_eq!(snapshot.total_output_tokens, 500);
325 }
326
327 #[test]
328 fn test_budget_limit() {
329 let tracker = CostTracker::new();
330 tracker.set_budget_limit(0.01); tracker.record_usage("claude-sonnet-4-6", 100, 50).unwrap();
334
335 let result = tracker.record_usage("claude-sonnet-4-6", 10000, 5000);
337 assert!(result.is_err());
338 }
339
340 #[test]
341 fn test_multiple_models() {
342 let tracker = CostTracker::new();
343
344 tracker.record_usage("claude-opus-4-6", 1000, 500).unwrap();
346 tracker
347 .record_usage("claude-sonnet-4-6", 2000, 1000)
348 .unwrap();
349 tracker.record_usage("claude-haiku-4-5", 500, 250).unwrap();
350 tracker.record_usage("gpt-4o", 1500, 750).unwrap();
351 tracker.record_usage("gpt-4o-mini", 3000, 1500).unwrap();
352
353 let snapshot = tracker.get_current_usage();
354
355 assert_eq!(snapshot.total_input_tokens, 8000);
357 assert_eq!(snapshot.total_output_tokens, 4000);
358
359 assert!(snapshot.model_costs.contains_key("claude-opus-4-6"));
361 assert!(snapshot.model_costs.contains_key("claude-sonnet-4-6"));
362 assert!(snapshot.model_costs.contains_key("claude-haiku-4-5"));
363 assert!(snapshot.model_costs.contains_key("gpt-4o"));
364 assert!(snapshot.model_costs.contains_key("gpt-4o-mini"));
365
366 assert!(snapshot.total_cost_usd > 0.0);
368
369 let opus_cost = snapshot
371 .model_costs
372 .get("claude-opus-4-6")
373 .unwrap()
374 .cost_usd;
375 let haiku_cost = snapshot
376 .model_costs
377 .get("claude-haiku-4-5")
378 .unwrap()
379 .cost_usd;
380
381 assert!(
383 opus_cost > haiku_cost,
384 "Opus should be more expensive than Haiku"
385 );
386 }
387
388 #[test]
389 fn test_budget_reset() {
390 let tracker = CostTracker::new();
391 tracker.set_budget_limit(1.0);
392
393 tracker
395 .record_usage("claude-sonnet-4-6", 5000, 2500)
396 .unwrap();
397 let snapshot = tracker.get_current_usage();
398 assert!(snapshot.total_cost_usd > 0.0);
399 assert!(snapshot.budget_remaining.is_some());
400 assert!(snapshot.budget_remaining.unwrap() < 1.0);
401
402 tracker.reset();
404
405 let snapshot = tracker.get_current_usage();
407 assert_eq!(snapshot.total_input_tokens, 0);
408 assert_eq!(snapshot.total_output_tokens, 0);
409 assert_eq!(snapshot.total_cost_usd, 0.0);
410 assert!(snapshot.model_costs.is_empty());
411
412 tracker
414 .record_usage("claude-sonnet-4-6", 1000, 500)
415 .unwrap();
416 let snapshot = tracker.get_current_usage();
417 assert!(snapshot.total_cost_usd > 0.0);
418 }
419
420 #[test]
421 fn test_concurrent_recording() {
422 use std::sync::Arc;
423 use std::thread;
424
425 let tracker = Arc::new(CostTracker::new());
426 let mut handles = vec![];
427
428 for i in 0..10 {
429 let t = Arc::clone(&tracker);
430 handles.push(thread::spawn(move || {
431 let model = match i % 3 {
432 0 => "claude-opus-4-6",
433 1 => "claude-sonnet-4-6",
434 _ => "claude-haiku-4-5",
435 };
436 t.record_usage(model, 100, 50).unwrap()
437 }));
438 }
439
440 for handle in handles {
442 handle.join().unwrap();
443 }
444
445 let snapshot = tracker.get_current_usage();
446 assert_eq!(snapshot.total_input_tokens, 1000);
447 assert_eq!(snapshot.total_output_tokens, 500);
448 }
449
450 #[test]
451 fn test_unknown_model_pricing() {
452 let tracker = CostTracker::new();
453
454 tracker.record_usage("unknown-model", 1000, 500).unwrap();
456
457 let snapshot = tracker.get_current_usage();
458 assert!(snapshot.model_costs.contains_key("unknown-model"));
459 let cost = snapshot.model_costs.get("unknown-model").unwrap().cost_usd;
461 assert!(cost > 0.0);
462 }
463
464 #[test]
465 fn test_estimate_next_step() {
466 let tracker = CostTracker::new();
467
468 let estimate = tracker.estimate_next_step("claude-sonnet-4-6", 1000, 500);
469 assert_eq!(estimate.min_tokens, 1000);
470 assert_eq!(estimate.max_tokens, 1500);
471 assert!(estimate.estimated_cost_usd > 0.0);
472 }
473
474 #[test]
475 fn test_generate_report() {
476 let tracker = CostTracker::new();
477
478 tracker
479 .record_usage("claude-sonnet-4-6", 1000, 500)
480 .unwrap();
481
482 let report = tracker.generate_report();
483 assert_eq!(report.total_input_tokens, 1000);
484 assert_eq!(report.total_output_tokens, 500);
485 assert!(!report.timestamp.is_empty());
486 }
487
488 #[test]
489 fn test_budget_remaining_calculation() {
490 let tracker = CostTracker::new();
491 tracker.set_budget_limit(1.0); tracker
494 .record_usage("claude-sonnet-4-6", 1000, 500)
495 .unwrap();
496
497 let snapshot = tracker.get_current_usage();
498 assert!(snapshot.budget_remaining.is_some());
499 let remaining = snapshot.budget_remaining.unwrap();
500
501 assert!(remaining < 1.0);
503 assert!(remaining > 0.9); }
505}