Skip to main content

zeph_core/
cost.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use parking_lot::Mutex;
8
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12#[error("daily budget exhausted: spent {spent_cents:.2} / {budget_cents:.2} cents")]
13pub struct BudgetExhausted {
14    pub spent_cents: f64,
15    pub budget_cents: f64,
16}
17
18/// Per-provider usage and cost breakdown for the current session/day.
19#[derive(Debug, Clone, Default)]
20pub struct ProviderUsage {
21    pub input_tokens: u64,
22    pub cache_read_tokens: u64,
23    pub cache_write_tokens: u64,
24    pub output_tokens: u64,
25    pub cost_cents: f64,
26    pub request_count: u64,
27    /// Last model seen for this provider (informational only — may change per-call).
28    pub model: String,
29}
30
31#[derive(Debug, Clone)]
32pub struct ModelPricing {
33    pub prompt_cents_per_1k: f64,
34    pub completion_cents_per_1k: f64,
35    /// Cache read (cache hit) price. Claude: 10% of prompt; `OpenAI`: 50%; others: 0%.
36    pub cache_read_cents_per_1k: f64,
37    /// Cache write (cache creation) price. Claude: 125% of prompt; others: 0%.
38    pub cache_write_cents_per_1k: f64,
39}
40
41struct CostState {
42    spent_cents: f64,
43    day: u32,
44    providers: HashMap<String, ProviderUsage>,
45}
46
47pub struct CostTracker {
48    pricing: HashMap<String, ModelPricing>,
49    state: Arc<Mutex<CostState>>,
50    max_daily_cents: f64,
51    enabled: bool,
52}
53
54fn current_day() -> u32 {
55    use std::time::{SystemTime, UNIX_EPOCH};
56    let secs = SystemTime::now()
57        .duration_since(UNIX_EPOCH)
58        .unwrap_or_default()
59        .as_secs();
60    // UTC day number (days since epoch)
61    u32::try_from(secs / 86_400).unwrap_or(0)
62}
63
64fn claude_pricing(prompt: f64, completion: f64) -> ModelPricing {
65    ModelPricing {
66        prompt_cents_per_1k: prompt,
67        completion_cents_per_1k: completion,
68        // Claude: cache read = 10% of prompt, cache write = 125% of prompt
69        cache_read_cents_per_1k: prompt * 0.1,
70        cache_write_cents_per_1k: prompt * 1.25,
71    }
72}
73
74fn openai_pricing(prompt: f64, completion: f64) -> ModelPricing {
75    ModelPricing {
76        prompt_cents_per_1k: prompt,
77        completion_cents_per_1k: completion,
78        // OpenAI: cache read = 50% of prompt, no cache write charge
79        cache_read_cents_per_1k: prompt * 0.5,
80        cache_write_cents_per_1k: 0.0,
81    }
82}
83
84fn default_pricing() -> HashMap<String, ModelPricing> {
85    let mut m = HashMap::new();
86    // Claude 4 (sonnet-4 / opus-4 base releases)
87    m.insert("claude-sonnet-4-20250514".into(), claude_pricing(0.3, 1.5));
88    m.insert("claude-opus-4-20250514".into(), claude_pricing(1.5, 7.5));
89    // Claude 4.1 Opus ($15/$75 per 1M tokens)
90    m.insert("claude-opus-4-1-20250805".into(), claude_pricing(1.5, 7.5));
91    // Claude 4.5 family
92    m.insert("claude-haiku-4-5-20251001".into(), claude_pricing(0.1, 0.5));
93    m.insert(
94        "claude-sonnet-4-5-20250929".into(),
95        claude_pricing(0.3, 1.5),
96    );
97    m.insert("claude-opus-4-5-20251101".into(), claude_pricing(0.5, 2.5));
98    // Claude 4.6 family
99    m.insert("claude-sonnet-4-6".into(), claude_pricing(0.3, 1.5));
100    m.insert("claude-opus-4-6".into(), claude_pricing(0.5, 2.5));
101    // OpenAI
102    m.insert("gpt-4o".into(), openai_pricing(0.25, 1.0));
103    m.insert("gpt-4o-mini".into(), openai_pricing(0.015, 0.06));
104    // GPT-5 family ($1.25/$10 per 1M tokens)
105    m.insert("gpt-5".into(), openai_pricing(0.125, 1.0));
106    // GPT-5 mini ($0.25/$2 per 1M tokens)
107    m.insert("gpt-5-mini".into(), openai_pricing(0.025, 0.2));
108    m
109}
110
111fn reset_if_new_day(state: &mut CostState) {
112    let today = current_day();
113    if state.day != today {
114        state.spent_cents = 0.0;
115        state.day = today;
116        state.providers.clear();
117    }
118}
119
120impl CostTracker {
121    #[must_use]
122    pub fn new(enabled: bool, max_daily_cents: f64) -> Self {
123        Self {
124            pricing: default_pricing(),
125            state: Arc::new(Mutex::new(CostState {
126                spent_cents: 0.0,
127                day: current_day(),
128                providers: HashMap::new(),
129            })),
130            max_daily_cents,
131            enabled,
132        }
133    }
134
135    #[must_use]
136    pub fn with_pricing(mut self, model: &str, pricing: ModelPricing) -> Self {
137        self.pricing.insert(model.to_owned(), pricing);
138        self
139    }
140
141    /// Record token usage for a single LLM call, attributed to `provider_name`.
142    ///
143    /// `provider_kind` must be the value returned by `AnyProvider::provider_kind_str()`:
144    /// `"ollama"` or `"candle"` for local providers, `"cloud"` for API providers.
145    /// Local providers always have zero cost by design; the missing-pricing WARN is
146    /// suppressed for them to avoid log floods on every Ollama call.
147    ///
148    /// Cache token counts are optional (pass 0 when not available). Cost is computed
149    /// using model-specific pricing including cache read/write rates.
150    #[allow(clippy::too_many_arguments)]
151    pub fn record_usage(
152        &self,
153        provider_name: &str,
154        provider_kind: &str,
155        model: &str,
156        input_tokens: u64,
157        cache_read_tokens: u64,
158        cache_write_tokens: u64,
159        output_tokens: u64,
160    ) {
161        if !self.enabled {
162            return;
163        }
164        let pricing = if let Some(p) = self.pricing.get(model).cloned() {
165            p
166        } else {
167            let is_local = matches!(provider_kind, "ollama" | "candle" | "local");
168            if is_local {
169                tracing::debug!(model, "local model; cost recorded as zero");
170            } else {
171                tracing::warn!(
172                    model,
173                    "model not found in pricing table; cost recorded as zero"
174                );
175            }
176            ModelPricing {
177                prompt_cents_per_1k: 0.0,
178                completion_cents_per_1k: 0.0,
179                cache_read_cents_per_1k: 0.0,
180                cache_write_cents_per_1k: 0.0,
181            }
182        };
183        #[allow(clippy::cast_precision_loss)]
184        let cost = pricing.prompt_cents_per_1k * (input_tokens as f64) / 1000.0
185            + pricing.completion_cents_per_1k * (output_tokens as f64) / 1000.0
186            + pricing.cache_read_cents_per_1k * (cache_read_tokens as f64) / 1000.0
187            + pricing.cache_write_cents_per_1k * (cache_write_tokens as f64) / 1000.0;
188
189        let mut state = self.state.lock();
190        reset_if_new_day(&mut state);
191        state.spent_cents += cost;
192
193        let entry = state.providers.entry(provider_name.to_owned()).or_default();
194        entry.input_tokens += input_tokens;
195        entry.cache_read_tokens += cache_read_tokens;
196        entry.cache_write_tokens += cache_write_tokens;
197        entry.output_tokens += output_tokens;
198        entry.cost_cents += cost;
199        entry.request_count += 1;
200        model.clone_into(&mut entry.model);
201    }
202
203    /// # Errors
204    ///
205    /// Returns `BudgetExhausted` when daily spend exceeds the configured limit.
206    pub fn check_budget(&self) -> Result<(), BudgetExhausted> {
207        if !self.enabled {
208            return Ok(());
209        }
210        let mut state = self.state.lock();
211        reset_if_new_day(&mut state);
212        if self.max_daily_cents > 0.0 && state.spent_cents >= self.max_daily_cents {
213            return Err(BudgetExhausted {
214                spent_cents: state.spent_cents,
215                budget_cents: self.max_daily_cents,
216            });
217        }
218        Ok(())
219    }
220
221    /// Returns the configured daily budget in cents. Zero means unlimited.
222    #[must_use]
223    pub fn max_daily_cents(&self) -> f64 {
224        self.max_daily_cents
225    }
226
227    #[must_use]
228    pub fn current_spend(&self) -> f64 {
229        let state = self.state.lock();
230        state.spent_cents
231    }
232
233    /// Returns per-provider breakdown sorted by cost descending.
234    #[must_use]
235    pub fn provider_breakdown(&self) -> Vec<(String, ProviderUsage)> {
236        let state = self.state.lock();
237        let mut breakdown: Vec<(String, ProviderUsage)> = state
238            .providers
239            .iter()
240            .map(|(k, v)| (k.clone(), v.clone()))
241            .collect();
242        breakdown.sort_by(|a, b| {
243            b.1.cost_cents
244                .partial_cmp(&a.1.cost_cents)
245                .unwrap_or(std::cmp::Ordering::Equal)
246        });
247        breakdown
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    fn record(tracker: &CostTracker, provider: &str, model: &str, input: u64, output: u64) {
256        tracker.record_usage(provider, "cloud", model, input, 0, 0, output);
257    }
258
259    #[test]
260    fn cost_tracker_records_usage_and_calculates_cost() {
261        let tracker = CostTracker::new(true, 1000.0);
262        record(&tracker, "openai", "gpt-4o", 1000, 1000);
263        // 0.25 + 1.0 = 1.25
264        let spend = tracker.current_spend();
265        assert!((spend - 1.25).abs() < 0.001);
266    }
267
268    #[test]
269    fn check_budget_passes_when_under_limit() {
270        let tracker = CostTracker::new(true, 100.0);
271        record(&tracker, "openai", "gpt-4o-mini", 100, 100);
272        assert!(tracker.check_budget().is_ok());
273    }
274
275    #[test]
276    fn check_budget_fails_when_over_limit() {
277        let tracker = CostTracker::new(true, 0.01);
278        record(&tracker, "claude", "claude-opus-4-20250514", 10000, 10000);
279        assert!(tracker.check_budget().is_err());
280    }
281
282    #[test]
283    fn daily_reset_clears_spending() {
284        let tracker = CostTracker::new(true, 100.0);
285        record(&tracker, "openai", "gpt-4o", 1000, 1000);
286        assert!(tracker.current_spend() > 0.0);
287        // Simulate day change
288        {
289            let mut state = tracker.state.lock();
290            state.day = 0; // force a past day
291        }
292        // check_budget should reset
293        assert!(tracker.check_budget().is_ok());
294        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
295    }
296
297    #[test]
298    fn daily_reset_clears_provider_breakdown() {
299        let tracker = CostTracker::new(true, 100.0);
300        record(&tracker, "openai", "gpt-4o", 1000, 1000);
301        assert!(!tracker.provider_breakdown().is_empty());
302        // Simulate day change
303        {
304            let mut state = tracker.state.lock();
305            state.day = 0;
306        }
307        assert!(tracker.check_budget().is_ok());
308        assert!(tracker.provider_breakdown().is_empty());
309    }
310
311    #[test]
312    fn ollama_zero_cost() {
313        let tracker = CostTracker::new(true, 100.0);
314        record(&tracker, "ollama", "llama3:8b", 10000, 10000);
315        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
316    }
317
318    #[test]
319    fn ollama_unknown_model_no_warn_no_panic() {
320        // Local providers should silently record zero cost for unknown models.
321        let tracker = CostTracker::new(true, 100.0);
322        tracker.record_usage(
323            "local",
324            "ollama",
325            "totally-unknown-ollama-model",
326            5000,
327            0,
328            0,
329            5000,
330        );
331        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
332    }
333
334    #[test]
335    fn cloud_unknown_model_still_records_zero_cost() {
336        // Cloud providers record zero cost for unknown models (WARN emitted separately).
337        let tracker = CostTracker::new(true, 100.0);
338        tracker.record_usage(
339            "openai",
340            "cloud",
341            "totally-unknown-cloud-model",
342            5000,
343            0,
344            0,
345            5000,
346        );
347        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
348    }
349
350    #[test]
351    fn unknown_model_zero_cost() {
352        let tracker = CostTracker::new(true, 100.0);
353        record(&tracker, "unknown", "totally-unknown-model", 5000, 5000);
354        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
355    }
356
357    #[test]
358    fn known_claude_model_has_nonzero_cost() {
359        let tracker = CostTracker::new(true, 1000.0);
360        record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 1000);
361        assert!(tracker.current_spend() > 0.0);
362    }
363
364    #[test]
365    fn gpt5_pricing_is_correct() {
366        let tracker = CostTracker::new(true, 1000.0);
367        record(&tracker, "openai", "gpt-5", 1000, 1000);
368        // 0.125 + 1.0 = 1.125
369        let spend = tracker.current_spend();
370        assert!((spend - 1.125).abs() < 0.001);
371    }
372
373    #[test]
374    fn gpt5_mini_pricing_is_correct() {
375        let tracker = CostTracker::new(true, 1000.0);
376        record(&tracker, "openai", "gpt-5-mini", 1000, 1000);
377        // 0.025 + 0.2 = 0.225
378        let spend = tracker.current_spend();
379        assert!((spend - 0.225).abs() < 0.001);
380    }
381
382    #[test]
383    fn disabled_tracker_always_passes() {
384        let tracker = CostTracker::new(false, 0.0);
385        record(
386            &tracker,
387            "claude",
388            "claude-opus-4-20250514",
389            1_000_000,
390            1_000_000,
391        );
392        assert!(tracker.check_budget().is_ok());
393        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
394    }
395
396    #[test]
397    fn check_budget_unlimited_when_max_daily_cents_is_zero() {
398        let tracker = CostTracker::new(true, 0.0);
399        record(
400            &tracker,
401            "claude",
402            "claude-opus-4-20250514",
403            100_000,
404            100_000,
405        );
406        assert!(tracker.check_budget().is_ok());
407    }
408
409    #[test]
410    fn per_provider_accumulation() {
411        let tracker = CostTracker::new(true, 1000.0);
412        record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 500);
413        record(&tracker, "openai", "gpt-4o", 2000, 1000);
414        record(&tracker, "claude", "claude-haiku-4-5-20251001", 500, 200);
415
416        let breakdown = tracker.provider_breakdown();
417        assert_eq!(breakdown.len(), 2);
418
419        let claude = breakdown.iter().find(|(n, _)| n == "claude").unwrap();
420        assert_eq!(claude.1.request_count, 2);
421        assert_eq!(claude.1.input_tokens, 1500);
422        assert_eq!(claude.1.output_tokens, 700);
423
424        let openai = breakdown.iter().find(|(n, _)| n == "openai").unwrap();
425        assert_eq!(openai.1.request_count, 1);
426        assert_eq!(openai.1.input_tokens, 2000);
427    }
428
429    #[test]
430    fn provider_breakdown_sorted_by_cost_desc() {
431        let tracker = CostTracker::new(true, 1000.0);
432        // gpt-4o: cheap; claude-opus: expensive
433        record(&tracker, "cheap", "gpt-4o-mini", 100, 100);
434        record(&tracker, "expensive", "claude-opus-4-20250514", 10000, 5000);
435
436        let breakdown = tracker.provider_breakdown();
437        assert_eq!(breakdown[0].0, "expensive");
438    }
439
440    #[test]
441    fn cache_tokens_included_in_cost() {
442        let tracker = CostTracker::new(true, 1000.0);
443        // claude-haiku prompt=0.1, cache_read=0.01 per 1k
444        // 1000 cache_read tokens = 0.01 cents; 0 input/output for isolation
445        tracker.record_usage(
446            "claude",
447            "cloud",
448            "claude-haiku-4-5-20251001",
449            0,
450            1000,
451            0,
452            0,
453        );
454        let spend = tracker.current_spend();
455        assert!(spend > 0.0, "cache read should contribute to cost");
456    }
457
458    #[test]
459    fn cache_write_cost_included_in_total() {
460        let tracker = CostTracker::new(true, 1000.0);
461        // Claude pricing: cache_write = 125% of prompt price
462        // claude-opus-4-6: prompt = 0.5 cents/1k
463        // 1000 cache_write tokens = (0.5 * 1.25 * 1000) / 1000 = 0.625 cents
464        tracker.record_usage("claude-provider", "cloud", "claude-opus-4-6", 0, 0, 1000, 0);
465        let cost = tracker.current_spend();
466        assert!((cost - 0.625).abs() < 0.001);
467    }
468
469    #[test]
470    fn provider_breakdown_empty_when_disabled() {
471        let tracker = CostTracker::new(false, 100.0);
472        tracker.record_usage(
473            "claude",
474            "cloud",
475            "claude-haiku-4-5-20251001",
476            1000,
477            0,
478            0,
479            1000,
480        );
481        assert!(tracker.provider_breakdown().is_empty());
482    }
483}