1use std::collections::HashMap;
16use std::sync::OnceLock;
17
18use chrono::{DateTime, Utc};
19use serde::{Deserialize, Serialize};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ModelPricing {
23 pub input_per_million: f64,
25 pub output_per_million: f64,
27 pub cached_input_per_million: Option<f64>,
29 pub cache_write_per_million: Option<f64>,
33 pub effective_at: DateTime<Utc>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
38pub struct ModelInfo {
39 pub id: String,
40 pub provider: String,
41 pub capabilities: Vec<Capability>,
42 pub max_input_tokens: u64,
43 pub max_output_tokens: u64,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
47#[serde(rename_all = "snake_case")]
48pub enum Capability {
49 Text,
50 Vision,
51 Audio,
52 Tools,
53 JsonMode,
54 Streaming,
55 Reasoning,
56 PromptCaching,
57}
58
59const PRICING_TOML: &str = include_str!("../data/pricing.toml");
62
63#[derive(Debug, Deserialize)]
65struct RawEntry {
66 provider: String,
67 model: String,
68 input_per_million: f64,
69 output_per_million: f64,
70 #[serde(default)]
71 cached_input_per_million: Option<f64>,
72 #[serde(default)]
73 cache_write_per_million: Option<f64>,
74 effective_at: DateTime<Utc>,
75}
76
77#[derive(Debug, Deserialize)]
78struct RawCatalog {
79 #[serde(default)]
80 entry: Vec<RawEntry>,
81}
82
83#[derive(Debug)]
86pub struct PricingCatalog {
87 by_model: HashMap<(String, String), Vec<ModelPricing>>,
88}
89
90impl PricingCatalog {
91 pub fn parse(toml_text: &str) -> Result<Self, toml::de::Error> {
94 let raw: RawCatalog = toml::from_str(toml_text)?;
95 let mut by_model: HashMap<(String, String), Vec<ModelPricing>> = HashMap::new();
96 for e in raw.entry {
97 by_model
98 .entry((e.provider, e.model))
99 .or_default()
100 .push(ModelPricing {
101 input_per_million: e.input_per_million,
102 output_per_million: e.output_per_million,
103 cached_input_per_million: e.cached_input_per_million,
104 cache_write_per_million: e.cache_write_per_million,
105 effective_at: e.effective_at,
106 });
107 }
108 for history in by_model.values_mut() {
111 history.sort_by_key(|p| p.effective_at);
112 }
113 Ok(Self { by_model })
114 }
115
116 pub fn latest(&self, provider: &str, model: &str) -> Option<ModelPricing> {
119 self.by_model
120 .get(&(provider.to_string(), model.to_string()))?
121 .last()
122 .cloned()
123 }
124
125 pub fn at(&self, provider: &str, model: &str, at: DateTime<Utc>) -> Option<ModelPricing> {
130 let history = self
131 .by_model
132 .get(&(provider.to_string(), model.to_string()))?;
133 history
134 .iter()
135 .rev()
136 .find(|p| p.effective_at <= at)
137 .or_else(|| history.first())
138 .cloned()
139 }
140
141 pub fn latest_for_provider(&self, provider: &str) -> Vec<(String, ModelPricing)> {
145 self.by_model
146 .iter()
147 .filter(|((p, _), _)| p == provider)
148 .filter_map(|((_, model), history)| history.last().map(|p| (model.clone(), p.clone())))
149 .collect()
150 }
151
152 pub fn pairs(&self) -> Vec<(String, String)> {
156 self.by_model.keys().cloned().collect()
157 }
158
159 pub fn len(&self) -> usize {
161 self.by_model.len()
162 }
163
164 pub fn is_empty(&self) -> bool {
166 self.by_model.is_empty()
167 }
168
169 pub fn catalog_max_effective_at(&self) -> Option<DateTime<Utc>> {
177 self.by_model
178 .values()
179 .filter_map(|history| history.last().map(|p| p.effective_at))
180 .max()
181 }
182}
183
184pub fn catalog() -> &'static PricingCatalog {
188 static CATALOG: OnceLock<PricingCatalog> = OnceLock::new();
189 CATALOG.get_or_init(|| {
190 PricingCatalog::parse(PRICING_TOML).expect("embedded data/pricing.toml must be valid")
191 })
192}
193
194#[must_use]
197pub fn is_stale(newest: Option<DateTime<Utc>>, now: DateTime<Utc>, max_days: i64) -> bool {
198 match newest {
199 Some(d) => (now - d).num_days() > max_days,
200 None => false,
201 }
202}
203
204#[cfg(test)]
205mod catalog_tests {
206 use super::*;
207 use chrono::TimeZone;
208
209 #[test]
210 fn is_stale_thresholds() {
211 use chrono::Duration;
212 let now: DateTime<Utc> = "2026-06-05T00:00:00Z".parse().unwrap();
213 assert!(!is_stale(None, now, 90)); assert!(!is_stale(Some(now - Duration::days(10)), now, 90));
215 assert!(is_stale(Some(now - Duration::days(100)), now, 90));
216 }
217
218 #[test]
219 fn embedded_catalog_parses_and_is_populated() {
220 let c = catalog();
221 assert!(!c.is_empty(), "embedded catalog should not be empty");
222 assert_eq!(
226 c.len(),
227 36,
228 "unexpected catalog size — update if intentional"
229 );
230 }
231
232 #[test]
238 fn catalog_max_effective_at_is_present() {
239 let c = catalog();
240 let max_date = c
241 .catalog_max_effective_at()
242 .expect("non-empty catalog must have a max effective_at");
243 let floor = Utc.with_ymd_and_hms(2026, 1, 1, 0, 0, 0).unwrap();
246 assert!(
247 max_date >= floor,
248 "catalog_max_effective_at = {max_date} is older than expected floor {floor}"
249 );
250 }
251
252 #[test]
254 fn catalog_max_effective_at_picks_newest() {
255 let toml = r#"
256 [[entry]]
257 provider = "p"
258 model = "m1"
259 input_per_million = 1.0
260 output_per_million = 2.0
261 effective_at = "2026-03-01T00:00:00Z"
262
263 [[entry]]
264 provider = "p"
265 model = "m2"
266 input_per_million = 3.0
267 output_per_million = 4.0
268 effective_at = "2026-05-01T00:00:00Z"
269 "#;
270 let c = PricingCatalog::parse(toml).expect("valid");
271 let max = c.catalog_max_effective_at().expect("present");
272 assert_eq!(
273 max,
274 Utc.with_ymd_and_hms(2026, 5, 1, 0, 0, 0).unwrap(),
275 "should return the newest effective_at across all models"
276 );
277 }
278
279 #[test]
281 fn catalog_max_effective_at_empty_catalog() {
282 let c = PricingCatalog::parse("").expect("empty TOML is valid");
283 assert!(c.catalog_max_effective_at().is_none());
284 }
285
286 #[test]
287 fn latest_returns_known_rates() {
288 let c = catalog();
289 let p = c.latest("openai", "gpt-4o").expect("gpt-4o present");
290 assert_eq!(p.input_per_million, 2.50);
291 assert_eq!(p.output_per_million, 10.00);
292 assert_eq!(p.cached_input_per_million, Some(1.25));
293
294 let g = c.latest("groq", "llama-3.1-8b-instant").expect("present");
296 assert_eq!(g.cached_input_per_million, None);
297 }
298
299 #[test]
302 fn anthropic_models_have_cache_write_rate() {
303 let c = catalog();
304
305 let haiku = c.latest("anthropic", "claude-haiku-4-5").expect("present");
306 assert_eq!(
307 haiku.cache_write_per_million,
308 Some(1.25),
309 "haiku write rate = 1.25× base input (1.00)"
310 );
311
312 let sonnet = c.latest("anthropic", "claude-sonnet-4-6").expect("present");
313 assert_eq!(
314 sonnet.cache_write_per_million,
315 Some(3.75),
316 "sonnet write rate = 1.25× base input (3.00)"
317 );
318
319 let opus = c.latest("anthropic", "claude-opus-4-7").expect("present");
320 assert_eq!(
321 opus.cache_write_per_million,
322 Some(6.25),
323 "opus write rate = 1.25× base input (5.00)"
324 );
325
326 let gpt4o = c.latest("openai", "gpt-4o").expect("gpt-4o present");
328 assert_eq!(
329 gpt4o.cache_write_per_million, None,
330 "OpenAI has no cache-write premium"
331 );
332
333 let groq_llama = c.latest("groq", "llama-3.1-8b-instant").expect("present");
334 assert_eq!(
335 groq_llama.cache_write_per_million, None,
336 "Groq has no cache-write premium"
337 );
338 }
339
340 #[test]
341 fn unknown_provider_or_model_is_none() {
342 let c = catalog();
343 assert!(c.latest("openai", "no-such-model").is_none());
344 assert!(c.latest("no-such-provider", "gpt-4o").is_none());
345 }
346
347 #[test]
348 fn at_selects_rate_effective_at_timestamp() {
349 let toml = r#"
351 [[entry]]
352 provider = "p"
353 model = "m"
354 input_per_million = 1.0
355 output_per_million = 2.0
356 effective_at = "2026-01-01T00:00:00Z"
357
358 [[entry]]
359 provider = "p"
360 model = "m"
361 input_per_million = 3.0
362 output_per_million = 4.0
363 effective_at = "2026-06-01T00:00:00Z"
364 "#;
365 let c = PricingCatalog::parse(toml).expect("valid");
366
367 let before = c
369 .at("p", "m", Utc.with_ymd_and_hms(2025, 1, 1, 0, 0, 0).unwrap())
370 .unwrap();
371 assert_eq!(before.input_per_million, 1.0);
372
373 let mid = c
375 .at("p", "m", Utc.with_ymd_and_hms(2026, 3, 1, 0, 0, 0).unwrap())
376 .unwrap();
377 assert_eq!(mid.input_per_million, 1.0);
378
379 let after = c
381 .at("p", "m", Utc.with_ymd_and_hms(2026, 9, 1, 0, 0, 0).unwrap())
382 .unwrap();
383 assert_eq!(after.input_per_million, 3.0);
384
385 assert_eq!(c.latest("p", "m").unwrap().input_per_million, 3.0);
387 }
388}