1use crate::types::{ModelPricing, RequestLog};
12
13#[derive(Debug, Clone)]
15pub struct ProjectedCost {
16 pub cost_usd: f64,
18}
19
20#[must_use]
26pub fn project_cost(
27 req: &RequestLog,
28 _target_model: &str,
29 pricing: &ModelPricing,
30) -> ProjectedCost {
31 let cached = req.cached_tokens.min(req.input_tokens);
32 let non_cached_input = req.input_tokens.saturating_sub(cached);
33 let cached_rate = pricing
34 .cached_input_per_million
35 .unwrap_or(pricing.input_per_million);
36 let cost = (f64::from(non_cached_input)) * pricing.input_per_million / 1_000_000.0
37 + (f64::from(cached)) * cached_rate / 1_000_000.0
38 + (f64::from(req.output_tokens)) * pricing.output_per_million / 1_000_000.0;
39 ProjectedCost { cost_usd: cost }
40}
41
42#[must_use]
46pub fn compute_baseline_cost(req: &RequestLog, pricing: &ModelPricing) -> f64 {
47 project_cost(req, &req.model, pricing).cost_usd
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use chrono::TimeZone;
54 use uuid::Uuid;
55
56 fn sample_request(input: u32, output: u32, cached: u32) -> RequestLog {
57 RequestLog {
58 id: Uuid::nil(),
59 org_id: Uuid::nil(),
60 ts: chrono::Utc.with_ymd_and_hms(2026, 5, 1, 0, 0, 0).unwrap(),
61 provider: "anthropic".into(),
62 model: "claude-3-5-sonnet".into(),
63 input_tokens: input,
64 output_tokens: output,
65 cached_tokens: cached,
66 cost_usd: 0.0,
67 baseline_cost_usd: 0.0,
68 cached: false,
69 cache_layer: None,
70 matched_route_id: None,
71 latency_ms: 0,
72 upstream_latency_ms: None,
73 status: 200,
74 tag: None,
75 embedding: None,
76 finish_reason: None,
77 body: None,
78 response_body: None,
79 }
80 }
81
82 #[test]
83 fn project_cost_with_full_pricing() {
84 let pricing = ModelPricing {
85 input_per_million: 3.0,
86 output_per_million: 15.0,
87 cached_input_per_million: Some(0.3),
88 };
89 let req = sample_request(1_000_000, 1_000_000, 0);
90 let p = project_cost(&req, "x", &pricing);
91 assert!((p.cost_usd - 18.0).abs() < 1e-9, "got {}", p.cost_usd);
93 }
94
95 #[test]
96 fn project_cost_charges_cached_at_discount() {
97 let pricing = ModelPricing {
98 input_per_million: 3.0,
99 output_per_million: 15.0,
100 cached_input_per_million: Some(0.3),
101 };
102 let req = sample_request(1_000_000, 0, 500_000);
103 let p = project_cost(&req, "x", &pricing);
104 assert!((p.cost_usd - 1.65).abs() < 1e-9, "got {}", p.cost_usd);
106 }
107
108 #[test]
109 fn project_cost_falls_back_to_full_rate_when_no_cache_discount() {
110 let pricing = ModelPricing {
111 input_per_million: 3.0,
112 output_per_million: 15.0,
113 cached_input_per_million: None,
114 };
115 let req = sample_request(1_000_000, 0, 500_000);
116 let p = project_cost(&req, "x", &pricing);
117 assert!((p.cost_usd - 3.0).abs() < 1e-9, "got {}", p.cost_usd);
119 }
120
121 #[test]
122 fn project_cost_clamps_cached_to_input() {
123 let pricing = ModelPricing {
124 input_per_million: 3.0,
125 output_per_million: 15.0,
126 cached_input_per_million: Some(0.3),
127 };
128 let req = sample_request(1_000, 0, 5_000);
130 let p = project_cost(&req, "x", &pricing);
131 let want = 1_000.0 * 0.3 / 1_000_000.0;
133 assert!((p.cost_usd - want).abs() < 1e-12, "got {}", p.cost_usd);
134 }
135}