1use std::sync::OnceLock;
2
3use serde::Deserialize;
4
5use crate::ModelSpec;
6use crate::types::{Cost, ModelCapabilities, Usage};
7
8#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum ProviderKind {
11 Remote,
12 Local,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum AuthMode {
18 Bearer,
19 ApiKeyHeader,
20 AwsSigv4,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum ApiVersion {
26 V1,
27 V1beta,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum PresetCapability {
33 Text,
34 Tools,
35 Thinking,
36 ImagesIn,
37 Streaming,
38 StructuredOutput,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum PresetStatus {
44 Ga,
45 Preview,
46}
47
48#[derive(Debug, Clone, PartialEq, Deserialize)]
49pub struct PresetCatalog {
50 pub id: String,
51 pub display_name: String,
52 pub group: Option<String>,
53 pub model_id: String,
54 pub api_version: Option<ApiVersion>,
55 #[serde(default)]
56 pub capabilities: Vec<PresetCapability>,
57 pub status: Option<PresetStatus>,
58 pub context_window_tokens: Option<u64>,
59 pub max_output_tokens: Option<u64>,
60 #[serde(default)]
61 pub include_by_default: bool,
62 pub repo_id: Option<String>,
63 pub filename: Option<String>,
64 #[serde(default)]
65 pub cost_per_million_input: Option<f64>,
66 #[serde(default)]
67 pub cost_per_million_output: Option<f64>,
68 #[serde(default)]
69 pub cost_per_million_cache_read: Option<f64>,
70 #[serde(default)]
71 pub cost_per_million_cache_write: Option<f64>,
72}
73
74#[derive(Debug, Clone, PartialEq, Deserialize)]
75pub struct ProviderCatalog {
76 pub key: String,
77 pub display_name: String,
78 pub kind: ProviderKind,
79 pub auth_mode: Option<AuthMode>,
80 pub credential_env_var: Option<String>,
81 pub base_url_env_var: Option<String>,
82 pub default_base_url: Option<String>,
83 #[serde(default)]
84 pub requires_base_url: bool,
85 pub region_env_var: Option<String>,
86 #[serde(default)]
87 pub presets: Vec<PresetCatalog>,
88}
89
90impl ProviderCatalog {
91 #[must_use]
92 pub fn preset(&self, preset_id: &str) -> Option<&PresetCatalog> {
93 self.presets.iter().find(|preset| preset.id == preset_id)
94 }
95}
96
97#[derive(Debug, Clone, PartialEq, Deserialize)]
98pub struct ModelCatalog {
99 #[serde(default)]
100 pub providers: Vec<ProviderCatalog>,
101}
102
103impl ModelCatalog {
104 #[must_use]
105 pub fn provider(&self, provider_key: &str) -> Option<&ProviderCatalog> {
106 self.providers
107 .iter()
108 .find(|provider| provider.key == provider_key)
109 }
110
111 #[must_use]
113 pub fn find_preset_by_model_id(&self, model_id: &str) -> Option<CatalogPreset> {
114 for provider in &self.providers {
115 for preset in &provider.presets {
116 if preset.model_id == model_id {
117 return self.preset(&provider.key, &preset.id);
118 }
119 }
120 }
121 None
122 }
123
124 #[must_use]
125 pub fn preset(&self, provider_key: &str, preset_id: &str) -> Option<CatalogPreset> {
126 let provider = self.provider(provider_key)?;
127 let preset = provider.preset(preset_id)?;
128 Some(CatalogPreset {
129 provider_key: provider.key.clone(),
130 provider_display_name: provider.display_name.clone(),
131 provider_kind: provider.kind.clone(),
132 preset_id: preset.id.clone(),
133 display_name: preset.display_name.clone(),
134 group: preset.group.clone(),
135 model_id: preset.model_id.clone(),
136 api_version: preset.api_version.clone(),
137 capabilities: preset.capabilities.clone(),
138 status: preset.status.clone(),
139 context_window_tokens: preset.context_window_tokens,
140 max_output_tokens: preset.max_output_tokens,
141 auth_mode: provider.auth_mode.clone(),
142 credential_env_var: provider.credential_env_var.clone(),
143 base_url_env_var: provider.base_url_env_var.clone(),
144 default_base_url: provider.default_base_url.clone(),
145 requires_base_url: provider.requires_base_url,
146 region_env_var: provider.region_env_var.clone(),
147 include_by_default: preset.include_by_default,
148 repo_id: preset.repo_id.clone(),
149 filename: preset.filename.clone(),
150 cost_per_million_input: preset.cost_per_million_input,
151 cost_per_million_output: preset.cost_per_million_output,
152 cost_per_million_cache_read: preset.cost_per_million_cache_read,
153 cost_per_million_cache_write: preset.cost_per_million_cache_write,
154 })
155 }
156}
157
158#[derive(Debug, Clone, PartialEq)]
159pub struct CatalogPreset {
160 pub provider_key: String,
161 pub provider_display_name: String,
162 pub provider_kind: ProviderKind,
163 pub preset_id: String,
164 pub display_name: String,
165 pub group: Option<String>,
166 pub model_id: String,
167 pub api_version: Option<ApiVersion>,
168 pub capabilities: Vec<PresetCapability>,
169 pub status: Option<PresetStatus>,
170 pub context_window_tokens: Option<u64>,
171 pub max_output_tokens: Option<u64>,
172 pub auth_mode: Option<AuthMode>,
173 pub credential_env_var: Option<String>,
174 pub base_url_env_var: Option<String>,
175 pub default_base_url: Option<String>,
176 pub requires_base_url: bool,
177 pub region_env_var: Option<String>,
178 pub include_by_default: bool,
179 pub repo_id: Option<String>,
180 pub filename: Option<String>,
181 pub cost_per_million_input: Option<f64>,
182 pub cost_per_million_output: Option<f64>,
183 pub cost_per_million_cache_read: Option<f64>,
184 pub cost_per_million_cache_write: Option<f64>,
185}
186
187impl CatalogPreset {
188 #[must_use]
191 pub fn model_capabilities(&self) -> ModelCapabilities {
192 let has = |cap: &PresetCapability| self.capabilities.contains(cap);
193 ModelCapabilities {
194 supports_thinking: has(&PresetCapability::Thinking),
195 supports_vision: has(&PresetCapability::ImagesIn),
196 supports_tool_use: has(&PresetCapability::Tools),
197 supports_streaming: has(&PresetCapability::Streaming),
198 supports_structured_output: has(&PresetCapability::StructuredOutput),
199 max_context_window: self.context_window_tokens,
200 max_output_tokens: self.max_output_tokens,
201 }
202 }
203
204 #[must_use]
206 pub fn model_spec(&self) -> ModelSpec {
207 ModelSpec::new(&self.provider_key, &self.model_id)
208 .with_capabilities(self.model_capabilities())
209 }
210}
211
212#[must_use]
213pub fn model_catalog() -> &'static ModelCatalog {
214 static MODEL_CATALOG: OnceLock<ModelCatalog> = OnceLock::new();
215 MODEL_CATALOG.get_or_init(|| {
216 toml::from_str(include_str!("model_catalog.toml"))
217 .expect("src/model_catalog.toml must be valid TOML")
218 })
219}
220
221#[must_use]
226pub fn calculate_cost(model_id: &str, usage: &Usage) -> Cost {
227 let Some(preset) = model_catalog().find_preset_by_model_id(model_id) else {
228 return Cost::default();
229 };
230
231 #[allow(clippy::cast_precision_loss)] let per_m = |tokens: u64, rate: Option<f64>| -> f64 {
233 rate.map_or(0.0, |r| tokens as f64 * r / 1_000_000.0)
234 };
235
236 let input = per_m(usage.input, preset.cost_per_million_input);
237 let output = per_m(usage.output, preset.cost_per_million_output);
238 let cache_read = per_m(usage.cache_read, preset.cost_per_million_cache_read);
239 let cache_write = per_m(usage.cache_write, preset.cost_per_million_cache_write);
240
241 Cost {
242 input,
243 output,
244 cache_read,
245 cache_write,
246 total: input + output + cache_read + cache_write,
247 ..Cost::default()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn catalog_loads_grouped_presets() {
257 let catalog = model_catalog();
258 let anthropic = catalog.provider("anthropic").unwrap();
259 assert_eq!(anthropic.kind, ProviderKind::Remote);
260 assert!(anthropic.preset("sonnet_46").is_some());
261
262 let local = catalog.provider("local").unwrap();
263 assert_eq!(local.kind, ProviderKind::Local);
264 assert!(!local.preset("smollm3_3b").unwrap().include_by_default);
265 assert!(local.preset("gemma4_e2b").unwrap().include_by_default);
266 assert_eq!(
267 local.preset("gemma4_e2b").unwrap().context_window_tokens,
268 Some(128_000)
269 );
270
271 let google = catalog.provider("google").unwrap();
272 assert_eq!(google.kind, ProviderKind::Remote);
273 assert_eq!(google.presets.len(), 4);
274
275 let bedrock = catalog.provider("bedrock").unwrap();
276 assert_eq!(bedrock.auth_mode, Some(AuthMode::AwsSigv4));
277 assert_eq!(bedrock.region_env_var.as_deref(), Some("AWS_REGION"));
278 }
279
280 #[test]
281 fn preset_lookup_returns_provider_metadata() {
282 let preset = model_catalog().preset("openai", "gpt_5_4").unwrap();
283 assert_eq!(preset.display_name, "OpenAI GPT-5.4");
284 assert_eq!(preset.model_id, "gpt-5.4");
285 assert_eq!(preset.credential_env_var.as_deref(), Some("OPENAI_API_KEY"));
286 assert_eq!(preset.base_url_env_var.as_deref(), Some("OPENAI_BASE_URL"));
287 assert_eq!(preset.auth_mode, Some(AuthMode::Bearer));
288 }
289
290 #[test]
291 fn google_preset_lookup_returns_extended_metadata() {
292 let preset = model_catalog().preset("google", "gemini_3_flash").unwrap();
293 assert_eq!(preset.display_name, "Google Gemini 3 Flash");
294 assert_eq!(preset.model_id, "gemini-3-flash-preview");
295 assert_eq!(preset.api_version, Some(ApiVersion::V1beta));
296 assert_eq!(preset.status, Some(PresetStatus::Preview));
297 assert_eq!(
298 preset.capabilities,
299 vec![
300 PresetCapability::Text,
301 PresetCapability::Tools,
302 PresetCapability::Thinking,
303 PresetCapability::ImagesIn,
304 PresetCapability::Streaming,
305 PresetCapability::StructuredOutput,
306 ]
307 );
308 assert_eq!(preset.context_window_tokens, Some(1_000_000));
309 assert_eq!(preset.max_output_tokens, Some(65536));
310 assert_eq!(preset.credential_env_var.as_deref(), Some("GEMINI_API_KEY"));
311 assert_eq!(preset.base_url_env_var.as_deref(), Some("GEMINI_BASE_URL"));
312 }
313
314 #[test]
315 fn azure_and_bedrock_presets_expose_provider_specific_metadata() {
316 let azure = model_catalog().preset("azure", "gpt_4o").unwrap();
317 assert_eq!(azure.auth_mode, Some(AuthMode::ApiKeyHeader));
318 assert!(azure.requires_base_url);
319 assert_eq!(azure.base_url_env_var.as_deref(), Some("AZURE_BASE_URL"));
320
321 let bedrock = model_catalog()
322 .preset("bedrock", "anthropic_claude_sonnet_45")
323 .unwrap();
324 assert_eq!(bedrock.auth_mode, Some(AuthMode::AwsSigv4));
325 assert_eq!(bedrock.region_env_var.as_deref(), Some("AWS_REGION"));
326 assert_eq!(bedrock.group.as_deref(), Some("anthropic"));
327 }
328
329 #[test]
330 fn anthropic_preset_model_capabilities() {
331 let preset = model_catalog().preset("anthropic", "sonnet_46").unwrap();
332 let caps = preset.model_capabilities();
333 assert!(caps.supports_thinking);
334 assert!(caps.supports_vision);
335 assert!(caps.supports_tool_use);
336 assert!(caps.supports_streaming);
337 assert!(caps.supports_structured_output);
338 assert_eq!(caps.max_context_window, Some(200_000));
339 assert_eq!(caps.max_output_tokens, Some(16384));
340 }
341
342 #[test]
343 fn model_spec_carries_capabilities_from_preset() {
344 let preset = model_catalog().preset("anthropic", "opus_46").unwrap();
345 let spec = preset.model_spec();
346 let caps = spec.capabilities();
347 assert!(caps.supports_thinking);
348 assert!(caps.supports_vision);
349 assert!(caps.supports_tool_use);
350 assert_eq!(caps.max_context_window, Some(200_000));
351 assert_eq!(caps.max_output_tokens, Some(32768));
352 }
353
354 #[test]
355 fn openai_preset_no_thinking() {
356 let preset = model_catalog().preset("openai", "gpt_5_4_mini").unwrap();
357 let caps = preset.model_capabilities();
358 assert!(!caps.supports_thinking);
359 assert!(caps.supports_tool_use);
360 assert!(caps.supports_vision);
361 assert!(caps.supports_streaming);
362 assert!(caps.supports_structured_output);
363 assert_eq!(caps.max_context_window, Some(400_000));
364 }
365
366 #[test]
367 fn local_preset_minimal_capabilities() {
368 let preset = model_catalog().preset("local", "smollm3_3b").unwrap();
369 let caps = preset.model_capabilities();
370 assert!(!caps.supports_thinking);
371 assert!(!caps.supports_vision);
372 assert!(!caps.supports_tool_use);
373 assert!(caps.supports_streaming);
374 assert!(!caps.supports_structured_output);
375 assert_eq!(caps.max_context_window, Some(8192));
376 assert_eq!(caps.max_output_tokens, Some(2048));
377 }
378
379 #[test]
380 fn bedrock_preset_capabilities() {
381 let preset = model_catalog()
382 .preset("bedrock", "anthropic_claude_sonnet_45")
383 .unwrap();
384 let caps = preset.model_capabilities();
385 assert!(caps.supports_thinking);
386 assert!(caps.supports_vision);
387 assert!(caps.supports_tool_use);
388 assert!(caps.supports_streaming);
389 assert!(!caps.supports_structured_output);
390 }
391
392 #[test]
393 fn manual_model_spec_defaults_to_no_capabilities() {
394 let spec = crate::ModelSpec::new("custom", "my-model");
395 let caps = spec.capabilities();
396 assert!(!caps.supports_thinking);
397 assert!(!caps.supports_vision);
398 assert!(!caps.supports_tool_use);
399 assert!(!caps.supports_streaming);
400 assert!(!caps.supports_structured_output);
401 assert_eq!(caps.max_context_window, None);
402 assert_eq!(caps.max_output_tokens, None);
403 }
404
405 fn usage(input: u64, output: u64, cache_read: u64, cache_write: u64) -> crate::types::Usage {
408 crate::types::Usage {
409 input,
410 output,
411 cache_read,
412 cache_write,
413 total: input + output + cache_read + cache_write,
414 ..Default::default()
415 }
416 }
417
418 #[test]
419 fn calculate_cost_known_model() {
420 let cost = calculate_cost("claude-sonnet-4-6", &usage(1_000_000, 500_000, 0, 0));
422 assert!((cost.input - 3.0).abs() < 0.001);
423 assert!((cost.output - 7.5).abs() < 0.001);
424 assert!((cost.total - 10.5).abs() < 0.001);
425 }
426
427 #[test]
428 fn calculate_cost_unknown_model() {
429 let cost = calculate_cost("nonexistent-model-xyz", &usage(1_000_000, 1_000_000, 0, 0));
430 assert!((cost.input).abs() < 0.001);
431 assert!((cost.output).abs() < 0.001);
432 assert!((cost.total).abs() < 0.001);
433 }
434
435 #[test]
436 fn calculate_cost_zero_usage() {
437 let cost = calculate_cost("claude-sonnet-4-6", &usage(0, 0, 0, 0));
438 assert!((cost.total).abs() < 0.001);
439 }
440
441 #[test]
442 fn calculate_cost_cache_tokens() {
443 let cost = calculate_cost("claude-sonnet-4-6", &usage(0, 0, 2_000_000, 1_000_000));
445 assert!((cost.cache_read - 0.60).abs() < 0.001);
446 assert!((cost.cache_write - 3.75).abs() < 0.001);
447 assert!((cost.total - 4.35).abs() < 0.001);
448 }
449
450 #[test]
451 fn calculate_cost_no_pricing_data() {
452 let cost = calculate_cost("SmolLM3-3B-Q4_K_M", &usage(1_000_000, 500_000, 0, 0));
454 assert!((cost.total).abs() < 0.001);
455 }
456
457 #[test]
460 fn capabilities_from_catalog_preset() {
461 let preset = model_catalog().preset("anthropic", "sonnet_46").unwrap();
462 let caps = preset.model_capabilities();
463 assert!(caps.supports_thinking);
464 assert!(caps.supports_vision);
465 assert!(caps.supports_tool_use);
466 assert!(caps.supports_streaming);
467 assert!(caps.supports_structured_output);
468 }
469
470 #[test]
471 fn capabilities_context_window_and_output() {
472 let preset = model_catalog().preset("openai", "gpt_5_4").unwrap();
473 let caps = preset.model_capabilities();
474 assert_eq!(caps.max_context_window, Some(1_050_000));
475 assert_eq!(caps.max_output_tokens, Some(128_000));
476 }
477
478 #[test]
479 fn model_spec_carries_capabilities() {
480 let preset = model_catalog().preset("google", "gemini_3_flash").unwrap();
481 let spec = preset.model_spec();
482 let caps = spec.capabilities();
483 assert!(caps.supports_thinking);
484 assert!(caps.supports_vision);
485 assert!(caps.supports_tool_use);
486 assert_eq!(caps.max_context_window, Some(1_000_000));
487 }
488
489 #[test]
490 fn find_preset_by_model_id_works() {
491 let preset = model_catalog()
492 .find_preset_by_model_id("claude-sonnet-4-6")
493 .unwrap();
494 assert_eq!(preset.preset_id, "sonnet_46");
495 assert_eq!(preset.provider_key, "anthropic");
496 }
497
498 #[test]
499 fn find_preset_by_model_id_unknown_returns_none() {
500 assert!(
501 model_catalog()
502 .find_preset_by_model_id("nonexistent")
503 .is_none()
504 );
505 }
506}