1use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10#[error("daily budget exhausted: spent {spent_cents:.2} / {budget_cents:.2} cents")]
11pub struct BudgetExhausted {
12 pub spent_cents: f64,
13 pub budget_cents: f64,
14}
15
16#[derive(Debug, Clone)]
17pub struct ModelPricing {
18 pub prompt_cents_per_1k: f64,
19 pub completion_cents_per_1k: f64,
20}
21
22struct CostState {
23 spent_cents: f64,
24 day: u32,
25}
26
27pub struct CostTracker {
28 pricing: HashMap<String, ModelPricing>,
29 state: Arc<Mutex<CostState>>,
30 max_daily_cents: f64,
31 enabled: bool,
32}
33
34fn current_day() -> u32 {
35 use std::time::{SystemTime, UNIX_EPOCH};
36 let secs = SystemTime::now()
37 .duration_since(UNIX_EPOCH)
38 .unwrap_or_default()
39 .as_secs();
40 u32::try_from(secs / 86_400).unwrap_or(0)
42}
43
44fn default_pricing() -> HashMap<String, ModelPricing> {
45 let mut m = HashMap::new();
46 m.insert(
48 "claude-sonnet-4-20250514".into(),
49 ModelPricing {
50 prompt_cents_per_1k: 0.3,
51 completion_cents_per_1k: 1.5,
52 },
53 );
54 m.insert(
55 "claude-opus-4-20250514".into(),
56 ModelPricing {
57 prompt_cents_per_1k: 1.5,
58 completion_cents_per_1k: 7.5,
59 },
60 );
61 m.insert(
63 "claude-opus-4-1-20250805".into(),
64 ModelPricing {
65 prompt_cents_per_1k: 1.5,
66 completion_cents_per_1k: 7.5,
67 },
68 );
69 m.insert(
71 "claude-haiku-4-5-20251001".into(),
72 ModelPricing {
73 prompt_cents_per_1k: 0.1,
74 completion_cents_per_1k: 0.5,
75 },
76 );
77 m.insert(
78 "claude-sonnet-4-5-20250929".into(),
79 ModelPricing {
80 prompt_cents_per_1k: 0.3,
81 completion_cents_per_1k: 1.5,
82 },
83 );
84 m.insert(
85 "claude-opus-4-5-20251101".into(),
86 ModelPricing {
87 prompt_cents_per_1k: 0.5,
88 completion_cents_per_1k: 2.5,
89 },
90 );
91 m.insert(
93 "claude-sonnet-4-6".into(),
94 ModelPricing {
95 prompt_cents_per_1k: 0.3,
96 completion_cents_per_1k: 1.5,
97 },
98 );
99 m.insert(
100 "claude-opus-4-6".into(),
101 ModelPricing {
102 prompt_cents_per_1k: 0.5,
103 completion_cents_per_1k: 2.5,
104 },
105 );
106 m.insert(
107 "gpt-4o".into(),
108 ModelPricing {
109 prompt_cents_per_1k: 0.25,
110 completion_cents_per_1k: 1.0,
111 },
112 );
113 m.insert(
114 "gpt-4o-mini".into(),
115 ModelPricing {
116 prompt_cents_per_1k: 0.015,
117 completion_cents_per_1k: 0.06,
118 },
119 );
120 m
121}
122
123impl CostTracker {
124 #[must_use]
125 pub fn new(enabled: bool, max_daily_cents: f64) -> Self {
126 Self {
127 pricing: default_pricing(),
128 state: Arc::new(Mutex::new(CostState {
129 spent_cents: 0.0,
130 day: current_day(),
131 })),
132 max_daily_cents,
133 enabled,
134 }
135 }
136
137 #[must_use]
138 pub fn with_pricing(mut self, model: &str, pricing: ModelPricing) -> Self {
139 self.pricing.insert(model.to_owned(), pricing);
140 self
141 }
142
143 pub fn record_usage(&self, model: &str, prompt_tokens: u64, completion_tokens: u64) {
144 if !self.enabled {
145 return;
146 }
147 let pricing = if let Some(p) = self.pricing.get(model).cloned() {
148 p
149 } else {
150 tracing::warn!(
151 model,
152 "model not found in pricing table; cost recorded as zero"
153 );
154 ModelPricing {
155 prompt_cents_per_1k: 0.0,
156 completion_cents_per_1k: 0.0,
157 }
158 };
159 #[allow(clippy::cast_precision_loss)]
160 let cost = pricing.prompt_cents_per_1k * (prompt_tokens as f64) / 1000.0
161 + pricing.completion_cents_per_1k * (completion_tokens as f64) / 1000.0;
162
163 let mut state = self
164 .state
165 .lock()
166 .unwrap_or_else(std::sync::PoisonError::into_inner);
167 let today = current_day();
168 if state.day != today {
169 state.spent_cents = 0.0;
170 state.day = today;
171 }
172 state.spent_cents += cost;
173 }
174
175 pub fn check_budget(&self) -> Result<(), BudgetExhausted> {
179 if !self.enabled {
180 return Ok(());
181 }
182 let mut state = self
183 .state
184 .lock()
185 .unwrap_or_else(std::sync::PoisonError::into_inner);
186 let today = current_day();
187 if state.day != today {
188 state.spent_cents = 0.0;
189 state.day = today;
190 }
191 if state.spent_cents >= self.max_daily_cents {
192 return Err(BudgetExhausted {
193 spent_cents: state.spent_cents,
194 budget_cents: self.max_daily_cents,
195 });
196 }
197 Ok(())
198 }
199
200 #[must_use]
201 pub fn current_spend(&self) -> f64 {
202 let state = self
203 .state
204 .lock()
205 .unwrap_or_else(std::sync::PoisonError::into_inner);
206 state.spent_cents
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 #[test]
215 fn cost_tracker_records_usage_and_calculates_cost() {
216 let tracker = CostTracker::new(true, 1000.0);
217 tracker.record_usage("gpt-4o", 1000, 1000);
218 let spend = tracker.current_spend();
220 assert!((spend - 1.25).abs() < 0.001);
221 }
222
223 #[test]
224 fn check_budget_passes_when_under_limit() {
225 let tracker = CostTracker::new(true, 100.0);
226 tracker.record_usage("gpt-4o-mini", 100, 100);
227 assert!(tracker.check_budget().is_ok());
228 }
229
230 #[test]
231 fn check_budget_fails_when_over_limit() {
232 let tracker = CostTracker::new(true, 0.01);
233 tracker.record_usage("claude-opus-4-20250514", 10000, 10000);
234 assert!(tracker.check_budget().is_err());
235 }
236
237 #[test]
238 fn daily_reset_clears_spending() {
239 let tracker = CostTracker::new(true, 100.0);
240 tracker.record_usage("gpt-4o", 1000, 1000);
241 assert!(tracker.current_spend() > 0.0);
242 {
244 let mut state = tracker.state.lock().unwrap();
245 state.day = 0; }
247 assert!(tracker.check_budget().is_ok());
249 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
250 }
251
252 #[test]
253 fn ollama_zero_cost() {
254 let tracker = CostTracker::new(true, 100.0);
255 tracker.record_usage("llama3:8b", 10000, 10000);
256 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
257 }
258
259 #[test]
260 fn unknown_model_zero_cost() {
261 let tracker = CostTracker::new(true, 100.0);
262 tracker.record_usage("totally-unknown-model", 5000, 5000);
263 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
264 }
265
266 #[test]
267 fn known_claude_model_has_nonzero_cost() {
268 let tracker = CostTracker::new(true, 1000.0);
269 tracker.record_usage("claude-haiku-4-5-20251001", 1000, 1000);
270 assert!(tracker.current_spend() > 0.0);
271 }
272
273 #[test]
274 fn disabled_tracker_always_passes() {
275 let tracker = CostTracker::new(false, 0.0);
276 tracker.record_usage("claude-opus-4-20250514", 1_000_000, 1_000_000);
277 assert!(tracker.check_budget().is_ok());
278 assert!((tracker.current_spend() - 0.0).abs() < 0.001);
279 }
280}