Skip to main content

zeph_core/
cost.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use parking_lot::Mutex;
8
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12#[error("daily budget exhausted: spent {spent_cents:.2} / {budget_cents:.2} cents")]
13pub struct BudgetExhausted {
14    pub spent_cents: f64,
15    pub budget_cents: f64,
16}
17
18/// Per-provider usage and cost breakdown for the current session/day.
19#[derive(Debug, Clone, Default)]
20pub struct ProviderUsage {
21    pub input_tokens: u64,
22    pub cache_read_tokens: u64,
23    pub cache_write_tokens: u64,
24    pub output_tokens: u64,
25    pub cost_cents: f64,
26    pub request_count: u64,
27    /// Last model seen for this provider (informational only — may change per-call).
28    pub model: String,
29}
30
31#[derive(Debug, Clone)]
32pub struct ModelPricing {
33    pub prompt_cents_per_1k: f64,
34    pub completion_cents_per_1k: f64,
35    /// Cache read (cache hit) price. Claude: 10% of prompt; `OpenAI`: 50%; others: 0%.
36    pub cache_read_cents_per_1k: f64,
37    /// Cache write (cache creation) price. Claude: 125% of prompt; others: 0%.
38    pub cache_write_cents_per_1k: f64,
39}
40
41struct CostState {
42    spent_cents: f64,
43    day: u32,
44    providers: HashMap<String, ProviderUsage>,
45}
46
47pub struct CostTracker {
48    pricing: HashMap<String, ModelPricing>,
49    state: Arc<Mutex<CostState>>,
50    max_daily_cents: f64,
51    enabled: bool,
52}
53
54fn current_day() -> u32 {
55    use std::time::{SystemTime, UNIX_EPOCH};
56    let secs = SystemTime::now()
57        .duration_since(UNIX_EPOCH)
58        .unwrap_or_default()
59        .as_secs();
60    // UTC day number (days since epoch)
61    u32::try_from(secs / 86_400).unwrap_or(0)
62}
63
64fn claude_pricing(prompt: f64, completion: f64) -> ModelPricing {
65    ModelPricing {
66        prompt_cents_per_1k: prompt,
67        completion_cents_per_1k: completion,
68        // Claude: cache read = 10% of prompt, cache write = 125% of prompt
69        cache_read_cents_per_1k: prompt * 0.1,
70        cache_write_cents_per_1k: prompt * 1.25,
71    }
72}
73
74fn openai_pricing(prompt: f64, completion: f64) -> ModelPricing {
75    ModelPricing {
76        prompt_cents_per_1k: prompt,
77        completion_cents_per_1k: completion,
78        // OpenAI: cache read = 50% of prompt, no cache write charge
79        cache_read_cents_per_1k: prompt * 0.5,
80        cache_write_cents_per_1k: 0.0,
81    }
82}
83
84fn default_pricing() -> HashMap<String, ModelPricing> {
85    let mut m = HashMap::new();
86    // Claude 4 (sonnet-4 / opus-4 base releases)
87    m.insert("claude-sonnet-4-20250514".into(), claude_pricing(0.3, 1.5));
88    m.insert("claude-opus-4-20250514".into(), claude_pricing(1.5, 7.5));
89    // Claude 4.1 Opus ($15/$75 per 1M tokens)
90    m.insert("claude-opus-4-1-20250805".into(), claude_pricing(1.5, 7.5));
91    // Claude 4.5 family
92    m.insert("claude-haiku-4-5-20251001".into(), claude_pricing(0.1, 0.5));
93    m.insert(
94        "claude-sonnet-4-5-20250929".into(),
95        claude_pricing(0.3, 1.5),
96    );
97    m.insert("claude-opus-4-5-20251101".into(), claude_pricing(0.5, 2.5));
98    // Claude 4.6 family
99    m.insert("claude-sonnet-4-6".into(), claude_pricing(0.3, 1.5));
100    m.insert("claude-opus-4-6".into(), claude_pricing(0.5, 2.5));
101    // OpenAI
102    m.insert("gpt-4o".into(), openai_pricing(0.25, 1.0));
103    m.insert("gpt-4o-mini".into(), openai_pricing(0.015, 0.06));
104    // GPT-5 family ($1.25/$10 per 1M tokens)
105    m.insert("gpt-5".into(), openai_pricing(0.125, 1.0));
106    // GPT-5 mini ($0.25/$2 per 1M tokens)
107    m.insert("gpt-5-mini".into(), openai_pricing(0.025, 0.2));
108    m
109}
110
111fn reset_if_new_day(state: &mut CostState) {
112    let today = current_day();
113    if state.day != today {
114        state.spent_cents = 0.0;
115        state.day = today;
116        state.providers.clear();
117    }
118}
119
120impl CostTracker {
121    #[must_use]
122    pub fn new(enabled: bool, max_daily_cents: f64) -> Self {
123        Self {
124            pricing: default_pricing(),
125            state: Arc::new(Mutex::new(CostState {
126                spent_cents: 0.0,
127                day: current_day(),
128                providers: HashMap::new(),
129            })),
130            max_daily_cents,
131            enabled,
132        }
133    }
134
135    #[must_use]
136    pub fn with_pricing(mut self, model: &str, pricing: ModelPricing) -> Self {
137        self.pricing.insert(model.to_owned(), pricing);
138        self
139    }
140
141    /// Record token usage for a single LLM call, attributed to `provider_name`.
142    ///
143    /// Cache token counts are optional (pass 0 when not available). Cost is computed
144    /// using model-specific pricing including cache read/write rates.
145    pub fn record_usage(
146        &self,
147        provider_name: &str,
148        model: &str,
149        input_tokens: u64,
150        cache_read_tokens: u64,
151        cache_write_tokens: u64,
152        output_tokens: u64,
153    ) {
154        if !self.enabled {
155            return;
156        }
157        let pricing = if let Some(p) = self.pricing.get(model).cloned() {
158            p
159        } else {
160            tracing::warn!(
161                model,
162                "model not found in pricing table; cost recorded as zero"
163            );
164            ModelPricing {
165                prompt_cents_per_1k: 0.0,
166                completion_cents_per_1k: 0.0,
167                cache_read_cents_per_1k: 0.0,
168                cache_write_cents_per_1k: 0.0,
169            }
170        };
171        #[allow(clippy::cast_precision_loss)]
172        let cost = pricing.prompt_cents_per_1k * (input_tokens as f64) / 1000.0
173            + pricing.completion_cents_per_1k * (output_tokens as f64) / 1000.0
174            + pricing.cache_read_cents_per_1k * (cache_read_tokens as f64) / 1000.0
175            + pricing.cache_write_cents_per_1k * (cache_write_tokens as f64) / 1000.0;
176
177        let mut state = self.state.lock();
178        reset_if_new_day(&mut state);
179        state.spent_cents += cost;
180
181        let entry = state.providers.entry(provider_name.to_owned()).or_default();
182        entry.input_tokens += input_tokens;
183        entry.cache_read_tokens += cache_read_tokens;
184        entry.cache_write_tokens += cache_write_tokens;
185        entry.output_tokens += output_tokens;
186        entry.cost_cents += cost;
187        entry.request_count += 1;
188        model.clone_into(&mut entry.model);
189    }
190
191    /// # Errors
192    ///
193    /// Returns `BudgetExhausted` when daily spend exceeds the configured limit.
194    pub fn check_budget(&self) -> Result<(), BudgetExhausted> {
195        if !self.enabled {
196            return Ok(());
197        }
198        let mut state = self.state.lock();
199        reset_if_new_day(&mut state);
200        if self.max_daily_cents > 0.0 && state.spent_cents >= self.max_daily_cents {
201            return Err(BudgetExhausted {
202                spent_cents: state.spent_cents,
203                budget_cents: self.max_daily_cents,
204            });
205        }
206        Ok(())
207    }
208
209    /// Returns the configured daily budget in cents. Zero means unlimited.
210    #[must_use]
211    pub fn max_daily_cents(&self) -> f64 {
212        self.max_daily_cents
213    }
214
215    #[must_use]
216    pub fn current_spend(&self) -> f64 {
217        let state = self.state.lock();
218        state.spent_cents
219    }
220
221    /// Returns per-provider breakdown sorted by cost descending.
222    #[must_use]
223    pub fn provider_breakdown(&self) -> Vec<(String, ProviderUsage)> {
224        let state = self.state.lock();
225        let mut breakdown: Vec<(String, ProviderUsage)> = state
226            .providers
227            .iter()
228            .map(|(k, v)| (k.clone(), v.clone()))
229            .collect();
230        breakdown.sort_by(|a, b| {
231            b.1.cost_cents
232                .partial_cmp(&a.1.cost_cents)
233                .unwrap_or(std::cmp::Ordering::Equal)
234        });
235        breakdown
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    fn record(tracker: &CostTracker, provider: &str, model: &str, input: u64, output: u64) {
244        tracker.record_usage(provider, model, input, 0, 0, output);
245    }
246
247    #[test]
248    fn cost_tracker_records_usage_and_calculates_cost() {
249        let tracker = CostTracker::new(true, 1000.0);
250        record(&tracker, "openai", "gpt-4o", 1000, 1000);
251        // 0.25 + 1.0 = 1.25
252        let spend = tracker.current_spend();
253        assert!((spend - 1.25).abs() < 0.001);
254    }
255
256    #[test]
257    fn check_budget_passes_when_under_limit() {
258        let tracker = CostTracker::new(true, 100.0);
259        record(&tracker, "openai", "gpt-4o-mini", 100, 100);
260        assert!(tracker.check_budget().is_ok());
261    }
262
263    #[test]
264    fn check_budget_fails_when_over_limit() {
265        let tracker = CostTracker::new(true, 0.01);
266        record(&tracker, "claude", "claude-opus-4-20250514", 10000, 10000);
267        assert!(tracker.check_budget().is_err());
268    }
269
270    #[test]
271    fn daily_reset_clears_spending() {
272        let tracker = CostTracker::new(true, 100.0);
273        record(&tracker, "openai", "gpt-4o", 1000, 1000);
274        assert!(tracker.current_spend() > 0.0);
275        // Simulate day change
276        {
277            let mut state = tracker.state.lock();
278            state.day = 0; // force a past day
279        }
280        // check_budget should reset
281        assert!(tracker.check_budget().is_ok());
282        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
283    }
284
285    #[test]
286    fn daily_reset_clears_provider_breakdown() {
287        let tracker = CostTracker::new(true, 100.0);
288        record(&tracker, "openai", "gpt-4o", 1000, 1000);
289        assert!(!tracker.provider_breakdown().is_empty());
290        // Simulate day change
291        {
292            let mut state = tracker.state.lock();
293            state.day = 0;
294        }
295        assert!(tracker.check_budget().is_ok());
296        assert!(tracker.provider_breakdown().is_empty());
297    }
298
299    #[test]
300    fn ollama_zero_cost() {
301        let tracker = CostTracker::new(true, 100.0);
302        record(&tracker, "ollama", "llama3:8b", 10000, 10000);
303        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
304    }
305
306    #[test]
307    fn unknown_model_zero_cost() {
308        let tracker = CostTracker::new(true, 100.0);
309        record(&tracker, "unknown", "totally-unknown-model", 5000, 5000);
310        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
311    }
312
313    #[test]
314    fn known_claude_model_has_nonzero_cost() {
315        let tracker = CostTracker::new(true, 1000.0);
316        record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 1000);
317        assert!(tracker.current_spend() > 0.0);
318    }
319
320    #[test]
321    fn gpt5_pricing_is_correct() {
322        let tracker = CostTracker::new(true, 1000.0);
323        record(&tracker, "openai", "gpt-5", 1000, 1000);
324        // 0.125 + 1.0 = 1.125
325        let spend = tracker.current_spend();
326        assert!((spend - 1.125).abs() < 0.001);
327    }
328
329    #[test]
330    fn gpt5_mini_pricing_is_correct() {
331        let tracker = CostTracker::new(true, 1000.0);
332        record(&tracker, "openai", "gpt-5-mini", 1000, 1000);
333        // 0.025 + 0.2 = 0.225
334        let spend = tracker.current_spend();
335        assert!((spend - 0.225).abs() < 0.001);
336    }
337
338    #[test]
339    fn disabled_tracker_always_passes() {
340        let tracker = CostTracker::new(false, 0.0);
341        record(
342            &tracker,
343            "claude",
344            "claude-opus-4-20250514",
345            1_000_000,
346            1_000_000,
347        );
348        assert!(tracker.check_budget().is_ok());
349        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
350    }
351
352    #[test]
353    fn check_budget_unlimited_when_max_daily_cents_is_zero() {
354        let tracker = CostTracker::new(true, 0.0);
355        record(
356            &tracker,
357            "claude",
358            "claude-opus-4-20250514",
359            100_000,
360            100_000,
361        );
362        assert!(tracker.check_budget().is_ok());
363    }
364
365    #[test]
366    fn per_provider_accumulation() {
367        let tracker = CostTracker::new(true, 1000.0);
368        record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 500);
369        record(&tracker, "openai", "gpt-4o", 2000, 1000);
370        record(&tracker, "claude", "claude-haiku-4-5-20251001", 500, 200);
371
372        let breakdown = tracker.provider_breakdown();
373        assert_eq!(breakdown.len(), 2);
374
375        let claude = breakdown.iter().find(|(n, _)| n == "claude").unwrap();
376        assert_eq!(claude.1.request_count, 2);
377        assert_eq!(claude.1.input_tokens, 1500);
378        assert_eq!(claude.1.output_tokens, 700);
379
380        let openai = breakdown.iter().find(|(n, _)| n == "openai").unwrap();
381        assert_eq!(openai.1.request_count, 1);
382        assert_eq!(openai.1.input_tokens, 2000);
383    }
384
385    #[test]
386    fn provider_breakdown_sorted_by_cost_desc() {
387        let tracker = CostTracker::new(true, 1000.0);
388        // gpt-4o: cheap; claude-opus: expensive
389        record(&tracker, "cheap", "gpt-4o-mini", 100, 100);
390        record(&tracker, "expensive", "claude-opus-4-20250514", 10000, 5000);
391
392        let breakdown = tracker.provider_breakdown();
393        assert_eq!(breakdown[0].0, "expensive");
394    }
395
396    #[test]
397    fn cache_tokens_included_in_cost() {
398        let tracker = CostTracker::new(true, 1000.0);
399        // claude-haiku prompt=0.1, cache_read=0.01 per 1k
400        // 1000 cache_read tokens = 0.01 cents; 0 input/output for isolation
401        tracker.record_usage("claude", "claude-haiku-4-5-20251001", 0, 1000, 0, 0);
402        let spend = tracker.current_spend();
403        assert!(spend > 0.0, "cache read should contribute to cost");
404    }
405
406    #[test]
407    fn cache_write_cost_included_in_total() {
408        let tracker = CostTracker::new(true, 1000.0);
409        // Claude pricing: cache_write = 125% of prompt price
410        // claude-opus-4-6: prompt = 0.5 cents/1k
411        // 1000 cache_write tokens = (0.5 * 1.25 * 1000) / 1000 = 0.625 cents
412        tracker.record_usage("claude-provider", "claude-opus-4-6", 0, 0, 1000, 0);
413        let cost = tracker.current_spend();
414        assert!((cost - 0.625).abs() < 0.001);
415    }
416
417    #[test]
418    fn provider_breakdown_empty_when_disabled() {
419        let tracker = CostTracker::new(false, 100.0);
420        tracker.record_usage("claude", "claude-haiku-4-5-20251001", 1000, 0, 0, 1000);
421        assert!(tracker.provider_breakdown().is_empty());
422    }
423}