Skip to main content

zeph_core/
cost.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use parking_lot::Mutex;
8
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12#[error("daily budget exhausted: spent {spent_cents:.2} / {budget_cents:.2} cents")]
13pub struct BudgetExhausted {
14    pub spent_cents: f64,
15    pub budget_cents: f64,
16}
17
18/// Per-provider usage and cost breakdown for the current session/day.
19#[derive(Debug, Clone, Default)]
20pub struct ProviderUsage {
21    pub input_tokens: u64,
22    pub cache_read_tokens: u64,
23    pub cache_write_tokens: u64,
24    pub output_tokens: u64,
25    pub cost_cents: f64,
26    pub request_count: u64,
27    /// Last model seen for this provider (informational only — may change per-call).
28    pub model: String,
29}
30
31#[derive(Debug, Clone)]
32pub struct ModelPricing {
33    pub prompt_cents_per_1k: f64,
34    pub completion_cents_per_1k: f64,
35    /// Cache read (cache hit) price. Claude: 10% of prompt; `OpenAI`: 50%; others: 0%.
36    pub cache_read_cents_per_1k: f64,
37    /// Cache write (cache creation) price. Claude: 125% of prompt; others: 0%.
38    pub cache_write_cents_per_1k: f64,
39}
40
41struct CostState {
42    spent_cents: f64,
43    day: u32,
44    providers: HashMap<String, ProviderUsage>,
45    successful_tasks: u64,
46}
47
48pub struct CostTracker {
49    pricing: HashMap<String, ModelPricing>,
50    state: Arc<Mutex<CostState>>,
51    max_daily_cents: f64,
52    enabled: bool,
53}
54
55fn current_day() -> u32 {
56    use std::time::{SystemTime, UNIX_EPOCH};
57    let secs = SystemTime::now()
58        .duration_since(UNIX_EPOCH)
59        .unwrap_or_default()
60        .as_secs();
61    // UTC day number (days since epoch)
62    u32::try_from(secs / 86_400).unwrap_or(0)
63}
64
65fn claude_pricing(prompt: f64, completion: f64) -> ModelPricing {
66    ModelPricing {
67        prompt_cents_per_1k: prompt,
68        completion_cents_per_1k: completion,
69        // Claude: cache read = 10% of prompt, cache write = 125% of prompt
70        cache_read_cents_per_1k: prompt * 0.1,
71        cache_write_cents_per_1k: prompt * 1.25,
72    }
73}
74
75fn openai_pricing(prompt: f64, completion: f64) -> ModelPricing {
76    ModelPricing {
77        prompt_cents_per_1k: prompt,
78        completion_cents_per_1k: completion,
79        // OpenAI: cache read = 50% of prompt, no cache write charge
80        cache_read_cents_per_1k: prompt * 0.5,
81        cache_write_cents_per_1k: 0.0,
82    }
83}
84
85fn default_pricing() -> HashMap<String, ModelPricing> {
86    let mut m = HashMap::new();
87    // Claude 4 (sonnet-4 / opus-4 base releases)
88    m.insert("claude-sonnet-4-20250514".into(), claude_pricing(0.3, 1.5));
89    m.insert("claude-opus-4-20250514".into(), claude_pricing(1.5, 7.5));
90    // Claude 4.1 Opus ($15/$75 per 1M tokens)
91    m.insert("claude-opus-4-1-20250805".into(), claude_pricing(1.5, 7.5));
92    // Claude 4.5 family
93    m.insert("claude-haiku-4-5-20251001".into(), claude_pricing(0.1, 0.5));
94    m.insert(
95        "claude-sonnet-4-5-20250929".into(),
96        claude_pricing(0.3, 1.5),
97    );
98    m.insert("claude-opus-4-5-20251101".into(), claude_pricing(0.5, 2.5));
99    // Claude 4.6 family
100    m.insert("claude-sonnet-4-6".into(), claude_pricing(0.3, 1.5));
101    m.insert("claude-opus-4-6".into(), claude_pricing(0.5, 2.5));
102    // OpenAI
103    m.insert("gpt-4o".into(), openai_pricing(0.25, 1.0));
104    m.insert("gpt-4o-mini".into(), openai_pricing(0.015, 0.06));
105    // GPT-5 family ($1.25/$10 per 1M tokens)
106    m.insert("gpt-5".into(), openai_pricing(0.125, 1.0));
107    // GPT-5 mini ($0.25/$2 per 1M tokens)
108    m.insert("gpt-5-mini".into(), openai_pricing(0.025, 0.2));
109    m
110}
111
112fn reset_if_new_day(state: &mut CostState) {
113    let today = current_day();
114    if state.day != today {
115        state.spent_cents = 0.0;
116        state.day = today;
117        state.providers.clear();
118        state.successful_tasks = 0;
119    }
120}
121
122impl CostTracker {
123    #[must_use]
124    pub fn new(enabled: bool, max_daily_cents: f64) -> Self {
125        Self {
126            pricing: default_pricing(),
127            state: Arc::new(Mutex::new(CostState {
128                spent_cents: 0.0,
129                day: current_day(),
130                providers: HashMap::new(),
131                successful_tasks: 0,
132            })),
133            max_daily_cents,
134            enabled,
135        }
136    }
137
138    #[must_use]
139    pub fn with_pricing(mut self, model: &str, pricing: ModelPricing) -> Self {
140        self.pricing.insert(model.to_owned(), pricing);
141        self
142    }
143
144    /// Record token usage for a single LLM call, attributed to `provider_name`.
145    ///
146    /// `provider_kind` must be the value returned by `AnyProvider::provider_kind_str()`:
147    /// `"ollama"` or `"candle"` for local providers, `"cloud"` for API providers.
148    /// Local providers always have zero cost by design; the missing-pricing WARN is
149    /// suppressed for them to avoid log floods on every Ollama call.
150    ///
151    /// Cache token counts are optional (pass 0 when not available). Cost is computed
152    /// using model-specific pricing including cache read/write rates.
153    #[allow(clippy::too_many_arguments)] // function with many required inputs; a *Params struct would be more verbose without simplifying the call site
154    pub fn record_usage(
155        &self,
156        provider_name: &str,
157        provider_kind: &str,
158        model: &str,
159        input_tokens: u64,
160        cache_read_tokens: u64,
161        cache_write_tokens: u64,
162        output_tokens: u64,
163    ) {
164        if !self.enabled {
165            return;
166        }
167        let pricing = if let Some(p) = self.pricing.get(model).cloned() {
168            p
169        } else {
170            let is_local = matches!(provider_kind, "ollama" | "candle" | "local");
171            if is_local {
172                tracing::debug!(model, "local model; cost recorded as zero");
173            } else {
174                tracing::warn!(
175                    model,
176                    "model not found in pricing table; cost recorded as zero"
177                );
178            }
179            ModelPricing {
180                prompt_cents_per_1k: 0.0,
181                completion_cents_per_1k: 0.0,
182                cache_read_cents_per_1k: 0.0,
183                cache_write_cents_per_1k: 0.0,
184            }
185        };
186        #[allow(clippy::cast_precision_loss)]
187        let cost = pricing.prompt_cents_per_1k * (input_tokens as f64) / 1000.0
188            + pricing.completion_cents_per_1k * (output_tokens as f64) / 1000.0
189            + pricing.cache_read_cents_per_1k * (cache_read_tokens as f64) / 1000.0
190            + pricing.cache_write_cents_per_1k * (cache_write_tokens as f64) / 1000.0;
191
192        let mut state = self.state.lock();
193        reset_if_new_day(&mut state);
194        state.spent_cents += cost;
195
196        let entry = state.providers.entry(provider_name.to_owned()).or_default();
197        entry.input_tokens += input_tokens;
198        entry.cache_read_tokens += cache_read_tokens;
199        entry.cache_write_tokens += cache_write_tokens;
200        entry.output_tokens += output_tokens;
201        entry.cost_cents += cost;
202        entry.request_count += 1;
203        model.clone_into(&mut entry.model);
204    }
205
206    /// # Errors
207    ///
208    /// Returns `BudgetExhausted` when daily spend exceeds the configured limit.
209    pub fn check_budget(&self) -> Result<(), BudgetExhausted> {
210        if !self.enabled {
211            return Ok(());
212        }
213        let mut state = self.state.lock();
214        reset_if_new_day(&mut state);
215        if self.max_daily_cents > 0.0 && state.spent_cents >= self.max_daily_cents {
216            return Err(BudgetExhausted {
217                spent_cents: state.spent_cents,
218                budget_cents: self.max_daily_cents,
219            });
220        }
221        Ok(())
222    }
223
224    /// Returns the configured daily budget in cents. Zero means unlimited.
225    #[must_use]
226    pub fn max_daily_cents(&self) -> f64 {
227        self.max_daily_cents
228    }
229
230    #[must_use]
231    pub fn current_spend(&self) -> f64 {
232        let state = self.state.lock();
233        state.spent_cents
234    }
235
236    /// Increment the successful-task counter.
237    ///
238    /// Call after each turn that completes without error and produces a usable agent response.
239    pub fn record_successful_task(&self) {
240        if !self.enabled {
241            return;
242        }
243        let mut state = self.state.lock();
244        reset_if_new_day(&mut state);
245        state.successful_tasks += 1;
246    }
247
248    /// Returns cost-per-successful-task in cents, or `None` if no tasks recorded yet.
249    #[must_use]
250    pub fn cps(&self) -> Option<f64> {
251        let state = self.state.lock();
252        if state.successful_tasks == 0 {
253            return None;
254        }
255        #[allow(clippy::cast_precision_loss)]
256        Some(state.spent_cents / state.successful_tasks as f64)
257    }
258
259    /// Returns total number of successful tasks recorded today.
260    #[must_use]
261    pub fn successful_tasks(&self) -> u64 {
262        self.state.lock().successful_tasks
263    }
264
265    /// Returns per-provider breakdown sorted by cost descending.
266    #[must_use]
267    pub fn provider_breakdown(&self) -> Vec<(String, ProviderUsage)> {
268        let state = self.state.lock();
269        let mut breakdown: Vec<(String, ProviderUsage)> = state
270            .providers
271            .iter()
272            .map(|(k, v)| (k.clone(), v.clone()))
273            .collect();
274        breakdown.sort_by(|a, b| {
275            b.1.cost_cents
276                .partial_cmp(&a.1.cost_cents)
277                .unwrap_or(std::cmp::Ordering::Equal)
278        });
279        breakdown
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    fn record(tracker: &CostTracker, provider: &str, model: &str, input: u64, output: u64) {
288        tracker.record_usage(provider, "cloud", model, input, 0, 0, output);
289    }
290
291    #[test]
292    fn cost_tracker_records_usage_and_calculates_cost() {
293        let tracker = CostTracker::new(true, 1000.0);
294        record(&tracker, "openai", "gpt-4o", 1000, 1000);
295        // 0.25 + 1.0 = 1.25
296        let spend = tracker.current_spend();
297        assert!((spend - 1.25).abs() < 0.001);
298    }
299
300    #[test]
301    fn check_budget_passes_when_under_limit() {
302        let tracker = CostTracker::new(true, 100.0);
303        record(&tracker, "openai", "gpt-4o-mini", 100, 100);
304        assert!(tracker.check_budget().is_ok());
305    }
306
307    #[test]
308    fn check_budget_fails_when_over_limit() {
309        let tracker = CostTracker::new(true, 0.01);
310        record(&tracker, "claude", "claude-opus-4-20250514", 10000, 10000);
311        assert!(tracker.check_budget().is_err());
312    }
313
314    #[test]
315    fn daily_reset_clears_spending() {
316        let tracker = CostTracker::new(true, 100.0);
317        record(&tracker, "openai", "gpt-4o", 1000, 1000);
318        assert!(tracker.current_spend() > 0.0);
319        // Simulate day change
320        {
321            let mut state = tracker.state.lock();
322            state.day = 0; // force a past day
323        }
324        // check_budget should reset
325        assert!(tracker.check_budget().is_ok());
326        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
327    }
328
329    #[test]
330    fn daily_reset_clears_provider_breakdown() {
331        let tracker = CostTracker::new(true, 100.0);
332        record(&tracker, "openai", "gpt-4o", 1000, 1000);
333        assert!(!tracker.provider_breakdown().is_empty());
334        // Simulate day change
335        {
336            let mut state = tracker.state.lock();
337            state.day = 0;
338        }
339        assert!(tracker.check_budget().is_ok());
340        assert!(tracker.provider_breakdown().is_empty());
341    }
342
343    #[test]
344    fn ollama_zero_cost() {
345        let tracker = CostTracker::new(true, 100.0);
346        record(&tracker, "ollama", "llama3:8b", 10000, 10000);
347        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
348    }
349
350    #[test]
351    fn ollama_unknown_model_no_warn_no_panic() {
352        // Local providers should silently record zero cost for unknown models.
353        let tracker = CostTracker::new(true, 100.0);
354        tracker.record_usage(
355            "local",
356            "ollama",
357            "totally-unknown-ollama-model",
358            5000,
359            0,
360            0,
361            5000,
362        );
363        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
364    }
365
366    #[test]
367    fn cloud_unknown_model_still_records_zero_cost() {
368        // Cloud providers record zero cost for unknown models (WARN emitted separately).
369        let tracker = CostTracker::new(true, 100.0);
370        tracker.record_usage(
371            "openai",
372            "cloud",
373            "totally-unknown-cloud-model",
374            5000,
375            0,
376            0,
377            5000,
378        );
379        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
380    }
381
382    #[test]
383    fn unknown_model_zero_cost() {
384        let tracker = CostTracker::new(true, 100.0);
385        record(&tracker, "unknown", "totally-unknown-model", 5000, 5000);
386        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
387    }
388
389    #[test]
390    fn known_claude_model_has_nonzero_cost() {
391        let tracker = CostTracker::new(true, 1000.0);
392        record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 1000);
393        assert!(tracker.current_spend() > 0.0);
394    }
395
396    #[test]
397    fn gpt5_pricing_is_correct() {
398        let tracker = CostTracker::new(true, 1000.0);
399        record(&tracker, "openai", "gpt-5", 1000, 1000);
400        // 0.125 + 1.0 = 1.125
401        let spend = tracker.current_spend();
402        assert!((spend - 1.125).abs() < 0.001);
403    }
404
405    #[test]
406    fn gpt5_mini_pricing_is_correct() {
407        let tracker = CostTracker::new(true, 1000.0);
408        record(&tracker, "openai", "gpt-5-mini", 1000, 1000);
409        // 0.025 + 0.2 = 0.225
410        let spend = tracker.current_spend();
411        assert!((spend - 0.225).abs() < 0.001);
412    }
413
414    #[test]
415    fn disabled_tracker_always_passes() {
416        let tracker = CostTracker::new(false, 0.0);
417        record(
418            &tracker,
419            "claude",
420            "claude-opus-4-20250514",
421            1_000_000,
422            1_000_000,
423        );
424        assert!(tracker.check_budget().is_ok());
425        assert!((tracker.current_spend() - 0.0).abs() < 0.001);
426    }
427
428    #[test]
429    fn check_budget_unlimited_when_max_daily_cents_is_zero() {
430        let tracker = CostTracker::new(true, 0.0);
431        record(
432            &tracker,
433            "claude",
434            "claude-opus-4-20250514",
435            100_000,
436            100_000,
437        );
438        assert!(tracker.check_budget().is_ok());
439    }
440
441    #[test]
442    fn per_provider_accumulation() {
443        let tracker = CostTracker::new(true, 1000.0);
444        record(&tracker, "claude", "claude-haiku-4-5-20251001", 1000, 500);
445        record(&tracker, "openai", "gpt-4o", 2000, 1000);
446        record(&tracker, "claude", "claude-haiku-4-5-20251001", 500, 200);
447
448        let breakdown = tracker.provider_breakdown();
449        assert_eq!(breakdown.len(), 2);
450
451        let claude = breakdown.iter().find(|(n, _)| n == "claude").unwrap();
452        assert_eq!(claude.1.request_count, 2);
453        assert_eq!(claude.1.input_tokens, 1500);
454        assert_eq!(claude.1.output_tokens, 700);
455
456        let openai = breakdown.iter().find(|(n, _)| n == "openai").unwrap();
457        assert_eq!(openai.1.request_count, 1);
458        assert_eq!(openai.1.input_tokens, 2000);
459    }
460
461    #[test]
462    fn provider_breakdown_sorted_by_cost_desc() {
463        let tracker = CostTracker::new(true, 1000.0);
464        // gpt-4o: cheap; claude-opus: expensive
465        record(&tracker, "cheap", "gpt-4o-mini", 100, 100);
466        record(&tracker, "expensive", "claude-opus-4-20250514", 10000, 5000);
467
468        let breakdown = tracker.provider_breakdown();
469        assert_eq!(breakdown[0].0, "expensive");
470    }
471
472    #[test]
473    fn cache_tokens_included_in_cost() {
474        let tracker = CostTracker::new(true, 1000.0);
475        // claude-haiku prompt=0.1, cache_read=0.01 per 1k
476        // 1000 cache_read tokens = 0.01 cents; 0 input/output for isolation
477        tracker.record_usage(
478            "claude",
479            "cloud",
480            "claude-haiku-4-5-20251001",
481            0,
482            1000,
483            0,
484            0,
485        );
486        let spend = tracker.current_spend();
487        assert!(spend > 0.0, "cache read should contribute to cost");
488    }
489
490    #[test]
491    fn cache_write_cost_included_in_total() {
492        let tracker = CostTracker::new(true, 1000.0);
493        // Claude pricing: cache_write = 125% of prompt price
494        // claude-opus-4-6: prompt = 0.5 cents/1k
495        // 1000 cache_write tokens = (0.5 * 1.25 * 1000) / 1000 = 0.625 cents
496        tracker.record_usage("claude-provider", "cloud", "claude-opus-4-6", 0, 0, 1000, 0);
497        let cost = tracker.current_spend();
498        assert!((cost - 0.625).abs() < 0.001);
499    }
500
501    #[test]
502    fn provider_breakdown_empty_when_disabled() {
503        let tracker = CostTracker::new(false, 100.0);
504        tracker.record_usage(
505            "claude",
506            "cloud",
507            "claude-haiku-4-5-20251001",
508            1000,
509            0,
510            0,
511            1000,
512        );
513        assert!(tracker.provider_breakdown().is_empty());
514    }
515
516    #[test]
517    fn cps_none_when_no_tasks() {
518        let tracker = CostTracker::new(true, 100.0);
519        assert!(tracker.cps().is_none());
520        assert_eq!(tracker.successful_tasks(), 0);
521    }
522
523    #[test]
524    fn cps_calculated_correctly() {
525        let tracker = CostTracker::new(true, 100.0);
526        // 0.25 (input) + 1.0 (output) = 1.25 cents
527        record(&tracker, "openai", "gpt-4o", 1000, 1000);
528        tracker.record_successful_task();
529        tracker.record_successful_task();
530        assert_eq!(tracker.successful_tasks(), 2);
531        let cps = tracker.cps().expect("cps should be Some after tasks");
532        // 1.25 / 2 = 0.625
533        assert!((cps - 0.625).abs() < 0.001, "cps={cps}");
534    }
535
536    #[test]
537    fn cps_resets_on_new_day() {
538        let tracker = CostTracker::new(true, 100.0);
539        record(&tracker, "openai", "gpt-4o", 1000, 1000);
540        tracker.record_successful_task();
541        assert_eq!(tracker.successful_tasks(), 1);
542        // Force day change
543        {
544            let mut state = tracker.state.lock();
545            state.day = 0;
546        }
547        // Any state-touching call triggers reset
548        assert!(tracker.check_budget().is_ok());
549        assert_eq!(tracker.successful_tasks(), 0);
550        assert!(tracker.cps().is_none());
551    }
552}