1use std::collections::{HashMap, HashSet};
2use std::path::{Path, PathBuf};
3
4use tracing::debug;
5
6use crate::config::model::{ModelConfig, ModelId};
7use crate::config::provider::ProviderId;
8use crate::config::toml_types::Catalog;
9use crate::error::Error;
10
11const DEFAULT_CATALOG_TOML: &str = include_str!("../assets/default_catalog.toml");
12
13#[derive(Debug, Clone)]
15pub struct ModelRegistry {
16 models: HashMap<ModelId, ModelConfig>,
18 aliases: HashMap<String, ModelId>,
20 providers: HashSet<ProviderId>,
22}
23
24impl ModelRegistry {
25 pub fn load(additional_catalogs: &[String]) -> Result<Self, Error> {
32 let builtin_catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML)
34 .map_err(|e| Error::Configuration(format!("Failed to parse default catalog: {e}")))?;
35
36 let mut models: Vec<ModelConfig> = builtin_catalog
38 .models
39 .into_iter()
40 .map(ModelConfig::from)
41 .collect();
42
43 let mut known_providers: HashMap<ProviderId, bool> = HashMap::new();
45 for p in builtin_catalog.providers {
46 known_providers.insert(ProviderId(p.id), true);
47 }
48
49 for path in Self::discover_catalog_paths() {
51 if let Some(catalog) = Self::load_catalog_file(&path)? {
52 for p in catalog.providers {
53 known_providers.insert(ProviderId(p.id), true);
54 }
55 let more_models: Vec<ModelConfig> =
56 catalog.models.into_iter().map(ModelConfig::from).collect();
57 Self::merge_models(&mut models, more_models);
58 }
59 }
60
61 for catalog_path in additional_catalogs {
63 if let Some(catalog) = Self::load_catalog_file(Path::new(catalog_path))? {
64 for p in catalog.providers {
66 known_providers.insert(ProviderId(p.id), true);
67 }
68
69 let catalog_models: Vec<ModelConfig> =
71 catalog.models.into_iter().map(ModelConfig::from).collect();
72 Self::merge_models(&mut models, catalog_models);
73 }
74 }
75
76 for model in &models {
78 if !known_providers.contains_key(&model.provider) {
79 return Err(Error::Configuration(format!(
80 "Model '{}' references unknown provider '{}'",
81 model.id, model.provider
82 )));
83 }
84 }
85
86 let mut registry = Self {
88 models: HashMap::new(),
89 aliases: HashMap::new(),
90 providers: HashSet::new(),
91 };
92
93 for model in models {
94 let model_id = (model.provider.clone(), model.id.clone());
95
96 registry.providers.insert(model.provider.clone());
98
99 for raw in &model.aliases {
101 let alias = raw.trim();
102 if alias.is_empty() {
103 return Err(Error::Configuration(format!(
104 "Empty alias found for {}/{}",
105 model_id.0.storage_key(),
106 model_id.1,
107 )));
108 }
109 if let Some(existing) = registry.aliases.get(alias) {
110 if existing != &model_id {
111 return Err(Error::Configuration(format!(
112 "Duplicate alias '{}' used by {}/{} and {}/{}",
113 alias,
114 existing.0.storage_key(),
115 existing.1,
116 model_id.0.storage_key(),
117 model_id.1,
118 )));
119 }
120 }
121 registry.aliases.insert(alias.to_string(), model_id.clone());
122 }
123
124 registry.models.insert(model_id, model);
126 }
127
128 debug!(
129 target: "model_registry::load",
130 "Loaded models: {:?}",
131 registry.models
132 );
133
134 {
136 let mut seen: HashMap<ProviderId, HashSet<String>> = HashMap::new();
137 for ((prov, _mid), cfg) in ®istry.models {
138 if let Some(name_raw) = cfg.display_name.as_deref() {
139 let name = name_raw.trim();
140 if name.is_empty() {
141 return Err(Error::Configuration(format!(
142 "Invalid display_name '{}' for {}/{}",
143 name_raw,
144 prov.storage_key(),
145 cfg.id
146 )));
147 }
148 let set = seen.entry(prov.clone()).or_default();
149 if !set.insert(name.to_string()) {
150 return Err(Error::Configuration(format!(
151 "Duplicate display_name '{}' for provider {}",
152 name,
153 prov.storage_key()
154 )));
155 }
156 }
157 }
158 }
159
160 Ok(registry)
164 }
165
166 pub fn get(&self, id: &ModelId) -> Option<&ModelConfig> {
168 self.models.get(id)
169 }
170
171 pub fn by_alias(&self, alias: &str) -> Option<&ModelConfig> {
173 self.aliases.get(alias).and_then(|id| self.models.get(id))
174 }
175 pub fn resolve(&self, input: &str) -> Result<ModelId, Error> {
182 if let Some((provider_str, part_raw)) = input.split_once('/') {
183 let provider: ProviderId = ProviderId(provider_str.to_string());
185 let provider_known = self.providers.contains(&provider);
186 if !provider_known {
187 return Err(Error::Configuration(format!(
188 "Unknown provider: {provider_str}"
189 )));
190 }
191
192 let part = part_raw.trim();
193 if part.is_empty() {
194 return Err(Error::Configuration(
195 "Model name cannot be empty".to_string(),
196 ));
197 }
198
199 let candidate = (provider.clone(), part.to_string());
201 if self.models.contains_key(&candidate) {
202 return Ok(candidate);
203 }
204
205 if let Some(alias_id) = self.aliases.get(part) {
207 if alias_id.0 == provider {
208 return Ok(alias_id.clone());
209 }
210 }
211
212 Err(Error::Configuration(format!(
213 "Unknown model or alias: {input}"
214 )))
215 } else {
216 self.by_alias(input)
217 .map(|config| (config.provider.clone(), config.id.clone()))
218 .ok_or_else(|| Error::Configuration(format!("Unknown model or alias: {input}")))
219 }
220 }
221
222 pub fn recommended(&self) -> impl Iterator<Item = &ModelConfig> {
223 self.models.values().filter(|model| model.recommended)
224 }
225
226 pub fn all(&self) -> impl Iterator<Item = &ModelConfig> {
228 self.models.values()
229 }
230
231 fn load_catalog_file(path: &Path) -> Result<Option<Catalog>, Error> {
233 if !path.exists() {
234 return Ok(None);
235 }
236
237 let content = std::fs::read_to_string(path).map_err(Error::Io)?;
238 let catalog: Catalog = toml::from_str(&content).map_err(|e| {
240 Error::Configuration(format!(
241 "Failed to parse catalog at {}: {}",
242 path.display(),
243 e
244 ))
245 })?;
246 Ok(Some(catalog))
247 }
248
249 fn discover_catalog_paths() -> Vec<PathBuf> {
251 let paths: Vec<PathBuf> = crate::utils::paths::AppPaths::discover_catalogs();
253 paths
255 }
256
257 fn merge_models(base: &mut Vec<ModelConfig>, user_models: Vec<ModelConfig>) {
260 let mut existing_models: HashMap<(ProviderId, String), usize> = HashMap::new();
262 for (idx, model) in base.iter().enumerate() {
263 existing_models.insert((model.provider.clone(), model.id.clone()), idx);
264 }
265
266 for user_model in user_models {
268 let key = (user_model.provider.clone(), user_model.id.clone());
269
270 if let Some(&idx) = existing_models.get(&key) {
271 base[idx].merge_with(user_model);
273 } else {
274 base.push(user_model);
276 }
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use crate::config::provider;
285
286 #[test]
287 fn test_load_builtin_models() {
288 let catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML).unwrap();
290 assert!(!catalog.models.is_empty());
291 assert!(!catalog.providers.is_empty());
292
293 let has_claude = catalog
295 .models
296 .iter()
297 .any(|m| m.provider == "anthropic" && m.id.contains("claude"));
298 assert!(has_claude, "Should have at least one Claude model");
299 }
300
301 #[test]
302 fn test_registry_creation() {
303 let toml = r#"
305[[providers]]
306id = "anthropic"
307name = "Anthropic"
308api_format = "anthropic"
309auth_schemes = ["api-key"]
310
311[[models]]
312provider = "anthropic"
313id = "test-model"
314aliases = ["test", "tm"]
315recommended = true
316parameters = { thinking_config = { enabled = true } }
317"#;
318
319 let catalog: Catalog = toml::from_str(toml).unwrap();
320 let models: Vec<ModelConfig> = catalog.models.into_iter().map(ModelConfig::from).collect();
322
323 let mut registry = ModelRegistry {
324 models: HashMap::new(),
325 aliases: HashMap::new(),
326 providers: HashSet::new(),
327 };
328
329 for model in models {
330 let model_id = (model.provider.clone(), model.id.clone());
331
332 registry.providers.insert(model.provider.clone());
334
335 for alias in &model.aliases {
336 registry.aliases.insert(alias.clone(), model_id.clone());
337 }
338
339 registry.models.insert(model_id, model);
340 }
341
342 let model_id = (provider::anthropic(), "test-model".to_string());
344 let model = registry.get(&model_id).unwrap();
345 assert_eq!(model.id, "test-model");
346 assert!(model.recommended);
347
348 assert!(model.parameters.is_some());
350 let params = model.parameters.unwrap();
351 assert!(params.thinking_config.is_some());
352 assert!(params.thinking_config.unwrap().enabled);
353
354 let model_by_alias = registry.by_alias("test").unwrap();
356 assert_eq!(model_by_alias.id, "test-model");
357
358 let model_by_alias2 = registry.by_alias("tm").unwrap();
359 assert_eq!(model_by_alias2.id, "test-model");
360
361 let recommended: Vec<_> = registry.recommended().collect();
363 assert_eq!(recommended.len(), 1);
364 assert_eq!(recommended[0].id, "test-model");
365 }
366
367 #[test]
368 fn test_merge_models() {
369 let base_toml = r#"
370[[providers]]
371id = "anthropic"
372name = "Anthropic"
373api_format = "anthropic"
374auth_schemes = ["api-key"]
375
376[[providers]]
377id = "openai"
378name = "OpenAI"
379api_format = "openai-responses"
380auth_schemes = ["api-key"]
381
382[[models]]
383provider = "anthropic"
384id = "claude-3"
385aliases = ["claude"]
386recommended = false
387parameters = { temperature = 0.7, max_tokens = 2048 }
388
389[[models]]
390provider = "openai"
391id = "gpt-4"
392aliases = ["gpt"]
393recommended = true
394"#;
395
396 let user_toml = r#"
397[[providers]]
398id = "google"
399name = "Google"
400api_format = "google"
401auth_schemes = ["api-key"]
402
403[[models]]
404provider = "anthropic"
405id = "claude-3"
406aliases = ["c3", "claude3"]
407recommended = true
408parameters = { temperature = 0.9, thinking_config = { enabled = true } }
409
410[[models]]
411provider = "google"
412id = "gemini-pro"
413aliases = ["gemini"]
414recommended = true
415parameters = { temperature = 0.5, top_p = 0.95 }
416"#;
417
418 let base: Catalog = toml::from_str(base_toml).unwrap();
419 let user: Catalog = toml::from_str(user_toml).unwrap();
420
421 let base_models: Vec<_> = base.models.into_iter().map(ModelConfig::from).collect();
423 let user_models: Vec<_> = user.models.into_iter().map(ModelConfig::from).collect();
424
425 let mut base_models_mut = base_models;
426 ModelRegistry::merge_models(&mut base_models_mut, user_models);
427
428 assert_eq!(base_models_mut.len(), 3);
430
431 let claude = base_models_mut
433 .iter()
434 .find(|m| m.provider == provider::anthropic() && m.id == "claude-3")
435 .unwrap();
436
437 assert_eq!(claude.aliases.len(), 3);
439 assert!(claude.aliases.contains(&"claude".to_string()));
440 assert!(claude.aliases.contains(&"c3".to_string()));
441 assert!(claude.aliases.contains(&"claude3".to_string()));
442
443 assert!(claude.recommended);
445
446 assert!(claude.parameters.is_some());
448 let claude_params = claude.parameters.unwrap();
449 assert_eq!(claude_params.temperature, Some(0.9)); assert_eq!(claude_params.max_tokens, Some(2048)); assert!(claude_params.thinking_config.is_some());
452 assert!(claude_params.thinking_config.unwrap().enabled);
453
454 let gpt4 = base_models_mut
456 .iter()
457 .find(|m| m.provider == provider::openai() && m.id == "gpt-4")
458 .unwrap();
459 assert!(gpt4.recommended);
460 assert!(gpt4.parameters.is_none()); let gemini = base_models_mut
464 .iter()
465 .find(|m| m.provider == provider::google() && m.id == "gemini-pro")
466 .unwrap();
467 assert!(gemini.recommended);
468 assert!(gemini.parameters.is_some());
469 let gemini_params = gemini.parameters.unwrap();
470 assert_eq!(gemini_params.temperature, Some(0.5));
471 assert_eq!(gemini_params.top_p, Some(0.95));
472 }
473
474 #[test]
475 fn test_load_catalog_from_path() {
476 use std::fs;
477 use tempfile::TempDir;
478
479 let dir = TempDir::new().unwrap();
480 let config_path = dir.path().join("test_catalog.toml");
481
482 let config = r#"
483[[providers]]
484id = "anthropic"
485name = "Anthropic"
486api_format = "anthropic"
487auth_schemes = ["api-key"]
488
489[[models]]
490provider = "anthropic"
491id = "test-model"
492aliases = ["test"]
493recommended = true
494"#;
495
496 fs::write(&config_path, config).unwrap();
497
498 let result = ModelRegistry::load_catalog_file(&config_path).unwrap();
499 assert!(result.is_some());
500
501 let catalog = result.unwrap();
502 assert_eq!(catalog.models.len(), 1);
503 assert_eq!(catalog.models[0].id, "test-model");
504 assert_eq!(catalog.providers.len(), 1);
505 assert_eq!(catalog.providers[0].id, "anthropic");
506 }
507
508 #[test]
509 fn test_resolve_by_provider_and_parts() {
510 let mut registry = ModelRegistry {
512 models: HashMap::new(),
513 aliases: HashMap::new(),
514 providers: HashSet::new(),
515 };
516 let prov = provider::anthropic();
517
518 let m1 = ModelConfig {
519 provider: prov.clone(),
520 id: "id-1".to_string(),
521 display_name: Some("NiceName".to_string()),
522 aliases: vec!["alias1".into()],
523 recommended: false,
524 parameters: None,
525 };
526 let m2 = ModelConfig {
527 provider: prov.clone(),
528 id: "id-2".to_string(),
529 display_name: Some("Other".to_string()),
530 aliases: vec!["alias2".into()],
531 recommended: false,
532 parameters: None,
533 };
534 let id1 = (prov.clone(), m1.id.clone());
535 let id2 = (prov.clone(), m2.id.clone());
536 registry.aliases.insert("alias1".into(), id1.clone());
537 registry.aliases.insert("alias2".into(), id2.clone());
538 registry.models.insert(id1.clone(), m1.clone());
539 registry.models.insert(id2.clone(), m2.clone());
540 registry.providers.insert(prov.clone());
541
542 assert_eq!(registry.resolve("anthropic/id-1").unwrap(), id1);
544 assert!(registry.resolve("anthropic/NiceName").is_err());
546 assert_eq!(registry.resolve("anthropic/alias2").unwrap(), id2);
548 assert!(registry.resolve("anthropic/does-not-exist").is_err());
550 }
551
552 #[test]
553 fn test_resolve_by_display_name_is_not_supported() {
554 let mut registry = ModelRegistry {
556 models: HashMap::new(),
557 aliases: HashMap::new(),
558 providers: HashSet::new(),
559 };
560 let prov = provider::anthropic();
561 let m1 = ModelConfig {
562 provider: prov.clone(),
563 id: "id-1".into(),
564 display_name: Some("Same".into()),
565 aliases: vec![],
566 recommended: false,
567 parameters: None,
568 };
569 let m2 = ModelConfig {
570 provider: prov.clone(),
571 id: "id-2".into(),
572 display_name: Some("Same".into()),
573 aliases: vec![],
574 recommended: false,
575 parameters: None,
576 };
577 let id1 = (prov.clone(), m1.id.clone());
578 let id2 = (prov.clone(), m2.id.clone());
579 registry.models.insert(id1, m1);
580 registry.models.insert(id2, m2);
581 registry.providers.insert(prov.clone());
582
583 let err = registry.resolve("anthropic/Same").unwrap_err();
585 match err {
586 Error::Configuration(msg) => assert!(msg.contains("Unknown model or alias")),
587 _ => panic!("unexpected error type"),
588 }
589 }
590
591 #[test]
592 fn test_load_rejects_invalid_or_duplicate_display_names() {
593 use std::fs;
594 use tempfile::TempDir;
595
596 let dir = TempDir::new().unwrap();
597 let bad_path = dir.path().join("bad_catalog.toml");
598 let dup_path = dir.path().join("dup_catalog.toml");
599
600 let bad = r#"
602[[providers]]
603id = "custom"
604name = "Custom"
605api_format = "openai-responses"
606auth_schemes = ["api-key"]
607
608[[models]]
609provider = "custom"
610id = "m1"
611display_name = ""
612"#;
613 fs::write(&bad_path, bad).unwrap();
614 let res = ModelRegistry::load(&[bad_path.to_string_lossy().to_string()]);
615 assert!(matches!(res, Err(Error::Configuration(_))));
616
617 let dup = r#"
619[[providers]]
620id = "custom"
621name = "Custom"
622api_format = "openai-responses"
623auth_schemes = ["api-key"]
624
625[[models]]
626provider = "custom"
627id = "m1"
628display_name = "Same"
629
630[[models]]
631provider = "custom"
632id = "m2"
633display_name = "Same"
634"#;
635 fs::write(&dup_path, dup).unwrap();
636 let res2 = ModelRegistry::load(&[dup_path.to_string_lossy().to_string()]);
637 assert!(matches!(res2, Err(Error::Configuration(_))));
638 }
639
640 #[test]
641 fn test_duplicate_aliases_across_providers_error() {
642 use std::fs;
644 use tempfile::TempDir;
645
646 let dir = TempDir::new().unwrap();
647 let path = dir.path().join("alias_conflict.toml");
648 let toml = r#"
649[[providers]]
650id = "p1"
651name = "P1"
652api_format = "openai-responses"
653auth_schemes = ["api-key"]
654
655[[providers]]
656id = "p2"
657name = "P2"
658api_format = "openai-responses"
659auth_schemes = ["api-key"]
660
661[[models]]
662provider = "p1"
663id = "m1"
664aliases = ["shared"]
665
666[[models]]
667provider = "p2"
668id = "m2"
669aliases = ["shared"]
670"#;
671 fs::write(&path, toml).unwrap();
672 let res = ModelRegistry::load(&[path.to_string_lossy().to_string()]);
673 assert!(matches!(res, Err(Error::Configuration(_))));
674 }
675}