1use chrono::{DateTime, TimeZone, Utc};
6
7use crate::models::Usage;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum CostCurrency {
12 Usd,
13 Cny,
14}
15
16impl CostCurrency {
17 pub fn from_setting(value: &str) -> Option<Self> {
18 match value.trim().to_ascii_lowercase().as_str() {
19 "usd" | "dollar" | "dollars" | "$" => Some(Self::Usd),
20 "cny" | "rmb" | "yuan" | "¥" => Some(Self::Cny),
21 _ => None,
22 }
23 }
24
25 fn symbol(self) -> &'static str {
26 match self {
27 Self::Usd => "$",
28 Self::Cny => "¥",
29 }
30 }
31}
32
33#[derive(Debug, Clone, Copy, Default, PartialEq)]
35pub struct CostEstimate {
36 pub usd: f64,
37 pub cny: f64,
38}
39
40impl CostEstimate {
41 #[allow(dead_code)]
42 pub fn usd_only(usd: f64) -> Self {
43 Self { usd, cny: 0.0 }
44 }
45
46 pub fn is_positive(self) -> bool {
47 self.usd > 0.0 || self.cny > 0.0
48 }
49
50 pub fn amount(self, currency: CostCurrency) -> f64 {
51 match currency {
52 CostCurrency::Usd => self.usd,
53 CostCurrency::Cny => self.cny,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy)]
60struct CurrencyPricing {
61 input_cache_hit_per_million: f64,
62 input_cache_miss_per_million: f64,
63 output_per_million: f64,
64}
65
66#[derive(Debug, Clone, Copy)]
68struct ModelPricing {
69 usd: CurrencyPricing,
70 cny: CurrencyPricing,
71}
72
73fn v4_pro_discount_ends_at() -> DateTime<Utc> {
74 Utc.with_ymd_and_hms(2026, 5, 31, 15, 59, 0)
75 .single()
76 .expect("valid DeepSeek V4 Pro discount end timestamp")
77}
78
79fn pricing_for_model(model: &str) -> Option<ModelPricing> {
81 pricing_for_model_at(model, Utc::now())
82}
83
84fn pricing_for_model_at(model: &str, now: DateTime<Utc>) -> Option<ModelPricing> {
85 let lower = model.to_lowercase();
86 if lower.starts_with("deepseek-ai/") {
87 return None;
90 }
91 if !lower.contains("deepseek") {
92 return None;
93 }
94 if lower.contains("v4-pro") || lower.contains("v4pro") {
95 if now <= v4_pro_discount_ends_at() {
96 return Some(ModelPricing {
99 usd: CurrencyPricing {
100 input_cache_hit_per_million: 0.003625,
101 input_cache_miss_per_million: 0.435,
102 output_per_million: 0.87,
103 },
104 cny: CurrencyPricing {
105 input_cache_hit_per_million: 0.025,
106 input_cache_miss_per_million: 3.0,
107 output_per_million: 6.0,
108 },
109 });
110 }
111 Some(ModelPricing {
112 usd: CurrencyPricing {
113 input_cache_hit_per_million: 0.0145,
114 input_cache_miss_per_million: 1.74,
115 output_per_million: 3.48,
116 },
117 cny: CurrencyPricing {
118 input_cache_hit_per_million: 0.1,
119 input_cache_miss_per_million: 12.0,
120 output_per_million: 24.0,
121 },
122 })
123 } else {
124 Some(ModelPricing {
126 usd: CurrencyPricing {
127 input_cache_hit_per_million: 0.0028,
128 input_cache_miss_per_million: 0.14,
129 output_per_million: 0.28,
130 },
131 cny: CurrencyPricing {
132 input_cache_hit_per_million: 0.02,
133 input_cache_miss_per_million: 1.0,
134 output_per_million: 2.0,
135 },
136 })
137 }
138}
139
140#[must_use]
142#[allow(dead_code)]
143pub fn calculate_turn_cost(model: &str, input_tokens: u32, output_tokens: u32) -> Option<f64> {
144 calculate_turn_cost_estimate(model, input_tokens, output_tokens).map(|estimate| estimate.usd)
145}
146
147#[must_use]
149pub fn calculate_turn_cost_estimate(
150 model: &str,
151 input_tokens: u32,
152 output_tokens: u32,
153) -> Option<CostEstimate> {
154 let pricing = pricing_for_model(model)?;
155 Some(CostEstimate {
156 usd: calculate_turn_cost_with_pricing(pricing.usd, input_tokens, output_tokens),
157 cny: calculate_turn_cost_with_pricing(pricing.cny, input_tokens, output_tokens),
158 })
159}
160
161fn calculate_turn_cost_with_pricing(
162 pricing: CurrencyPricing,
163 input_tokens: u32,
164 output_tokens: u32,
165) -> f64 {
166 let input_cost = (input_tokens as f64 / 1_000_000.0) * pricing.input_cache_miss_per_million;
167 let output_cost = (output_tokens as f64 / 1_000_000.0) * pricing.output_per_million;
168 input_cost + output_cost
169}
170
171#[must_use]
173pub fn calculate_turn_cost_from_usage(model: &str, usage: &Usage) -> Option<f64> {
174 calculate_turn_cost_estimate_from_usage(model, usage).map(|estimate| estimate.usd)
175}
176
177#[must_use]
179pub fn calculate_turn_cost_estimate_from_usage(model: &str, usage: &Usage) -> Option<CostEstimate> {
180 let pricing = pricing_for_model(model)?;
181 Some(CostEstimate {
182 usd: calculate_turn_cost_from_usage_with_pricing(pricing.usd, usage),
183 cny: calculate_turn_cost_from_usage_with_pricing(pricing.cny, usage),
184 })
185}
186
187fn calculate_turn_cost_from_usage_with_pricing(pricing: CurrencyPricing, usage: &Usage) -> f64 {
188 let hit_tokens = usage.prompt_cache_hit_tokens.unwrap_or(0);
189 let miss_tokens = usage
190 .prompt_cache_miss_tokens
191 .unwrap_or_else(|| usage.input_tokens.saturating_sub(hit_tokens));
192 let accounted_input = hit_tokens.saturating_add(miss_tokens);
193 let uncategorized_input = usage.input_tokens.saturating_sub(accounted_input);
194
195 let hit_cost = (hit_tokens as f64 / 1_000_000.0) * pricing.input_cache_hit_per_million;
196 let miss_cost = ((miss_tokens.saturating_add(uncategorized_input)) as f64 / 1_000_000.0)
197 * pricing.input_cache_miss_per_million;
198 let output_cost = (usage.output_tokens as f64 / 1_000_000.0) * pricing.output_per_million;
199 hit_cost + miss_cost + output_cost
200}
201
202#[must_use]
204#[allow(dead_code)]
205pub fn format_cost(cost: f64) -> String {
206 format_cost_amount(cost, CostCurrency::Usd)
207}
208
209#[must_use]
211pub fn format_cost_amount(cost: f64, currency: CostCurrency) -> String {
212 let symbol = currency.symbol();
213 if cost < 0.0001 {
214 format!("<{symbol}0.0001")
215 } else if cost < 0.01 {
216 format!("{symbol}{cost:.4}")
217 } else {
218 format!("{symbol}{cost:.2}")
219 }
220}
221
222#[must_use]
224pub fn format_cost_amount_precise(cost: f64, currency: CostCurrency) -> String {
225 let symbol = currency.symbol();
226 if cost < 0.0001 {
227 format!("<{symbol}0.0001")
228 } else {
229 format!("{symbol}{cost:.4}")
230 }
231}
232
233#[must_use]
235pub fn format_cost_estimate(estimate: CostEstimate, currency: CostCurrency) -> String {
236 format_cost_amount(estimate.amount(currency), currency)
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::models::Usage;
243
244 #[test]
245 fn nvidia_nim_deepseek_model_does_not_use_deepseek_platform_pricing() {
246 assert!(calculate_turn_cost("deepseek-ai/deepseek-v4-pro", 1_000, 1_000).is_none());
247 }
248
249 #[test]
250 fn v4_pro_uses_limited_time_discount_before_expiry() {
251 let before_expiry = Utc
252 .with_ymd_and_hms(2026, 5, 31, 15, 58, 59)
253 .single()
254 .unwrap();
255 let pricing = pricing_for_model_at("deepseek-v4-pro", before_expiry).unwrap();
256
257 assert_eq!(pricing.usd.input_cache_hit_per_million, 0.003625);
258 assert_eq!(pricing.usd.input_cache_miss_per_million, 0.435);
259 assert_eq!(pricing.usd.output_per_million, 0.87);
260 assert_eq!(pricing.cny.input_cache_hit_per_million, 0.025);
261 assert_eq!(pricing.cny.input_cache_miss_per_million, 3.0);
262 assert_eq!(pricing.cny.output_per_million, 6.0);
263 }
264
265 #[test]
266 fn v4_pro_returns_to_base_rates_after_discount_expiry() {
267 let after_expiry = Utc
268 .with_ymd_and_hms(2026, 5, 31, 16, 0, 0)
269 .single()
270 .unwrap();
271 let pricing = pricing_for_model_at("deepseek-v4-pro", after_expiry).unwrap();
272
273 assert_eq!(pricing.usd.input_cache_hit_per_million, 0.0145);
274 assert_eq!(pricing.usd.input_cache_miss_per_million, 1.74);
275 assert_eq!(pricing.usd.output_per_million, 3.48);
276 assert_eq!(pricing.cny.input_cache_hit_per_million, 0.1);
277 assert_eq!(pricing.cny.input_cache_miss_per_million, 12.0);
278 assert_eq!(pricing.cny.output_per_million, 24.0);
279 }
280
281 #[test]
282 fn v4_pro_discount_still_applies_just_before_old_may5_expiry() {
283 let after_old_expiry = Utc.with_ymd_and_hms(2026, 5, 6, 0, 0, 0).single().unwrap();
285 let pricing = pricing_for_model_at("deepseek-v4-pro", after_old_expiry).unwrap();
286
287 assert_eq!(pricing.usd.input_cache_hit_per_million, 0.003625);
288 assert_eq!(pricing.usd.input_cache_miss_per_million, 0.435);
289 assert_eq!(pricing.usd.output_per_million, 0.87);
290 }
291
292 #[test]
293 fn v4_flash_keeps_current_published_rates() {
294 let now = Utc.with_ymd_and_hms(2026, 4, 25, 0, 0, 0).single().unwrap();
295 let pricing = pricing_for_model_at("deepseek-v4-flash", now).unwrap();
296
297 assert_eq!(pricing.usd.input_cache_hit_per_million, 0.0028);
298 assert_eq!(pricing.usd.input_cache_miss_per_million, 0.14);
299 assert_eq!(pricing.usd.output_per_million, 0.28);
300 assert_eq!(pricing.cny.input_cache_hit_per_million, 0.02);
301 assert_eq!(pricing.cny.input_cache_miss_per_million, 1.0);
302 assert_eq!(pricing.cny.output_per_million, 2.0);
303 }
304
305 #[test]
306 fn cost_estimate_calculates_usd_and_cny() {
307 let estimate = calculate_turn_cost_estimate("deepseek-v4-flash", 1_000_000, 500_000)
308 .expect("estimate");
309
310 assert_eq!(estimate.usd, 0.28);
311 assert_eq!(estimate.cny, 2.0);
312 }
313
314 #[test]
315 fn cost_currency_accepts_yuan_aliases() {
316 assert_eq!(CostCurrency::from_setting("usd"), Some(CostCurrency::Usd));
317 assert_eq!(CostCurrency::from_setting("yuan"), Some(CostCurrency::Cny));
318 assert_eq!(CostCurrency::from_setting("rmb"), Some(CostCurrency::Cny));
319 assert_eq!(CostCurrency::from_setting("cny"), Some(CostCurrency::Cny));
320 assert_eq!(CostCurrency::from_setting("eur"), None);
321 }
322
323 #[test]
324 fn format_cost_amount_uses_selected_symbol() {
325 assert_eq!(format_cost_amount(0.42, CostCurrency::Usd), "$0.42");
326 assert_eq!(format_cost_amount(2.0, CostCurrency::Cny), "¥2.00");
327 }
328
329 #[test]
330 fn cost_from_usage_splits_cache_hit_and_miss() {
331 let usage = Usage {
332 input_tokens: 1000,
333 output_tokens: 200,
334 prompt_cache_hit_tokens: Some(800),
335 prompt_cache_miss_tokens: Some(200),
336 reasoning_tokens: None,
337 reasoning_replay_tokens: None,
338 server_tool_use: None,
339 };
340 let actual = calculate_turn_cost_from_usage("deepseek-v4-flash", &usage).unwrap();
341 let all_miss = calculate_turn_cost_estimate(
342 "deepseek-v4-flash",
343 usage.input_tokens,
344 usage.output_tokens,
345 )
346 .unwrap()
347 .usd;
348 assert!(actual < all_miss);
349 assert!(all_miss - actual > 0.0);
350 }
351
352 #[test]
353 fn format_cost_amount_precise_keeps_report_precision() {
354 assert_eq!(
355 format_cost_amount_precise(0.1234, CostCurrency::Usd),
356 "$0.1234"
357 );
358 assert_eq!(
359 format_cost_amount_precise(0.1234, CostCurrency::Cny),
360 "¥0.1234"
361 );
362 }
363}