Skip to main content

tt_plan_core/
cost.rs

1//! Cost projection for a single replayed request. The math is deliberately
2//! minimal so the determinism contract is easy to audit: same inputs in,
3//! same `f64` out.
4//!
5//! `compute_baseline_cost` re-derives the historical cost from the same
6//! pricing table — it's the denominator we compare against. Both helpers
7//! charge cached input tokens at `cached_input_per_million` when set,
8//! falling back to the full input rate when the pricing entry doesn't
9//! advertise a cache discount.
10
11use crate::types::{ModelPricing, RequestLog};
12
13/// A projected cost for one replayed request.
14#[derive(Debug, Clone)]
15pub struct ProjectedCost {
16    /// The recomputed cost, USD, under the proposed model + pricing.
17    pub cost_usd: f64,
18}
19
20/// Project the cost of one request under a different model + pricing entry.
21///
22/// `target_model` is taken purely for traceability — the math uses
23/// `pricing` directly. Cached-token rate falls back to the non-cached
24/// input rate when the pricing entry doesn't advertise a discount.
25#[must_use]
26pub fn project_cost(
27    req: &RequestLog,
28    _target_model: &str,
29    pricing: &ModelPricing,
30) -> ProjectedCost {
31    let cached = req.cached_tokens.min(req.input_tokens);
32    let non_cached_input = req.input_tokens.saturating_sub(cached);
33    let cached_rate = pricing
34        .cached_input_per_million
35        .unwrap_or(pricing.input_per_million);
36    let cost = (f64::from(non_cached_input)) * pricing.input_per_million / 1_000_000.0
37        + (f64::from(cached)) * cached_rate / 1_000_000.0
38        + (f64::from(req.output_tokens)) * pricing.output_per_million / 1_000_000.0;
39    ProjectedCost { cost_usd: cost }
40}
41
42/// Re-derive the baseline cost from a pricing entry. Used by tests and
43/// any caller that wants to validate the historical `cost_usd` field
44/// against today's pricing snapshot.
45#[must_use]
46pub fn compute_baseline_cost(req: &RequestLog, pricing: &ModelPricing) -> f64 {
47    project_cost(req, &req.model, pricing).cost_usd
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use chrono::TimeZone;
54    use uuid::Uuid;
55
56    fn sample_request(input: u32, output: u32, cached: u32) -> RequestLog {
57        RequestLog {
58            id: Uuid::nil(),
59            org_id: Uuid::nil(),
60            ts: chrono::Utc.with_ymd_and_hms(2026, 5, 1, 0, 0, 0).unwrap(),
61            provider: "anthropic".into(),
62            model: "claude-3-5-sonnet".into(),
63            input_tokens: input,
64            output_tokens: output,
65            cached_tokens: cached,
66            cost_usd: 0.0,
67            baseline_cost_usd: 0.0,
68            cached: false,
69            cache_layer: None,
70            matched_route_id: None,
71            latency_ms: 0,
72            upstream_latency_ms: None,
73            status: 200,
74            tag: None,
75            embedding: None,
76            finish_reason: None,
77            body: None,
78            response_body: None,
79        }
80    }
81
82    #[test]
83    fn project_cost_with_full_pricing() {
84        let pricing = ModelPricing {
85            input_per_million: 3.0,
86            output_per_million: 15.0,
87            cached_input_per_million: Some(0.3),
88        };
89        let req = sample_request(1_000_000, 1_000_000, 0);
90        let p = project_cost(&req, "x", &pricing);
91        // 1M input @ $3 + 1M output @ $15 = $18.
92        assert!((p.cost_usd - 18.0).abs() < 1e-9, "got {}", p.cost_usd);
93    }
94
95    #[test]
96    fn project_cost_charges_cached_at_discount() {
97        let pricing = ModelPricing {
98            input_per_million: 3.0,
99            output_per_million: 15.0,
100            cached_input_per_million: Some(0.3),
101        };
102        let req = sample_request(1_000_000, 0, 500_000);
103        let p = project_cost(&req, "x", &pricing);
104        // 500K non-cached @ $3/1M + 500K cached @ $0.30/1M = $1.50 + $0.15 = $1.65
105        assert!((p.cost_usd - 1.65).abs() < 1e-9, "got {}", p.cost_usd);
106    }
107
108    #[test]
109    fn project_cost_falls_back_to_full_rate_when_no_cache_discount() {
110        let pricing = ModelPricing {
111            input_per_million: 3.0,
112            output_per_million: 15.0,
113            cached_input_per_million: None,
114        };
115        let req = sample_request(1_000_000, 0, 500_000);
116        let p = project_cost(&req, "x", &pricing);
117        // All input charged at the full rate -> $3.00.
118        assert!((p.cost_usd - 3.0).abs() < 1e-9, "got {}", p.cost_usd);
119    }
120
121    #[test]
122    fn project_cost_clamps_cached_to_input() {
123        let pricing = ModelPricing {
124            input_per_million: 3.0,
125            output_per_million: 15.0,
126            cached_input_per_million: Some(0.3),
127        };
128        // cached_tokens > input_tokens — should clamp.
129        let req = sample_request(1_000, 0, 5_000);
130        let p = project_cost(&req, "x", &pricing);
131        // All 1000 charged at cached rate ($0.30/1M).
132        let want = 1_000.0 * 0.3 / 1_000_000.0;
133        assert!((p.cost_usd - want).abs() < 1e-12, "got {}", p.cost_usd);
134    }
135}