1use std::collections::{HashMap, HashSet};
2use std::path::{Path, PathBuf};
3
4use tracing::debug;
5
6use crate::config::model::{ModelConfig, ModelId};
7use crate::config::provider::ProviderId;
8use crate::config::toml_types::Catalog;
9use crate::error::Error;
10
11const DEFAULT_CATALOG_TOML: &str = include_str!("../assets/default_catalog.toml");
12
13#[derive(Debug, Clone)]
15pub struct ModelRegistry {
16 models: HashMap<ModelId, ModelConfig>,
18 aliases: HashMap<String, ModelId>,
20 providers: HashSet<ProviderId>,
22}
23
24impl ModelRegistry {
25 pub fn load(additional_catalogs: &[String]) -> Result<Self, Error> {
32 let builtin_catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML)
34 .map_err(|e| Error::Configuration(format!("Failed to parse default catalog: {e}")))?;
35
36 let mut models: Vec<ModelConfig> = builtin_catalog
38 .models
39 .into_iter()
40 .map(ModelConfig::from)
41 .collect();
42
43 let mut known_providers: HashMap<ProviderId, bool> = HashMap::new();
45 for p in builtin_catalog.providers {
46 known_providers.insert(ProviderId(p.id), true);
47 }
48
49 for path in Self::discover_catalog_paths() {
51 if let Some(catalog) = Self::load_catalog_file(&path)? {
52 for p in catalog.providers {
53 known_providers.insert(ProviderId(p.id), true);
54 }
55 let more_models: Vec<ModelConfig> =
56 catalog.models.into_iter().map(ModelConfig::from).collect();
57 Self::merge_models(&mut models, more_models);
58 }
59 }
60
61 for catalog_path in additional_catalogs {
63 if let Some(catalog) = Self::load_catalog_file(Path::new(catalog_path))? {
64 for p in catalog.providers {
66 known_providers.insert(ProviderId(p.id), true);
67 }
68
69 let catalog_models: Vec<ModelConfig> =
71 catalog.models.into_iter().map(ModelConfig::from).collect();
72 Self::merge_models(&mut models, catalog_models);
73 }
74 }
75
76 for model in &models {
78 if !known_providers.contains_key(&model.provider) {
79 return Err(Error::Configuration(format!(
80 "Model '{}' references unknown provider '{}'",
81 model.id, model.provider
82 )));
83 }
84
85 let has_required_max_output_tokens = model
86 .parameters
87 .as_ref()
88 .and_then(|parameters| parameters.max_output_tokens)
89 .is_some();
90 if !has_required_max_output_tokens {
91 return Err(Error::Configuration(format!(
92 "Model '{}' is missing required parameters.max_output_tokens",
93 model.id
94 )));
95 }
96 }
97
98 let mut registry = Self {
100 models: HashMap::new(),
101 aliases: HashMap::new(),
102 providers: HashSet::new(),
103 };
104
105 for model in models {
106 let model_id = ModelId::new(model.provider.clone(), model.id.clone());
107
108 registry.providers.insert(model.provider.clone());
110
111 for raw in &model.aliases {
113 let alias = raw.trim();
114 if alias.is_empty() {
115 return Err(Error::Configuration(format!(
116 "Empty alias found for {}/{}",
117 model_id.provider.storage_key(),
118 model_id.id.as_str(),
119 )));
120 }
121 if let Some(existing) = registry.aliases.get(alias)
122 && existing != &model_id
123 {
124 return Err(Error::Configuration(format!(
125 "Duplicate alias '{}' used by {}/{} and {}/{}",
126 alias,
127 existing.provider.storage_key(),
128 existing.id.as_str(),
129 model_id.provider.storage_key(),
130 model_id.id.as_str(),
131 )));
132 }
133 registry.aliases.insert(alias.to_string(), model_id.clone());
134 }
135
136 registry.models.insert(model_id, model);
138 }
139
140 debug!(
141 target: "model_registry::load",
142 "Loaded models: {:?}",
143 registry.models
144 );
145
146 {
148 let mut seen: HashMap<ProviderId, HashSet<String>> = HashMap::new();
149 for (model_id, cfg) in ®istry.models {
150 if let Some(name_raw) = cfg.display_name.as_deref() {
151 let name = name_raw.trim();
152 if name.is_empty() {
153 return Err(Error::Configuration(format!(
154 "Invalid display_name '{}' for {}/{}",
155 name_raw,
156 model_id.provider.storage_key(),
157 cfg.id
158 )));
159 }
160 let set = seen.entry(model_id.provider.clone()).or_default();
161 if !set.insert(name.to_string()) {
162 return Err(Error::Configuration(format!(
163 "Duplicate display_name '{}' for provider {}",
164 name,
165 model_id.provider.storage_key()
166 )));
167 }
168 }
169 }
170 }
171
172 Ok(registry)
176 }
177
178 pub fn empty() -> Self {
180 Self {
181 models: HashMap::new(),
182 aliases: HashMap::new(),
183 providers: HashSet::new(),
184 }
185 }
186
187 pub fn get(&self, id: &ModelId) -> Option<&ModelConfig> {
189 self.models.get(id)
190 }
191
192 pub fn by_alias(&self, alias: &str) -> Option<&ModelConfig> {
194 self.aliases.get(alias).and_then(|id| self.models.get(id))
195 }
196 pub fn resolve(&self, input: &str) -> Result<ModelId, Error> {
203 if let Some((provider_str, part_raw)) = input.split_once('/') {
204 let provider: ProviderId = ProviderId(provider_str.to_string());
206 let provider_known = self.providers.contains(&provider);
207 if !provider_known {
208 return Err(Error::Configuration(format!(
209 "Unknown provider: {provider_str}"
210 )));
211 }
212
213 let part = part_raw.trim();
214 if part.is_empty() {
215 return Err(Error::Configuration(
216 "Model name cannot be empty".to_string(),
217 ));
218 }
219
220 let candidate = ModelId::new(provider.clone(), part.to_string());
222 if self.models.contains_key(&candidate) {
223 return Ok(candidate);
224 }
225
226 if let Some(alias_id) = self.aliases.get(part)
228 && alias_id.provider == provider
229 {
230 return Ok(alias_id.clone());
231 }
232
233 Err(Error::Configuration(format!(
234 "Unknown model or alias: {input}"
235 )))
236 } else {
237 self.by_alias(input)
238 .map(|config| ModelId::new(config.provider.clone(), config.id.clone()))
239 .ok_or_else(|| Error::Configuration(format!("Unknown model or alias: {input}")))
240 }
241 }
242
243 pub fn recommended(&self) -> impl Iterator<Item = &ModelConfig> {
244 self.models.values().filter(|model| model.recommended)
245 }
246
247 pub fn all(&self) -> impl Iterator<Item = &ModelConfig> {
249 self.models.values()
250 }
251
252 fn load_catalog_file(path: &Path) -> Result<Option<Catalog>, Error> {
254 if !path.exists() {
255 return Ok(None);
256 }
257
258 let content = std::fs::read_to_string(path).map_err(Error::Io)?;
259 let catalog: Catalog = toml::from_str(&content).map_err(|e| {
261 Error::Configuration(format!(
262 "Failed to parse catalog at {}: {}",
263 path.display(),
264 e
265 ))
266 })?;
267 Ok(Some(catalog))
268 }
269
270 fn discover_catalog_paths() -> Vec<PathBuf> {
272 let paths: Vec<PathBuf> = crate::utils::paths::AppPaths::discover_catalogs();
274 paths
276 }
277
278 fn merge_models(base: &mut Vec<ModelConfig>, user_models: Vec<ModelConfig>) {
281 let mut existing_models: HashMap<ModelId, usize> = HashMap::new();
283 for (idx, model) in base.iter().enumerate() {
284 existing_models.insert(ModelId::new(model.provider.clone(), model.id.clone()), idx);
285 }
286
287 for user_model in user_models {
289 let key = ModelId::new(user_model.provider.clone(), user_model.id.clone());
290
291 if let Some(&idx) = existing_models.get(&key) {
292 base[idx].merge_with(user_model);
294 } else {
295 base.push(user_model);
297 }
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::config::provider;
306
307 #[test]
308 fn test_load_builtin_models() {
309 let catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML).unwrap();
311 assert!(!catalog.models.is_empty());
312 assert!(!catalog.providers.is_empty());
313
314 let has_claude = catalog
316 .models
317 .iter()
318 .any(|m| m.provider == "anthropic" && m.id.contains("claude"));
319 assert!(has_claude, "Should have at least one Claude model");
320 }
321
322 #[test]
323 fn test_registry_creation() {
324 let toml = r#"
326[[providers]]
327id = "anthropic"
328name = "Anthropic"
329api_format = "anthropic"
330auth_schemes = ["api-key"]
331
332[[models]]
333provider = "anthropic"
334id = "test-model"
335aliases = ["test", "tm"]
336recommended = true
337parameters = { thinking_config = { enabled = true } }
338"#;
339
340 let catalog: Catalog = toml::from_str(toml).unwrap();
341 let models: Vec<ModelConfig> = catalog.models.into_iter().map(ModelConfig::from).collect();
343
344 let mut registry = ModelRegistry {
345 models: HashMap::new(),
346 aliases: HashMap::new(),
347 providers: HashSet::new(),
348 };
349
350 for model in models {
351 let model_id = ModelId::new(model.provider.clone(), model.id.clone());
352
353 registry.providers.insert(model.provider.clone());
355
356 for alias in &model.aliases {
357 registry.aliases.insert(alias.clone(), model_id.clone());
358 }
359
360 registry.models.insert(model_id, model);
361 }
362
363 let model_id = ModelId::new(provider::anthropic(), "test-model");
365 let model = registry.get(&model_id).unwrap();
366 assert_eq!(model.id, "test-model");
367 assert!(model.recommended);
368
369 assert!(model.parameters.is_some());
371 let params = model.parameters.unwrap();
372 assert!(params.thinking_config.is_some());
373 assert!(params.thinking_config.unwrap().enabled);
374
375 let model_by_alias = registry.by_alias("test").unwrap();
377 assert_eq!(model_by_alias.id, "test-model");
378
379 let model_by_alias2 = registry.by_alias("tm").unwrap();
380 assert_eq!(model_by_alias2.id, "test-model");
381
382 let recommended: Vec<_> = registry.recommended().collect();
384 assert_eq!(recommended.len(), 1);
385 assert_eq!(recommended[0].id, "test-model");
386 }
387
388 #[test]
389 fn test_merge_models() {
390 let base_toml = r#"
391[[providers]]
392id = "anthropic"
393name = "Anthropic"
394api_format = "anthropic"
395auth_schemes = ["api-key"]
396
397[[providers]]
398id = "openai"
399name = "OpenAI"
400api_format = "openai-responses"
401auth_schemes = ["api-key"]
402
403[[models]]
404provider = "anthropic"
405id = "claude-3"
406aliases = ["claude"]
407recommended = false
408parameters = { temperature = 0.7, max_output_tokens = 2048 }
409
410[[models]]
411provider = "openai"
412id = "gpt-4"
413aliases = ["gpt"]
414recommended = true
415"#;
416
417 let user_toml = r#"
418[[providers]]
419id = "google"
420name = "Google"
421api_format = "google"
422auth_schemes = ["api-key"]
423
424[[models]]
425provider = "anthropic"
426id = "claude-3"
427aliases = ["c3", "claude3"]
428recommended = true
429parameters = { temperature = 0.9, thinking_config = { enabled = true } }
430
431[[models]]
432provider = "google"
433id = "gemini-pro"
434aliases = ["gemini"]
435recommended = true
436parameters = { temperature = 0.5, max_output_tokens = 4096, top_p = 0.95 }
437"#;
438
439 let base: Catalog = toml::from_str(base_toml).unwrap();
440 let user: Catalog = toml::from_str(user_toml).unwrap();
441
442 let base_models: Vec<_> = base.models.into_iter().map(ModelConfig::from).collect();
444 let user_models: Vec<_> = user.models.into_iter().map(ModelConfig::from).collect();
445
446 let mut base_models_mut = base_models;
447 ModelRegistry::merge_models(&mut base_models_mut, user_models);
448
449 assert_eq!(base_models_mut.len(), 3);
451
452 let claude = base_models_mut
454 .iter()
455 .find(|m| m.provider == provider::anthropic() && m.id == "claude-3")
456 .unwrap();
457
458 assert_eq!(claude.aliases.len(), 3);
460 assert!(claude.aliases.contains(&"claude".to_string()));
461 assert!(claude.aliases.contains(&"c3".to_string()));
462 assert!(claude.aliases.contains(&"claude3".to_string()));
463
464 assert!(claude.recommended);
466
467 assert!(claude.parameters.is_some());
469 let claude_params = claude.parameters.unwrap();
470 assert_eq!(claude_params.temperature, Some(0.9)); assert_eq!(claude_params.max_output_tokens, Some(2048)); assert!(claude_params.thinking_config.is_some());
473 assert!(claude_params.thinking_config.unwrap().enabled);
474
475 let gpt4 = base_models_mut
477 .iter()
478 .find(|m| m.provider == provider::openai() && m.id == "gpt-4")
479 .unwrap();
480 assert!(gpt4.recommended);
481 assert!(gpt4.parameters.is_none()); let gemini = base_models_mut
485 .iter()
486 .find(|m| m.provider == provider::google() && m.id == "gemini-pro")
487 .unwrap();
488 assert!(gemini.recommended);
489 assert!(gemini.parameters.is_some());
490 let gemini_params = gemini.parameters.unwrap();
491 assert_eq!(gemini_params.temperature, Some(0.5));
492 assert_eq!(gemini_params.top_p, Some(0.95));
493 }
494
495 #[test]
496 fn test_load_catalog_from_path() {
497 use std::fs;
498 use tempfile::TempDir;
499
500 let dir = TempDir::new().unwrap();
501 let config_path = dir.path().join("test_catalog.toml");
502
503 let config = r#"
504[[providers]]
505id = "anthropic"
506name = "Anthropic"
507api_format = "anthropic"
508auth_schemes = ["api-key"]
509
510[[models]]
511provider = "anthropic"
512id = "test-model"
513aliases = ["test"]
514recommended = true
515parameters = { max_output_tokens = 4096 }
516"#;
517
518 fs::write(&config_path, config).unwrap();
519
520 let result = ModelRegistry::load_catalog_file(&config_path).unwrap();
521 assert!(result.is_some());
522
523 let catalog = result.unwrap();
524 assert_eq!(catalog.models.len(), 1);
525 assert_eq!(catalog.models[0].id, "test-model");
526 assert_eq!(catalog.providers.len(), 1);
527 assert_eq!(catalog.providers[0].id, "anthropic");
528 }
529
530 #[test]
531 fn test_resolve_by_provider_and_parts() {
532 let mut registry = ModelRegistry {
534 models: HashMap::new(),
535 aliases: HashMap::new(),
536 providers: HashSet::new(),
537 };
538 let prov = provider::anthropic();
539
540 let m1 = ModelConfig {
541 provider: prov.clone(),
542 id: "id-1".to_string(),
543 display_name: Some("NiceName".to_string()),
544 aliases: vec!["alias1".into()],
545 recommended: false,
546 context_window_tokens: None,
547 parameters: None,
548 };
549 let m2 = ModelConfig {
550 provider: prov.clone(),
551 id: "id-2".to_string(),
552 display_name: Some("Other".to_string()),
553 aliases: vec!["alias2".into()],
554 recommended: false,
555 context_window_tokens: None,
556 parameters: None,
557 };
558 let id1 = ModelId::new(prov.clone(), m1.id.clone());
559 let id2 = ModelId::new(prov.clone(), m2.id.clone());
560 registry.aliases.insert("alias1".into(), id1.clone());
561 registry.aliases.insert("alias2".into(), id2.clone());
562 registry.models.insert(id1.clone(), m1.clone());
563 registry.models.insert(id2.clone(), m2.clone());
564 registry.providers.insert(prov.clone());
565
566 assert_eq!(registry.resolve("anthropic/id-1").unwrap(), id1);
568 assert!(registry.resolve("anthropic/NiceName").is_err());
570 assert_eq!(registry.resolve("anthropic/alias2").unwrap(), id2);
572 assert!(registry.resolve("anthropic/does-not-exist").is_err());
574 }
575
576 #[test]
577 fn test_resolve_by_display_name_is_not_supported() {
578 let mut registry = ModelRegistry {
580 models: HashMap::new(),
581 aliases: HashMap::new(),
582 providers: HashSet::new(),
583 };
584 let prov = provider::anthropic();
585 let m1 = ModelConfig {
586 provider: prov.clone(),
587 id: "id-1".into(),
588 display_name: Some("Same".into()),
589 aliases: vec![],
590 recommended: false,
591 context_window_tokens: None,
592 parameters: None,
593 };
594 let m2 = ModelConfig {
595 provider: prov.clone(),
596 id: "id-2".into(),
597 display_name: Some("Same".into()),
598 aliases: vec![],
599 recommended: false,
600 context_window_tokens: None,
601 parameters: None,
602 };
603 let id1 = ModelId::new(prov.clone(), m1.id.clone());
604 let id2 = ModelId::new(prov.clone(), m2.id.clone());
605 registry.models.insert(id1, m1);
606 registry.models.insert(id2, m2);
607 registry.providers.insert(prov.clone());
608
609 let err = registry.resolve("anthropic/Same").unwrap_err();
611 match err {
612 Error::Configuration(msg) => assert!(msg.contains("Unknown model or alias")),
613 _ => panic!("unexpected error type"),
614 }
615 }
616
617 #[test]
618 fn test_load_rejects_invalid_or_duplicate_display_names() {
619 use std::fs;
620 use tempfile::TempDir;
621
622 let dir = TempDir::new().unwrap();
623 let bad_path = dir.path().join("bad_catalog.toml");
624 let dup_path = dir.path().join("dup_catalog.toml");
625
626 let bad = r#"
628[[providers]]
629id = "custom"
630name = "Custom"
631api_format = "openai-responses"
632auth_schemes = ["api-key"]
633
634[[models]]
635provider = "custom"
636id = "m1"
637display_name = ""
638"#;
639 fs::write(&bad_path, bad).unwrap();
640 let res = ModelRegistry::load(&[bad_path.to_string_lossy().to_string()]);
641 assert!(matches!(res, Err(Error::Configuration(_))));
642
643 let dup = r#"
645[[providers]]
646id = "custom"
647name = "Custom"
648api_format = "openai-responses"
649auth_schemes = ["api-key"]
650
651[[models]]
652provider = "custom"
653id = "m1"
654display_name = "Same"
655parameters = { max_output_tokens = 1024 }
656
657[[models]]
658provider = "custom"
659id = "m2"
660display_name = "Same"
661parameters = { max_output_tokens = 2048 }
662"#;
663 fs::write(&dup_path, dup).unwrap();
664 let res2 = ModelRegistry::load(&[dup_path.to_string_lossy().to_string()]);
665 assert!(matches!(res2, Err(Error::Configuration(_))));
666 }
667
668 #[test]
669 fn test_duplicate_aliases_across_providers_error() {
670 use std::fs;
672 use tempfile::TempDir;
673
674 let dir = TempDir::new().unwrap();
675 let path = dir.path().join("alias_conflict.toml");
676 let toml = r#"
677[[providers]]
678id = "p1"
679name = "P1"
680api_format = "openai-responses"
681auth_schemes = ["api-key"]
682
683[[providers]]
684id = "p2"
685name = "P2"
686api_format = "openai-responses"
687auth_schemes = ["api-key"]
688
689[[models]]
690provider = "p1"
691id = "m1"
692aliases = ["shared"]
693parameters = { max_output_tokens = 1024 }
694
695[[models]]
696provider = "p2"
697id = "m2"
698aliases = ["shared"]
699parameters = { max_output_tokens = 1024 }
700"#;
701 fs::write(&path, toml).unwrap();
702 let res = ModelRegistry::load(&[path.to_string_lossy().to_string()]);
703 assert!(matches!(res, Err(Error::Configuration(_))));
704 }
705}