1use std::collections::HashMap;
20
21use serde::Deserialize;
22use tracing::debug;
23
24use crate::provider::CostRates;
25
26const DEFAULTS_TOML: &str = include_str!("defaults/models.toml");
28
29#[derive(Debug, Deserialize)]
32struct CatalogFile {
33 #[serde(flatten)]
34 providers: HashMap<String, ProviderEntry>,
35}
36
37#[derive(Debug, Deserialize)]
38struct ProviderEntry {
39 #[serde(default)]
40 default_model: Option<String>,
41 #[serde(default)]
42 api_key_env: Option<String>,
43 #[serde(default)]
44 cache_read_multiplier: Option<f64>,
45 #[serde(default)]
46 cache_creation_multiplier: Option<f64>,
47 #[serde(default)]
48 models: HashMap<String, ModelEntry>,
49}
50
51#[derive(Debug, Deserialize)]
52struct ModelEntry {
53 input: f64,
54 output: f64,
55 #[serde(default)]
56 context_window: Option<u64>,
57 #[serde(default = "default_true")]
58 supports_tool_use: bool,
59 #[serde(default)]
60 supports_vision: bool,
61 #[serde(default)]
62 cache_read_multiplier: Option<f64>,
63 #[serde(default)]
64 cache_creation_multiplier: Option<f64>,
65}
66
67fn default_true() -> bool {
68 true
69}
70
71#[derive(Debug, Clone)]
75pub struct ModelInfo {
76 pub id: String,
78 pub provider: String,
80 pub pricing: CostRates,
82 pub context_window: Option<u64>,
84 pub supports_tool_use: bool,
86 pub supports_vision: bool,
88}
89
90#[derive(Debug, Clone)]
92pub struct ProviderInfo {
93 pub name: String,
95 pub default_model: Option<String>,
97 pub api_key_env: Option<String>,
99 pub cache_read_multiplier: Option<f64>,
101 pub cache_creation_multiplier: Option<f64>,
103}
104
105type ModelKey = String;
109
110fn make_key(provider: &str, model: &str) -> ModelKey {
111 format!("{provider}::{model}")
112}
113
114#[derive(Debug, Clone)]
122pub struct ModelRegistry {
123 models: HashMap<ModelKey, ModelInfo>,
124 providers: HashMap<String, ProviderInfo>,
125}
126
127impl ModelRegistry {
128 pub fn new() -> Self {
130 Self {
131 models: HashMap::new(),
132 providers: HashMap::new(),
133 }
134 }
135
136 pub fn with_defaults() -> Self {
138 Self::from_toml(DEFAULTS_TOML).expect("embedded models.toml must be valid")
139 }
140
141 pub fn from_toml(toml_str: &str) -> Result<Self, String> {
143 let file: CatalogFile =
144 toml::from_str(toml_str).map_err(|e| format!("models TOML parse error: {e}"))?;
145
146 let mut models = HashMap::new();
147 let mut providers = HashMap::new();
148
149 for (prov_name, pe) in &file.providers {
150 providers.insert(
151 prov_name.clone(),
152 ProviderInfo {
153 name: prov_name.clone(),
154 default_model: pe.default_model.clone(),
155 api_key_env: pe.api_key_env.clone(),
156 cache_read_multiplier: pe.cache_read_multiplier,
157 cache_creation_multiplier: pe.cache_creation_multiplier,
158 },
159 );
160
161 for (model_id, me) in &pe.models {
162 let info = ModelInfo {
163 id: model_id.clone(),
164 provider: prov_name.clone(),
165 pricing: CostRates {
166 input_per_million: me.input,
167 output_per_million: me.output,
168 cache_read_multiplier: me.cache_read_multiplier.or(pe.cache_read_multiplier),
169 cache_creation_multiplier: me
170 .cache_creation_multiplier
171 .or(pe.cache_creation_multiplier),
172 },
173 context_window: me.context_window,
174 supports_tool_use: me.supports_tool_use,
175 supports_vision: me.supports_vision,
176 };
177 models.insert(make_key(prov_name, model_id), info);
178 }
179 }
180
181 Ok(Self { models, providers })
182 }
183
184 pub fn merge(&mut self, other: Self) {
186 for (key, info) in other.models {
187 self.models.insert(key, info);
188 }
189 for (key, info) in other.providers {
190 if let Some(existing) = self.providers.get_mut(&key) {
191 if info.default_model.is_some() {
192 existing.default_model = info.default_model;
193 }
194 if info.api_key_env.is_some() {
195 existing.api_key_env = info.api_key_env;
196 }
197 if info.cache_read_multiplier.is_some() {
198 existing.cache_read_multiplier = info.cache_read_multiplier;
199 }
200 if info.cache_creation_multiplier.is_some() {
201 existing.cache_creation_multiplier = info.cache_creation_multiplier;
202 }
203 } else {
204 self.providers.insert(key, info);
205 }
206 }
207 }
208
209 pub fn register(&mut self, provider: &str, model_id: &str, info: ModelInfo) {
213 self.models.insert(make_key(provider, model_id), info);
214 }
215
216 pub fn get(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
220 self.models.get(&make_key(provider, model))
221 }
222
223 pub fn get_fuzzy(&self, provider: &str, model: &str) -> Option<&ModelInfo> {
226 if let Some(info) = self.get(provider, model) {
227 return Some(info);
228 }
229
230 let prefix = format!("{provider}::");
231
232 let mut best: Option<(&str, &ModelInfo)> = None;
233 for (key, info) in &self.models {
234 if let Some(registered) = key.strip_prefix(&prefix) {
235 if model.contains(registered) || registered.contains(model) {
236 let dominated = best
237 .map(|(prev, _)| registered.len() > prev.len())
238 .unwrap_or(true);
239 if dominated {
240 best = Some((registered, info));
241 }
242 }
243 }
244 }
245 if let Some((matched, info)) = best {
246 debug!(provider, model, matched, "fuzzy model match");
247 return Some(info);
248 }
249
250 None
251 }
252
253 pub fn get_pricing(&self, provider: &str, model: &str) -> Option<CostRates> {
256 if let Some(info) = self.get_fuzzy(provider, model) {
257 return Some(info.pricing.clone());
258 }
259
260 self.providers.get(provider).and_then(|p| {
261 if p.cache_read_multiplier.is_some() || p.cache_creation_multiplier.is_some() {
262 Some(CostRates {
263 input_per_million: 0.0,
264 output_per_million: 0.0,
265 cache_read_multiplier: p.cache_read_multiplier,
266 cache_creation_multiplier: p.cache_creation_multiplier,
267 })
268 } else {
269 None
270 }
271 })
272 }
273
274 pub fn provider(&self, name: &str) -> Option<&ProviderInfo> {
278 self.providers.get(name)
279 }
280
281 pub fn provider_names(&self) -> Vec<&str> {
283 let mut names: Vec<&str> = self.providers.keys().map(|s| s.as_str()).collect();
284 names.sort();
285 names
286 }
287
288 pub fn default_model(&self, provider: &str) -> Option<&str> {
290 self.providers
291 .get(provider)
292 .and_then(|p| p.default_model.as_deref())
293 }
294
295 pub fn api_key_env(&self, provider: &str) -> Option<&str> {
297 self.providers
298 .get(provider)
299 .and_then(|p| p.api_key_env.as_deref())
300 }
301
302 pub fn models_for_provider(&self, provider: &str) -> Vec<&str> {
304 let prefix = format!("{provider}::");
305 let mut out: Vec<&str> = self
306 .models
307 .iter()
308 .filter_map(|(key, info)| {
309 if key.starts_with(&prefix) {
310 Some(info.id.as_str())
311 } else {
312 None
313 }
314 })
315 .collect();
316 out.sort();
317 out
318 }
319
320 pub fn models_by_provider(&self) -> HashMap<String, Vec<String>> {
322 let mut result: HashMap<String, Vec<String>> = HashMap::new();
323 for prov in self.providers.keys() {
324 result.insert(
325 prov.clone(),
326 self.models_for_provider(prov)
327 .into_iter()
328 .map(String::from)
329 .collect(),
330 );
331 }
332 result
333 }
334
335 pub fn len(&self) -> usize {
337 self.models.len()
338 }
339
340 pub fn is_empty(&self) -> bool {
342 self.models.is_empty()
343 }
344}
345
346impl Default for ModelRegistry {
347 fn default() -> Self {
348 Self::with_defaults()
349 }
350}
351
352pub type PricingRegistry = ModelRegistry;
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn defaults_load_successfully() {
361 let reg = ModelRegistry::with_defaults();
362 assert!(!reg.is_empty());
363 }
364
365 #[test]
366 fn exact_match() {
367 let reg = ModelRegistry::with_defaults();
368 let info = reg.get("anthropic", "claude-sonnet-4-5").unwrap();
369 assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
370 assert!((info.pricing.output_per_million - 15.0).abs() < 1e-9);
371 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
372 assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
373 assert_eq!(info.context_window, Some(200_000));
374 assert!(info.supports_tool_use);
375 assert!(info.supports_vision);
376 }
377
378 #[test]
379 fn fuzzy_match_longer_model_id() {
380 let reg = ModelRegistry::with_defaults();
381 let info = reg.get_fuzzy("anthropic", "claude-sonnet-4-5-20250514").unwrap();
382 assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
383 }
384
385 #[test]
386 fn fuzzy_match_picks_most_specific() {
387 let mut reg = ModelRegistry::new();
388 let short_key = make_key("test", "claude-sonnet");
389 reg.models.insert(short_key, ModelInfo {
390 id: "claude-sonnet".into(),
391 provider: "test".into(),
392 pricing: CostRates {
393 input_per_million: 1.0,
394 output_per_million: 5.0,
395 cache_read_multiplier: None,
396 cache_creation_multiplier: None,
397 },
398 context_window: None,
399 supports_tool_use: true,
400 supports_vision: false,
401 });
402 let long_key = make_key("test", "claude-sonnet-4-5");
403 reg.models.insert(long_key, ModelInfo {
404 id: "claude-sonnet-4-5".into(),
405 provider: "test".into(),
406 pricing: CostRates {
407 input_per_million: 3.0,
408 output_per_million: 15.0,
409 cache_read_multiplier: None,
410 cache_creation_multiplier: None,
411 },
412 context_window: None,
413 supports_tool_use: true,
414 supports_vision: false,
415 });
416 let info = reg.get_fuzzy("test", "claude-sonnet-4-5-20250514").unwrap();
417 assert!((info.pricing.input_per_million - 3.0).abs() < 1e-9);
418 }
419
420 #[test]
421 fn provider_default_cache_multipliers() {
422 let reg = ModelRegistry::with_defaults();
423 let pricing = reg.get_pricing("anthropic", "claude-unknown-99").unwrap();
424 assert!((pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
425 }
426
427 #[test]
428 fn merge_overrides() {
429 let mut base = ModelRegistry::with_defaults();
430 let overrides = ModelRegistry::from_toml(r#"
431[anthropic.models.claude-sonnet-4-5]
432input = 99.0
433output = 99.0
434"#).unwrap();
435 base.merge(overrides);
436 let info = base.get("anthropic", "claude-sonnet-4-5").unwrap();
437 assert!((info.pricing.input_per_million - 99.0).abs() < 1e-9);
438 }
439
440 #[test]
441 fn openai_cache_rates() {
442 let reg = ModelRegistry::with_defaults();
443 let info = reg.get("openai", "gpt-4o").unwrap();
444 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
445 assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.0).abs() < 1e-9);
446 }
447
448 #[test]
449 fn gemini_cache_rates() {
450 let reg = ModelRegistry::with_defaults();
451 let info = reg.get_fuzzy("gemini", "gemini-2-5-flash").unwrap();
452 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.1).abs() < 1e-9);
453 }
454
455 #[test]
456 fn from_toml_custom() {
457 let toml = r#"
458[custom]
459cache_read_multiplier = 0.3
460
461[custom.models.my-model]
462input = 5.0
463output = 20.0
464"#;
465 let reg = ModelRegistry::from_toml(toml).unwrap();
466 let info = reg.get("custom", "my-model").unwrap();
467 assert!((info.pricing.input_per_million - 5.0).abs() < 1e-9);
468 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.3).abs() < 1e-9);
469 assert!(info.pricing.cache_creation_multiplier.is_none());
470 }
471
472 #[test]
473 fn per_model_cache_override() {
474 let toml = r#"
475[prov]
476cache_read_multiplier = 0.1
477cache_creation_multiplier = 1.25
478
479[prov.models.special]
480input = 10.0
481output = 50.0
482cache_read_multiplier = 0.05
483"#;
484 let reg = ModelRegistry::from_toml(toml).unwrap();
485 let info = reg.get("prov", "special").unwrap();
486 assert!((info.pricing.cache_read_multiplier.unwrap() - 0.05).abs() < 1e-9);
487 assert!((info.pricing.cache_creation_multiplier.unwrap() - 1.25).abs() < 1e-9);
488 }
489
490 #[test]
491 fn empty_provider_no_panic() {
492 let toml = r#"
493[empty]
494"#;
495 let reg = ModelRegistry::from_toml(toml).unwrap();
496 assert!(reg.get("empty", "anything").is_none());
497 assert!(reg.get_fuzzy("empty", "anything").is_none());
498 }
499
500 #[test]
503 fn default_model_per_provider() {
504 let reg = ModelRegistry::with_defaults();
505 assert_eq!(reg.default_model("anthropic"), Some("claude-haiku-4-5"));
506 assert_eq!(reg.default_model("openai"), Some("gpt-4o"));
507 assert_eq!(reg.default_model("gemini"), Some("gemini-2.5-pro"));
508 assert_eq!(reg.default_model("groq"), Some("llama-3.3-70b-versatile"));
509 assert_eq!(reg.default_model("deepseek"), Some("deepseek-chat"));
510 assert_eq!(reg.default_model("ollama"), Some("qwen3.5:9b"));
511 }
512
513 #[test]
514 fn api_key_env_per_provider() {
515 let reg = ModelRegistry::with_defaults();
516 assert_eq!(reg.api_key_env("anthropic"), Some("ANTHROPIC_API_KEY"));
517 assert_eq!(reg.api_key_env("openai"), Some("OPENAI_API_KEY"));
518 assert_eq!(reg.api_key_env("ollama"), None);
519 }
520
521 #[test]
522 fn models_for_provider_lists_all() {
523 let reg = ModelRegistry::with_defaults();
524 let anthropic = reg.models_for_provider("anthropic");
525 assert!(anthropic.contains(&"claude-haiku-4-5"));
526 assert!(anthropic.contains(&"claude-sonnet-4-6"));
527 assert!(anthropic.contains(&"claude-opus-4-6"));
528 assert!(anthropic.len() >= 4);
529 }
530
531 #[test]
532 fn models_by_provider_for_settings_api() {
533 let reg = ModelRegistry::with_defaults();
534 let map = reg.models_by_provider();
535 assert!(map.contains_key("anthropic"));
536 assert!(map.contains_key("openai"));
537 assert!(map.contains_key("ollama"));
538 assert!(map["ollama"].is_empty());
539 }
540
541 #[test]
542 fn provider_names_returns_all() {
543 let reg = ModelRegistry::with_defaults();
544 let names = reg.provider_names();
545 assert!(names.contains(&"anthropic"));
546 assert!(names.contains(&"openai"));
547 assert!(names.contains(&"gemini"));
548 assert!(names.contains(&"groq"));
549 assert!(names.contains(&"deepseek"));
550 assert!(names.contains(&"openrouter"));
551 assert!(names.contains(&"ollama"));
552 }
553
554 #[test]
555 fn model_capabilities() {
556 let reg = ModelRegistry::with_defaults();
557 let haiku = reg.get("anthropic", "claude-haiku-4-5").unwrap();
558 assert!(haiku.supports_tool_use);
559 assert!(haiku.supports_vision);
560
561 let gpt41 = reg.get("openai", "gpt-4.1").unwrap();
562 assert!(gpt41.supports_tool_use);
563 assert!(!gpt41.supports_vision);
564 }
565
566 #[test]
569 fn register_makes_model_visible_via_get() {
570 let mut reg = ModelRegistry::new();
571 reg.register("ollama", "qwen3.5:9b", ModelInfo {
572 id: "qwen3.5:9b".into(),
573 provider: "ollama".into(),
574 pricing: CostRates {
575 input_per_million: 0.0,
576 output_per_million: 0.0,
577 cache_read_multiplier: None,
578 cache_creation_multiplier: None,
579 },
580 context_window: Some(262_144),
581 supports_tool_use: true,
582 supports_vision: true,
583 });
584 let info = reg.get("ollama", "qwen3.5:9b").unwrap();
585 assert_eq!(info.context_window, Some(262_144));
586 assert!(info.supports_vision);
587 }
588
589 #[test]
590 fn register_appears_in_models_for_provider() {
591 let mut reg = ModelRegistry::with_defaults();
592 assert!(reg.models_for_provider("ollama").is_empty());
593
594 reg.register("ollama", "llama3:8b", ModelInfo {
595 id: "llama3:8b".into(),
596 provider: "ollama".into(),
597 pricing: CostRates {
598 input_per_million: 0.0,
599 output_per_million: 0.0,
600 cache_read_multiplier: None,
601 cache_creation_multiplier: None,
602 },
603 context_window: Some(131_072),
604 supports_tool_use: true,
605 supports_vision: false,
606 });
607 let models = reg.models_for_provider("ollama");
608 assert_eq!(models, vec!["llama3:8b"]);
609 }
610
611 #[test]
612 fn register_overrides_existing() {
613 let mut reg = ModelRegistry::with_defaults();
614 let original = reg.get("anthropic", "claude-haiku-4-5").unwrap();
615 assert!(original.pricing.input_per_million > 0.0);
616
617 reg.register("anthropic", "claude-haiku-4-5", ModelInfo {
618 id: "claude-haiku-4-5".into(),
619 provider: "anthropic".into(),
620 pricing: CostRates {
621 input_per_million: 99.0,
622 output_per_million: 99.0,
623 cache_read_multiplier: None,
624 cache_creation_multiplier: None,
625 },
626 context_window: Some(200_000),
627 supports_tool_use: true,
628 supports_vision: true,
629 });
630 let updated = reg.get("anthropic", "claude-haiku-4-5").unwrap();
631 assert!((updated.pricing.input_per_million - 99.0).abs() < 1e-9);
632 }
633}