1use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10#[error("daily budget exhausted: spent {spent_cents:.2} / {budget_cents:.2} cents")]
11pub struct BudgetExhausted {
12 pub spent_cents: f64,
13 pub budget_cents: f64,
14}
15
16#[derive(Debug, Clone)]
17pub struct ModelPricing {
18 pub prompt_cents_per_1k: f64,
19 pub completion_cents_per_1k: f64,
20}
21
22struct CostState {
23 spent_cents: f64,
24 day: u32,
25}
26
27pub struct CostTracker {
28 pricing: HashMap<String, ModelPricing>,
29 state: Arc<Mutex<CostState>>,
30 max_daily_cents: f64,
31 enabled: bool,
32}
33
34fn current_day() -> u32 {
35 use std::time::{SystemTime, UNIX_EPOCH};
36 let secs = SystemTime::now()
37 .duration_since(UNIX_EPOCH)
38 .unwrap_or_default()
39 .as_secs();
40 u32::try_from(secs / 86_400).unwrap_or(0)
42}
43
44fn default_pricing() -> HashMap<String, ModelPricing> {
45 let mut m = HashMap::new();
46 m.insert(
47 "claude-sonnet-4-20250514".into(),
48 ModelPricing {
49 prompt_cents_per_1k: 0.3,
50 completion_cents_per_1k: 1.5,
51 },
52 );
53 m.insert(
54 "claude-opus-4-20250514".into(),
55 ModelPricing {
56 prompt_cents_per_1k: 1.5,
57 completion_cents_per_1k: 7.5,
58 },
59 );
60 m.insert(
61 "gpt-4o".into(),
62 ModelPricing {
63 prompt_cents_per_1k: 0.25,
64 completion_cents_per_1k: 1.0,
65 },
66 );
67 m.insert(
68 "gpt-4o-mini".into(),
69 ModelPricing {
70 prompt_cents_per_1k: 0.015,
71 completion_cents_per_1k: 0.06,
72 },
73 );
74 m
75}
76
77impl CostTracker {
78 #[must_use]
79 pub fn new(enabled: bool, max_daily_cents: f64) -> Self {
80 Self {
81 pricing: default_pricing(),
82 state: Arc::new(Mutex::new(CostState {
83 spent_cents: 0.0,
84 day: current_day(),
85 })),
86 max_daily_cents,
87 enabled,
88 }
89 }
90
91 #[must_use]
92 pub fn with_pricing(mut self, model: &str, pricing: ModelPricing) -> Self {
93 self.pricing.insert(model.to_owned(), pricing);
94 self
95 }
96
97 pub fn record_usage(&self, model: &str, prompt_tokens: u64, completion_tokens: u64) {
98 if !self.enabled {
99 return;
100 }
101 let pricing = self.pricing.get(model).cloned().unwrap_or(ModelPricing {
102 prompt_cents_per_1k: 0.0,
103 completion_cents_per_1k: 0.0,
104 });
105 #[allow(clippy::cast_precision_loss)]
106 let cost = pricing.prompt_cents_per_1k * (prompt_tokens as f64) / 1000.0
107 + pricing.completion_cents_per_1k * (completion_tokens as f64) / 1000.0;
108
109 let mut state = self
110 .state
111 .lock()
112 .unwrap_or_else(std::sync::PoisonError::into_inner);
113 let today = current_day();
114 if state.day != today {
115 state.spent_cents = 0.0;
116 state.day = today;
117 }
118 state.spent_cents += cost;
119 }
120
121 pub fn check_budget(&self) -> Result<(), BudgetExhausted> {
125 if !self.enabled {
126 return Ok(());
127 }
128 let mut state = self
129 .state
130 .lock()
131 .unwrap_or_else(std::sync::PoisonError::into_inner);
132 let today = current_day();
133 if state.day != today {
134 state.spent_cents = 0.0;
135 state.day = today;
136 }
137 if state.spent_cents >= self.max_daily_cents {
138 return Err(BudgetExhausted {
139 spent_cents: state.spent_cents,
140 budget_cents: self.max_daily_cents,
141 });
142 }
143 Ok(())
144 }
145
146 #[must_use]
147 pub fn current_spend(&self) -> f64 {
148 let state = self
149 .state
150 .lock()
151 .unwrap_or_else(std::sync::PoisonError::into_inner);
152 state.spent_cents
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn cost_tracker_records_usage_and_calculates_cost() {
162 let tracker = CostTracker::new(true, 1000.0);
163 tracker.record_usage("gpt-4o", 1000, 1000);
164 let spend = tracker.current_spend();
166 assert!((spend - 1.25).abs() < 0.001);
167 }
168
169 #[test]
170 fn check_budget_passes_when_under_limit() {
171 let tracker = CostTracker::new(true, 100.0);
172 tracker.record_usage("gpt-4o-mini", 100, 100);
173 assert!(tracker.check_budget().is_ok());
174 }
175
176 #[test]
177 fn check_budget_fails_when_over_limit() {
178 let tracker = CostTracker::new(true, 0.01);
179 tracker.record_usage("claude-opus-4-20250514", 10000, 10000);
180 assert!(tracker.check_budget().is_err());
181 }
182
183 #[test]
184 fn daily_reset_clears_spending() {
185 let tracker = CostTracker::new(true, 100.0);
186 tracker.record_usage("gpt-4o", 1000, 1000);
187 assert!(tracker.current_spend() > 0.0);
188 {
190 let mut state = tracker.state.lock().unwrap();
191 state.day = 0; }
193 assert!(tracker.check_budget().is_ok());
195 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
196 }
197
198 #[test]
199 fn ollama_zero_cost() {
200 let tracker = CostTracker::new(true, 100.0);
201 tracker.record_usage("llama3:8b", 10000, 10000);
202 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
203 }
204
205 #[test]
206 fn unknown_model_zero_cost() {
207 let tracker = CostTracker::new(true, 100.0);
208 tracker.record_usage("totally-unknown-model", 5000, 5000);
209 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
210 }
211
212 #[test]
213 fn disabled_tracker_always_passes() {
214 let tracker = CostTracker::new(false, 0.0);
215 tracker.record_usage("claude-opus-4-20250514", 1_000_000, 1_000_000);
216 assert!(tracker.check_budget().is_ok());
217 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
218 }
219}