1use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10#[error("daily budget exhausted: spent {spent_cents:.2} / {budget_cents:.2} cents")]
11pub struct BudgetExhausted {
12 pub spent_cents: f64,
13 pub budget_cents: f64,
14}
15
16#[derive(Debug, Clone)]
17pub struct ModelPricing {
18 pub prompt_cents_per_1k: f64,
19 pub completion_cents_per_1k: f64,
20}
21
22struct CostState {
23 spent_cents: f64,
24 day: u32,
25}
26
27pub struct CostTracker {
28 pricing: HashMap<String, ModelPricing>,
29 state: Arc<Mutex<CostState>>,
30 max_daily_cents: f64,
31 enabled: bool,
32}
33
34fn current_day() -> u32 {
35 use std::time::{SystemTime, UNIX_EPOCH};
36 let secs = SystemTime::now()
37 .duration_since(UNIX_EPOCH)
38 .unwrap_or_default()
39 .as_secs();
40 u32::try_from(secs / 86_400).unwrap_or(0)
42}
43
44fn default_pricing() -> HashMap<String, ModelPricing> {
45 let mut m = HashMap::new();
46 m.insert(
48 "claude-sonnet-4-20250514".into(),
49 ModelPricing {
50 prompt_cents_per_1k: 0.3,
51 completion_cents_per_1k: 1.5,
52 },
53 );
54 m.insert(
55 "claude-opus-4-20250514".into(),
56 ModelPricing {
57 prompt_cents_per_1k: 1.5,
58 completion_cents_per_1k: 7.5,
59 },
60 );
61 m.insert(
63 "claude-opus-4-1-20250805".into(),
64 ModelPricing {
65 prompt_cents_per_1k: 1.5,
66 completion_cents_per_1k: 7.5,
67 },
68 );
69 m.insert(
71 "claude-haiku-4-5-20251001".into(),
72 ModelPricing {
73 prompt_cents_per_1k: 0.1,
74 completion_cents_per_1k: 0.5,
75 },
76 );
77 m.insert(
78 "claude-sonnet-4-5-20250929".into(),
79 ModelPricing {
80 prompt_cents_per_1k: 0.3,
81 completion_cents_per_1k: 1.5,
82 },
83 );
84 m.insert(
85 "claude-opus-4-5-20251101".into(),
86 ModelPricing {
87 prompt_cents_per_1k: 0.5,
88 completion_cents_per_1k: 2.5,
89 },
90 );
91 m.insert(
93 "claude-sonnet-4-6".into(),
94 ModelPricing {
95 prompt_cents_per_1k: 0.3,
96 completion_cents_per_1k: 1.5,
97 },
98 );
99 m.insert(
100 "claude-opus-4-6".into(),
101 ModelPricing {
102 prompt_cents_per_1k: 0.5,
103 completion_cents_per_1k: 2.5,
104 },
105 );
106 m.insert(
107 "gpt-4o".into(),
108 ModelPricing {
109 prompt_cents_per_1k: 0.25,
110 completion_cents_per_1k: 1.0,
111 },
112 );
113 m.insert(
114 "gpt-4o-mini".into(),
115 ModelPricing {
116 prompt_cents_per_1k: 0.015,
117 completion_cents_per_1k: 0.06,
118 },
119 );
120 m.insert(
122 "gpt-5".into(),
123 ModelPricing {
124 prompt_cents_per_1k: 0.125,
125 completion_cents_per_1k: 1.0,
126 },
127 );
128 m.insert(
130 "gpt-5-mini".into(),
131 ModelPricing {
132 prompt_cents_per_1k: 0.025,
133 completion_cents_per_1k: 0.2,
134 },
135 );
136 m
137}
138
139impl CostTracker {
140 #[must_use]
141 pub fn new(enabled: bool, max_daily_cents: f64) -> Self {
142 Self {
143 pricing: default_pricing(),
144 state: Arc::new(Mutex::new(CostState {
145 spent_cents: 0.0,
146 day: current_day(),
147 })),
148 max_daily_cents,
149 enabled,
150 }
151 }
152
153 #[must_use]
154 pub fn with_pricing(mut self, model: &str, pricing: ModelPricing) -> Self {
155 self.pricing.insert(model.to_owned(), pricing);
156 self
157 }
158
159 pub fn record_usage(&self, model: &str, prompt_tokens: u64, completion_tokens: u64) {
160 if !self.enabled {
161 return;
162 }
163 let pricing = if let Some(p) = self.pricing.get(model).cloned() {
164 p
165 } else {
166 tracing::warn!(
167 model,
168 "model not found in pricing table; cost recorded as zero"
169 );
170 ModelPricing {
171 prompt_cents_per_1k: 0.0,
172 completion_cents_per_1k: 0.0,
173 }
174 };
175 #[allow(clippy::cast_precision_loss)]
176 let cost = pricing.prompt_cents_per_1k * (prompt_tokens as f64) / 1000.0
177 + pricing.completion_cents_per_1k * (completion_tokens as f64) / 1000.0;
178
179 let mut state = self
180 .state
181 .lock()
182 .unwrap_or_else(std::sync::PoisonError::into_inner);
183 let today = current_day();
184 if state.day != today {
185 state.spent_cents = 0.0;
186 state.day = today;
187 }
188 state.spent_cents += cost;
189 }
190
191 pub fn check_budget(&self) -> Result<(), BudgetExhausted> {
195 if !self.enabled {
196 return Ok(());
197 }
198 let mut state = self
199 .state
200 .lock()
201 .unwrap_or_else(std::sync::PoisonError::into_inner);
202 let today = current_day();
203 if state.day != today {
204 state.spent_cents = 0.0;
205 state.day = today;
206 }
207 if self.max_daily_cents > 0.0 && state.spent_cents >= self.max_daily_cents {
208 return Err(BudgetExhausted {
209 spent_cents: state.spent_cents,
210 budget_cents: self.max_daily_cents,
211 });
212 }
213 Ok(())
214 }
215
216 #[must_use]
217 pub fn current_spend(&self) -> f64 {
218 let state = self
219 .state
220 .lock()
221 .unwrap_or_else(std::sync::PoisonError::into_inner);
222 state.spent_cents
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn cost_tracker_records_usage_and_calculates_cost() {
232 let tracker = CostTracker::new(true, 1000.0);
233 tracker.record_usage("gpt-4o", 1000, 1000);
234 let spend = tracker.current_spend();
236 assert!((spend - 1.25).abs() < 0.001);
237 }
238
239 #[test]
240 fn check_budget_passes_when_under_limit() {
241 let tracker = CostTracker::new(true, 100.0);
242 tracker.record_usage("gpt-4o-mini", 100, 100);
243 assert!(tracker.check_budget().is_ok());
244 }
245
246 #[test]
247 fn check_budget_fails_when_over_limit() {
248 let tracker = CostTracker::new(true, 0.01);
249 tracker.record_usage("claude-opus-4-20250514", 10000, 10000);
250 assert!(tracker.check_budget().is_err());
251 }
252
253 #[test]
254 fn daily_reset_clears_spending() {
255 let tracker = CostTracker::new(true, 100.0);
256 tracker.record_usage("gpt-4o", 1000, 1000);
257 assert!(tracker.current_spend() > 0.0);
258 {
260 let mut state = tracker.state.lock().unwrap();
261 state.day = 0; }
263 assert!(tracker.check_budget().is_ok());
265 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
266 }
267
268 #[test]
269 fn ollama_zero_cost() {
270 let tracker = CostTracker::new(true, 100.0);
271 tracker.record_usage("llama3:8b", 10000, 10000);
272 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
273 }
274
275 #[test]
276 fn unknown_model_zero_cost() {
277 let tracker = CostTracker::new(true, 100.0);
278 tracker.record_usage("totally-unknown-model", 5000, 5000);
279 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
280 }
281
282 #[test]
283 fn known_claude_model_has_nonzero_cost() {
284 let tracker = CostTracker::new(true, 1000.0);
285 tracker.record_usage("claude-haiku-4-5-20251001", 1000, 1000);
286 assert!(tracker.current_spend() > 0.0);
287 }
288
289 #[test]
290 fn gpt5_pricing_is_correct() {
291 let tracker = CostTracker::new(true, 1000.0);
292 tracker.record_usage("gpt-5", 1000, 1000);
293 let spend = tracker.current_spend();
295 assert!((spend - 1.125).abs() < 0.001);
296 }
297
298 #[test]
299 fn gpt5_mini_pricing_is_correct() {
300 let tracker = CostTracker::new(true, 1000.0);
301 tracker.record_usage("gpt-5-mini", 1000, 1000);
302 let spend = tracker.current_spend();
304 assert!((spend - 0.225).abs() < 0.001);
305 }
306
307 #[test]
308 fn disabled_tracker_always_passes() {
309 let tracker = CostTracker::new(false, 0.0);
310 tracker.record_usage("claude-opus-4-20250514", 1_000_000, 1_000_000);
311 assert!(tracker.check_budget().is_ok());
312 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
313 }
314
315 #[test]
316 fn check_budget_unlimited_when_max_daily_cents_is_zero() {
317 let tracker = CostTracker::new(true, 0.0);
318 tracker.record_usage("claude-opus-4-20250514", 100_000, 100_000);
319 assert!(tracker.check_budget().is_ok());
320 }
321}