sentinel_proxy/inference/
cost.rs1use tracing::{debug, trace};
6
7use sentinel_common::budget::{CostAttributionConfig, CostResult, ModelPricing};
8
9pub struct CostCalculator {
14 config: CostAttributionConfig,
16 route_id: String,
18}
19
20impl CostCalculator {
21 pub fn new(config: CostAttributionConfig, route_id: impl Into<String>) -> Self {
23 let route_id = route_id.into();
24
25 debug!(
26 route_id = %route_id,
27 enabled = config.enabled,
28 pricing_rules = config.pricing.len(),
29 default_input = config.default_input_cost,
30 default_output = config.default_output_cost,
31 currency = %config.currency,
32 "Created cost calculator"
33 );
34
35 Self { config, route_id }
36 }
37
38 pub fn is_enabled(&self) -> bool {
40 self.config.enabled
41 }
42
43 pub fn calculate(
47 &self,
48 model: &str,
49 input_tokens: u64,
50 output_tokens: u64,
51 ) -> CostResult {
52 if !self.config.enabled {
53 return CostResult::new(model, input_tokens, output_tokens, 0.0, 0.0, "USD");
54 }
55
56 let (input_cost_per_million, output_cost_per_million, currency) =
58 if let Some(pricing) = self.find_pricing(model) {
59 let currency = pricing
60 .currency
61 .as_ref()
62 .unwrap_or(&self.config.currency)
63 .clone();
64 (
65 pricing.input_cost_per_million,
66 pricing.output_cost_per_million,
67 currency,
68 )
69 } else {
70 (
71 self.config.default_input_cost,
72 self.config.default_output_cost,
73 self.config.currency.clone(),
74 )
75 };
76
77 let input_cost = (input_tokens as f64 / 1_000_000.0) * input_cost_per_million;
79 let output_cost = (output_tokens as f64 / 1_000_000.0) * output_cost_per_million;
80 let total_cost = input_cost + output_cost;
81
82 trace!(
83 route_id = %self.route_id,
84 model = model,
85 input_tokens = input_tokens,
86 output_tokens = output_tokens,
87 input_cost = input_cost,
88 output_cost = output_cost,
89 total_cost = total_cost,
90 currency = %currency,
91 "Calculated cost"
92 );
93
94 CostResult::new(model, input_tokens, output_tokens, input_cost, output_cost, currency)
95 }
96
97 pub fn find_pricing(&self, model: &str) -> Option<&ModelPricing> {
101 self.config.pricing.iter().find(|p| p.matches(model))
102 }
103
104 pub fn currency(&self) -> &str {
106 &self.config.currency
107 }
108
109 pub fn pricing_rule_count(&self) -> usize {
111 self.config.pricing.len()
112 }
113}
114
115#[cfg(test)]
120mod tests {
121 use super::*;
122
123 fn test_config() -> CostAttributionConfig {
124 CostAttributionConfig {
125 enabled: true,
126 pricing: vec![
127 ModelPricing {
128 model_pattern: "gpt-4*".to_string(),
129 input_cost_per_million: 30.0,
130 output_cost_per_million: 60.0,
131 currency: None,
132 },
133 ModelPricing {
134 model_pattern: "gpt-3.5*".to_string(),
135 input_cost_per_million: 0.5,
136 output_cost_per_million: 1.5,
137 currency: None,
138 },
139 ModelPricing {
140 model_pattern: "claude-*".to_string(),
141 input_cost_per_million: 15.0,
142 output_cost_per_million: 75.0,
143 currency: Some("EUR".to_string()),
144 },
145 ],
146 default_input_cost: 1.0,
147 default_output_cost: 2.0,
148 currency: "USD".to_string(),
149 }
150 }
151
152 #[test]
153 fn test_calculate_gpt4() {
154 let calc = CostCalculator::new(test_config(), "test-route");
155
156 let result = calc.calculate("gpt-4-turbo", 1000, 500);
158
159 assert_eq!(result.model, "gpt-4-turbo");
160 assert_eq!(result.input_tokens, 1000);
161 assert_eq!(result.output_tokens, 500);
162 assert_eq!(result.currency, "USD");
163
164 assert!((result.input_cost - 0.03).abs() < 0.001);
166
167 assert!((result.output_cost - 0.03).abs() < 0.001);
169 }
170
171 #[test]
172 fn test_calculate_gpt35() {
173 let calc = CostCalculator::new(test_config(), "test-route");
174
175 let result = calc.calculate("gpt-3.5-turbo", 1_000_000, 1_000_000);
176
177 assert!((result.input_cost - 0.5).abs() < 0.001);
179
180 assert!((result.output_cost - 1.5).abs() < 0.001);
182
183 assert!((result.total_cost - 2.0).abs() < 0.001);
184 }
185
186 #[test]
187 fn test_calculate_claude_with_currency_override() {
188 let calc = CostCalculator::new(test_config(), "test-route");
189
190 let result = calc.calculate("claude-3-opus", 1000, 1000);
191
192 assert_eq!(result.currency, "EUR");
194 }
195
196 #[test]
197 fn test_calculate_unknown_model_uses_default() {
198 let calc = CostCalculator::new(test_config(), "test-route");
199
200 let result = calc.calculate("llama-3", 1_000_000, 1_000_000);
201
202 assert!((result.input_cost - 1.0).abs() < 0.001);
204 assert!((result.output_cost - 2.0).abs() < 0.001);
205 assert_eq!(result.currency, "USD");
206 }
207
208 #[test]
209 fn test_disabled_returns_zero() {
210 let mut config = test_config();
211 config.enabled = false;
212
213 let calc = CostCalculator::new(config, "test-route");
214
215 let result = calc.calculate("gpt-4", 1000, 500);
216
217 assert!((result.input_cost).abs() < 0.00001);
218 assert!((result.output_cost).abs() < 0.00001);
219 assert!((result.total_cost).abs() < 0.00001);
220 }
221
222 #[test]
223 fn test_find_pricing() {
224 let calc = CostCalculator::new(test_config(), "test-route");
225
226 assert!(calc.find_pricing("gpt-4").is_some());
227 assert!(calc.find_pricing("gpt-4-turbo").is_some());
228 assert!(calc.find_pricing("gpt-3.5-turbo").is_some());
229 assert!(calc.find_pricing("claude-3-sonnet").is_some());
230 assert!(calc.find_pricing("llama-3").is_none());
231 }
232}