1use std::collections::{HashMap, HashSet};
2use std::path::{Path, PathBuf};
3
4use tracing::debug;
5
6use crate::config::model::{ModelConfig, ModelId};
7use crate::config::provider::ProviderId;
8use crate::config::toml_types::Catalog;
9use crate::error::Error;
10
11const DEFAULT_CATALOG_TOML: &str = include_str!("../assets/default_catalog.toml");
12
13#[derive(Debug, Clone)]
15pub struct ModelRegistry {
16 models: HashMap<ModelId, ModelConfig>,
18 aliases: HashMap<String, ModelId>,
20 providers: HashSet<ProviderId>,
22}
23
24impl ModelRegistry {
25 pub fn load(additional_catalogs: &[String]) -> Result<Self, Error> {
32 let builtin_catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML)
34 .map_err(|e| Error::Configuration(format!("Failed to parse default catalog: {e}")))?;
35
36 let mut models: Vec<ModelConfig> = builtin_catalog
38 .models
39 .into_iter()
40 .map(ModelConfig::from)
41 .collect();
42
43 let mut known_providers: HashMap<ProviderId, bool> = HashMap::new();
45 for p in builtin_catalog.providers {
46 known_providers.insert(ProviderId(p.id), true);
47 }
48
49 for path in Self::discover_catalog_paths() {
51 if let Some(catalog) = Self::load_catalog_file(&path)? {
52 for p in catalog.providers {
53 known_providers.insert(ProviderId(p.id), true);
54 }
55 let more_models: Vec<ModelConfig> =
56 catalog.models.into_iter().map(ModelConfig::from).collect();
57 Self::merge_models(&mut models, more_models);
58 }
59 }
60
61 for catalog_path in additional_catalogs {
63 if let Some(catalog) = Self::load_catalog_file(Path::new(catalog_path))? {
64 for p in catalog.providers {
66 known_providers.insert(ProviderId(p.id), true);
67 }
68
69 let catalog_models: Vec<ModelConfig> =
71 catalog.models.into_iter().map(ModelConfig::from).collect();
72 Self::merge_models(&mut models, catalog_models);
73 }
74 }
75
76 for model in &models {
78 if !known_providers.contains_key(&model.provider) {
79 return Err(Error::Configuration(format!(
80 "Model '{}' references unknown provider '{}'",
81 model.id, model.provider
82 )));
83 }
84 }
85
86 let mut registry = Self {
88 models: HashMap::new(),
89 aliases: HashMap::new(),
90 providers: HashSet::new(),
91 };
92
93 for model in models {
94 let model_id = ModelId::new(model.provider.clone(), model.id.clone());
95
96 registry.providers.insert(model.provider.clone());
98
99 for raw in &model.aliases {
101 let alias = raw.trim();
102 if alias.is_empty() {
103 return Err(Error::Configuration(format!(
104 "Empty alias found for {}/{}",
105 model_id.provider.storage_key(),
106 model_id.id.as_str(),
107 )));
108 }
109 if let Some(existing) = registry.aliases.get(alias)
110 && existing != &model_id
111 {
112 return Err(Error::Configuration(format!(
113 "Duplicate alias '{}' used by {}/{} and {}/{}",
114 alias,
115 existing.provider.storage_key(),
116 existing.id.as_str(),
117 model_id.provider.storage_key(),
118 model_id.id.as_str(),
119 )));
120 }
121 registry.aliases.insert(alias.to_string(), model_id.clone());
122 }
123
124 registry.models.insert(model_id, model);
126 }
127
128 debug!(
129 target: "model_registry::load",
130 "Loaded models: {:?}",
131 registry.models
132 );
133
134 {
136 let mut seen: HashMap<ProviderId, HashSet<String>> = HashMap::new();
137 for (model_id, cfg) in ®istry.models {
138 if let Some(name_raw) = cfg.display_name.as_deref() {
139 let name = name_raw.trim();
140 if name.is_empty() {
141 return Err(Error::Configuration(format!(
142 "Invalid display_name '{}' for {}/{}",
143 name_raw,
144 model_id.provider.storage_key(),
145 cfg.id
146 )));
147 }
148 let set = seen.entry(model_id.provider.clone()).or_default();
149 if !set.insert(name.to_string()) {
150 return Err(Error::Configuration(format!(
151 "Duplicate display_name '{}' for provider {}",
152 name,
153 model_id.provider.storage_key()
154 )));
155 }
156 }
157 }
158 }
159
160 Ok(registry)
164 }
165
166 pub fn empty() -> Self {
168 Self {
169 models: HashMap::new(),
170 aliases: HashMap::new(),
171 providers: HashSet::new(),
172 }
173 }
174
175 pub fn get(&self, id: &ModelId) -> Option<&ModelConfig> {
177 self.models.get(id)
178 }
179
180 pub fn by_alias(&self, alias: &str) -> Option<&ModelConfig> {
182 self.aliases.get(alias).and_then(|id| self.models.get(id))
183 }
184 pub fn resolve(&self, input: &str) -> Result<ModelId, Error> {
191 if let Some((provider_str, part_raw)) = input.split_once('/') {
192 let provider: ProviderId = ProviderId(provider_str.to_string());
194 let provider_known = self.providers.contains(&provider);
195 if !provider_known {
196 return Err(Error::Configuration(format!(
197 "Unknown provider: {provider_str}"
198 )));
199 }
200
201 let part = part_raw.trim();
202 if part.is_empty() {
203 return Err(Error::Configuration(
204 "Model name cannot be empty".to_string(),
205 ));
206 }
207
208 let candidate = ModelId::new(provider.clone(), part.to_string());
210 if self.models.contains_key(&candidate) {
211 return Ok(candidate);
212 }
213
214 if let Some(alias_id) = self.aliases.get(part)
216 && alias_id.provider == provider
217 {
218 return Ok(alias_id.clone());
219 }
220
221 Err(Error::Configuration(format!(
222 "Unknown model or alias: {input}"
223 )))
224 } else {
225 self.by_alias(input)
226 .map(|config| ModelId::new(config.provider.clone(), config.id.clone()))
227 .ok_or_else(|| Error::Configuration(format!("Unknown model or alias: {input}")))
228 }
229 }
230
231 pub fn recommended(&self) -> impl Iterator<Item = &ModelConfig> {
232 self.models.values().filter(|model| model.recommended)
233 }
234
235 pub fn all(&self) -> impl Iterator<Item = &ModelConfig> {
237 self.models.values()
238 }
239
240 fn load_catalog_file(path: &Path) -> Result<Option<Catalog>, Error> {
242 if !path.exists() {
243 return Ok(None);
244 }
245
246 let content = std::fs::read_to_string(path).map_err(Error::Io)?;
247 let catalog: Catalog = toml::from_str(&content).map_err(|e| {
249 Error::Configuration(format!(
250 "Failed to parse catalog at {}: {}",
251 path.display(),
252 e
253 ))
254 })?;
255 Ok(Some(catalog))
256 }
257
258 fn discover_catalog_paths() -> Vec<PathBuf> {
260 let paths: Vec<PathBuf> = crate::utils::paths::AppPaths::discover_catalogs();
262 paths
264 }
265
266 fn merge_models(base: &mut Vec<ModelConfig>, user_models: Vec<ModelConfig>) {
269 let mut existing_models: HashMap<ModelId, usize> = HashMap::new();
271 for (idx, model) in base.iter().enumerate() {
272 existing_models.insert(ModelId::new(model.provider.clone(), model.id.clone()), idx);
273 }
274
275 for user_model in user_models {
277 let key = ModelId::new(user_model.provider.clone(), user_model.id.clone());
278
279 if let Some(&idx) = existing_models.get(&key) {
280 base[idx].merge_with(user_model);
282 } else {
283 base.push(user_model);
285 }
286 }
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use crate::config::provider;
294
295 #[test]
296 fn test_load_builtin_models() {
297 let catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML).unwrap();
299 assert!(!catalog.models.is_empty());
300 assert!(!catalog.providers.is_empty());
301
302 let has_claude = catalog
304 .models
305 .iter()
306 .any(|m| m.provider == "anthropic" && m.id.contains("claude"));
307 assert!(has_claude, "Should have at least one Claude model");
308 }
309
310 #[test]
311 fn test_registry_creation() {
312 let toml = r#"
314[[providers]]
315id = "anthropic"
316name = "Anthropic"
317api_format = "anthropic"
318auth_schemes = ["api-key"]
319
320[[models]]
321provider = "anthropic"
322id = "test-model"
323aliases = ["test", "tm"]
324recommended = true
325parameters = { thinking_config = { enabled = true } }
326"#;
327
328 let catalog: Catalog = toml::from_str(toml).unwrap();
329 let models: Vec<ModelConfig> = catalog.models.into_iter().map(ModelConfig::from).collect();
331
332 let mut registry = ModelRegistry {
333 models: HashMap::new(),
334 aliases: HashMap::new(),
335 providers: HashSet::new(),
336 };
337
338 for model in models {
339 let model_id = ModelId::new(model.provider.clone(), model.id.clone());
340
341 registry.providers.insert(model.provider.clone());
343
344 for alias in &model.aliases {
345 registry.aliases.insert(alias.clone(), model_id.clone());
346 }
347
348 registry.models.insert(model_id, model);
349 }
350
351 let model_id = ModelId::new(provider::anthropic(), "test-model");
353 let model = registry.get(&model_id).unwrap();
354 assert_eq!(model.id, "test-model");
355 assert!(model.recommended);
356
357 assert!(model.parameters.is_some());
359 let params = model.parameters.unwrap();
360 assert!(params.thinking_config.is_some());
361 assert!(params.thinking_config.unwrap().enabled);
362
363 let model_by_alias = registry.by_alias("test").unwrap();
365 assert_eq!(model_by_alias.id, "test-model");
366
367 let model_by_alias2 = registry.by_alias("tm").unwrap();
368 assert_eq!(model_by_alias2.id, "test-model");
369
370 let recommended: Vec<_> = registry.recommended().collect();
372 assert_eq!(recommended.len(), 1);
373 assert_eq!(recommended[0].id, "test-model");
374 }
375
376 #[test]
377 fn test_merge_models() {
378 let base_toml = r#"
379[[providers]]
380id = "anthropic"
381name = "Anthropic"
382api_format = "anthropic"
383auth_schemes = ["api-key"]
384
385[[providers]]
386id = "openai"
387name = "OpenAI"
388api_format = "openai-responses"
389auth_schemes = ["api-key"]
390
391[[models]]
392provider = "anthropic"
393id = "claude-3"
394aliases = ["claude"]
395recommended = false
396parameters = { temperature = 0.7, max_tokens = 2048 }
397
398[[models]]
399provider = "openai"
400id = "gpt-4"
401aliases = ["gpt"]
402recommended = true
403"#;
404
405 let user_toml = r#"
406[[providers]]
407id = "google"
408name = "Google"
409api_format = "google"
410auth_schemes = ["api-key"]
411
412[[models]]
413provider = "anthropic"
414id = "claude-3"
415aliases = ["c3", "claude3"]
416recommended = true
417parameters = { temperature = 0.9, thinking_config = { enabled = true } }
418
419[[models]]
420provider = "google"
421id = "gemini-pro"
422aliases = ["gemini"]
423recommended = true
424parameters = { temperature = 0.5, top_p = 0.95 }
425"#;
426
427 let base: Catalog = toml::from_str(base_toml).unwrap();
428 let user: Catalog = toml::from_str(user_toml).unwrap();
429
430 let base_models: Vec<_> = base.models.into_iter().map(ModelConfig::from).collect();
432 let user_models: Vec<_> = user.models.into_iter().map(ModelConfig::from).collect();
433
434 let mut base_models_mut = base_models;
435 ModelRegistry::merge_models(&mut base_models_mut, user_models);
436
437 assert_eq!(base_models_mut.len(), 3);
439
440 let claude = base_models_mut
442 .iter()
443 .find(|m| m.provider == provider::anthropic() && m.id == "claude-3")
444 .unwrap();
445
446 assert_eq!(claude.aliases.len(), 3);
448 assert!(claude.aliases.contains(&"claude".to_string()));
449 assert!(claude.aliases.contains(&"c3".to_string()));
450 assert!(claude.aliases.contains(&"claude3".to_string()));
451
452 assert!(claude.recommended);
454
455 assert!(claude.parameters.is_some());
457 let claude_params = claude.parameters.unwrap();
458 assert_eq!(claude_params.temperature, Some(0.9)); assert_eq!(claude_params.max_tokens, Some(2048)); assert!(claude_params.thinking_config.is_some());
461 assert!(claude_params.thinking_config.unwrap().enabled);
462
463 let gpt4 = base_models_mut
465 .iter()
466 .find(|m| m.provider == provider::openai() && m.id == "gpt-4")
467 .unwrap();
468 assert!(gpt4.recommended);
469 assert!(gpt4.parameters.is_none()); let gemini = base_models_mut
473 .iter()
474 .find(|m| m.provider == provider::google() && m.id == "gemini-pro")
475 .unwrap();
476 assert!(gemini.recommended);
477 assert!(gemini.parameters.is_some());
478 let gemini_params = gemini.parameters.unwrap();
479 assert_eq!(gemini_params.temperature, Some(0.5));
480 assert_eq!(gemini_params.top_p, Some(0.95));
481 }
482
483 #[test]
484 fn test_load_catalog_from_path() {
485 use std::fs;
486 use tempfile::TempDir;
487
488 let dir = TempDir::new().unwrap();
489 let config_path = dir.path().join("test_catalog.toml");
490
491 let config = r#"
492[[providers]]
493id = "anthropic"
494name = "Anthropic"
495api_format = "anthropic"
496auth_schemes = ["api-key"]
497
498[[models]]
499provider = "anthropic"
500id = "test-model"
501aliases = ["test"]
502recommended = true
503"#;
504
505 fs::write(&config_path, config).unwrap();
506
507 let result = ModelRegistry::load_catalog_file(&config_path).unwrap();
508 assert!(result.is_some());
509
510 let catalog = result.unwrap();
511 assert_eq!(catalog.models.len(), 1);
512 assert_eq!(catalog.models[0].id, "test-model");
513 assert_eq!(catalog.providers.len(), 1);
514 assert_eq!(catalog.providers[0].id, "anthropic");
515 }
516
517 #[test]
518 fn test_resolve_by_provider_and_parts() {
519 let mut registry = ModelRegistry {
521 models: HashMap::new(),
522 aliases: HashMap::new(),
523 providers: HashSet::new(),
524 };
525 let prov = provider::anthropic();
526
527 let m1 = ModelConfig {
528 provider: prov.clone(),
529 id: "id-1".to_string(),
530 display_name: Some("NiceName".to_string()),
531 aliases: vec!["alias1".into()],
532 recommended: false,
533 parameters: None,
534 };
535 let m2 = ModelConfig {
536 provider: prov.clone(),
537 id: "id-2".to_string(),
538 display_name: Some("Other".to_string()),
539 aliases: vec!["alias2".into()],
540 recommended: false,
541 parameters: None,
542 };
543 let id1 = ModelId::new(prov.clone(), m1.id.clone());
544 let id2 = ModelId::new(prov.clone(), m2.id.clone());
545 registry.aliases.insert("alias1".into(), id1.clone());
546 registry.aliases.insert("alias2".into(), id2.clone());
547 registry.models.insert(id1.clone(), m1.clone());
548 registry.models.insert(id2.clone(), m2.clone());
549 registry.providers.insert(prov.clone());
550
551 assert_eq!(registry.resolve("anthropic/id-1").unwrap(), id1);
553 assert!(registry.resolve("anthropic/NiceName").is_err());
555 assert_eq!(registry.resolve("anthropic/alias2").unwrap(), id2);
557 assert!(registry.resolve("anthropic/does-not-exist").is_err());
559 }
560
561 #[test]
562 fn test_resolve_by_display_name_is_not_supported() {
563 let mut registry = ModelRegistry {
565 models: HashMap::new(),
566 aliases: HashMap::new(),
567 providers: HashSet::new(),
568 };
569 let prov = provider::anthropic();
570 let m1 = ModelConfig {
571 provider: prov.clone(),
572 id: "id-1".into(),
573 display_name: Some("Same".into()),
574 aliases: vec![],
575 recommended: false,
576 parameters: None,
577 };
578 let m2 = ModelConfig {
579 provider: prov.clone(),
580 id: "id-2".into(),
581 display_name: Some("Same".into()),
582 aliases: vec![],
583 recommended: false,
584 parameters: None,
585 };
586 let id1 = ModelId::new(prov.clone(), m1.id.clone());
587 let id2 = ModelId::new(prov.clone(), m2.id.clone());
588 registry.models.insert(id1, m1);
589 registry.models.insert(id2, m2);
590 registry.providers.insert(prov.clone());
591
592 let err = registry.resolve("anthropic/Same").unwrap_err();
594 match err {
595 Error::Configuration(msg) => assert!(msg.contains("Unknown model or alias")),
596 _ => panic!("unexpected error type"),
597 }
598 }
599
600 #[test]
601 fn test_load_rejects_invalid_or_duplicate_display_names() {
602 use std::fs;
603 use tempfile::TempDir;
604
605 let dir = TempDir::new().unwrap();
606 let bad_path = dir.path().join("bad_catalog.toml");
607 let dup_path = dir.path().join("dup_catalog.toml");
608
609 let bad = r#"
611[[providers]]
612id = "custom"
613name = "Custom"
614api_format = "openai-responses"
615auth_schemes = ["api-key"]
616
617[[models]]
618provider = "custom"
619id = "m1"
620display_name = ""
621"#;
622 fs::write(&bad_path, bad).unwrap();
623 let res = ModelRegistry::load(&[bad_path.to_string_lossy().to_string()]);
624 assert!(matches!(res, Err(Error::Configuration(_))));
625
626 let dup = r#"
628[[providers]]
629id = "custom"
630name = "Custom"
631api_format = "openai-responses"
632auth_schemes = ["api-key"]
633
634[[models]]
635provider = "custom"
636id = "m1"
637display_name = "Same"
638
639[[models]]
640provider = "custom"
641id = "m2"
642display_name = "Same"
643"#;
644 fs::write(&dup_path, dup).unwrap();
645 let res2 = ModelRegistry::load(&[dup_path.to_string_lossy().to_string()]);
646 assert!(matches!(res2, Err(Error::Configuration(_))));
647 }
648
649 #[test]
650 fn test_duplicate_aliases_across_providers_error() {
651 use std::fs;
653 use tempfile::TempDir;
654
655 let dir = TempDir::new().unwrap();
656 let path = dir.path().join("alias_conflict.toml");
657 let toml = r#"
658[[providers]]
659id = "p1"
660name = "P1"
661api_format = "openai-responses"
662auth_schemes = ["api-key"]
663
664[[providers]]
665id = "p2"
666name = "P2"
667api_format = "openai-responses"
668auth_schemes = ["api-key"]
669
670[[models]]
671provider = "p1"
672id = "m1"
673aliases = ["shared"]
674
675[[models]]
676provider = "p2"
677id = "m2"
678aliases = ["shared"]
679"#;
680 fs::write(&path, toml).unwrap();
681 let res = ModelRegistry::load(&[path.to_string_lossy().to_string()]);
682 assert!(matches!(res, Err(Error::Configuration(_))));
683 }
684}