steer_core/
model_registry.rs

1use std::collections::{HashMap, HashSet};
2use std::path::{Path, PathBuf};
3
4use tracing::debug;
5
6use crate::config::model::{ModelConfig, ModelId};
7use crate::config::provider::ProviderId;
8use crate::config::toml_types::Catalog;
9use crate::error::Error;
10
11const DEFAULT_CATALOG_TOML: &str = include_str!("../assets/default_catalog.toml");
12
13/// Registry containing all available model configurations.
14#[derive(Debug, Clone)]
15pub struct ModelRegistry {
16    /// Map of ModelId to ModelConfig for fast lookups.
17    models: HashMap<ModelId, ModelConfig>,
18    /// Map of aliases to ModelIds for alias resolution.
19    aliases: HashMap<String, ModelId>,
20    /// Set of providers that have at least one model in the registry (for fast provider checks).
21    providers: HashSet<ProviderId>,
22}
23
24impl ModelRegistry {
25    /// Load the model registry with optional additional catalog files.
26    ///
27    /// Merge order (later overrides earlier):
28    /// 1. Built-in defaults from embedded catalog
29    /// 2. Discovered catalogs (project, then user)
30    /// 3. Additional catalog files specified
31    pub fn load(additional_catalogs: &[String]) -> Result<Self, Error> {
32        // First, load the built-in models from embedded catalog
33        let builtin_catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML)
34            .map_err(|e| Error::Configuration(format!("Failed to parse default catalog: {e}")))?;
35
36        // Convert TOML models to ModelConfig
37        let mut models: Vec<ModelConfig> = builtin_catalog
38            .models
39            .into_iter()
40            .map(ModelConfig::from)
41            .collect();
42
43        // Validate providers exist for built-in models
44        let mut known_providers: HashMap<ProviderId, bool> = HashMap::new();
45        for p in builtin_catalog.providers {
46            known_providers.insert(ProviderId(p.id), true);
47        }
48
49        // Load discovered catalogs (user + project)
50        for path in Self::discover_catalog_paths() {
51            if let Some(catalog) = Self::load_catalog_file(&path)? {
52                for p in catalog.providers {
53                    known_providers.insert(ProviderId(p.id), true);
54                }
55                let more_models: Vec<ModelConfig> =
56                    catalog.models.into_iter().map(ModelConfig::from).collect();
57                Self::merge_models(&mut models, more_models);
58            }
59        }
60
61        // Load additional catalog files
62        for catalog_path in additional_catalogs {
63            if let Some(catalog) = Self::load_catalog_file(Path::new(catalog_path))? {
64                // Add new providers to known set
65                for p in catalog.providers {
66                    known_providers.insert(ProviderId(p.id), true);
67                }
68
69                // Merge models
70                let catalog_models: Vec<ModelConfig> =
71                    catalog.models.into_iter().map(ModelConfig::from).collect();
72                Self::merge_models(&mut models, catalog_models);
73            }
74        }
75
76        // Validate all models reference known providers
77        for model in &models {
78            if !known_providers.contains_key(&model.provider) {
79                return Err(Error::Configuration(format!(
80                    "Model '{}' references unknown provider '{}'",
81                    model.id, model.provider
82                )));
83            }
84        }
85
86        // Build the registry from the merged models
87        let mut registry = Self {
88            models: HashMap::new(),
89            aliases: HashMap::new(),
90            providers: HashSet::new(),
91        };
92
93        for model in models {
94            let model_id = (model.provider.clone(), model.id.clone());
95
96            // Track provider presence
97            registry.providers.insert(model.provider.clone());
98
99            // Store aliases, ensuring global uniqueness; trim and reject empty aliases
100            for raw in &model.aliases {
101                let alias = raw.trim();
102                if alias.is_empty() {
103                    return Err(Error::Configuration(format!(
104                        "Empty alias found for {}/{}",
105                        model_id.0.storage_key(),
106                        model_id.1,
107                    )));
108                }
109                if let Some(existing) = registry.aliases.get(alias) {
110                    if existing != &model_id {
111                        return Err(Error::Configuration(format!(
112                            "Duplicate alias '{}' used by {}/{} and {}/{}",
113                            alias,
114                            existing.0.storage_key(),
115                            existing.1,
116                            model_id.0.storage_key(),
117                            model_id.1,
118                        )));
119                    }
120                }
121                registry.aliases.insert(alias.to_string(), model_id.clone());
122            }
123
124            // Store model
125            registry.models.insert(model_id, model);
126        }
127
128        debug!(
129            target: "model_registry::load",
130            "Loaded models: {:?}",
131            registry.models
132        );
133
134        // Validate display_name values: non-empty, unique per provider
135        {
136            let mut seen: HashMap<ProviderId, HashSet<String>> = HashMap::new();
137            for ((prov, _mid), cfg) in &registry.models {
138                if let Some(name_raw) = cfg.display_name.as_deref() {
139                    let name = name_raw.trim();
140                    if name.is_empty() {
141                        return Err(Error::Configuration(format!(
142                            "Invalid display_name '{}' for {}/{}",
143                            name_raw,
144                            prov.storage_key(),
145                            cfg.id
146                        )));
147                    }
148                    let set = seen.entry(prov.clone()).or_default();
149                    if !set.insert(name.to_string()) {
150                        return Err(Error::Configuration(format!(
151                            "Duplicate display_name '{}' for provider {}",
152                            name,
153                            prov.storage_key()
154                        )));
155                    }
156                }
157            }
158        }
159
160        // Validate alias collisions across providers (already enforced during build)
161        // Add a targeted test to ensure cross-provider duplicate aliases error out.
162
163        Ok(registry)
164    }
165
166    /// Get a model by its ID.
167    pub fn get(&self, id: &ModelId) -> Option<&ModelConfig> {
168        self.models.get(id)
169    }
170
171    /// Find a model by its alias.
172    pub fn by_alias(&self, alias: &str) -> Option<&ModelConfig> {
173        self.aliases.get(alias).and_then(|id| self.models.get(id))
174    }
175    /// Resolve a model string to a ModelId.
176    /// - If input contains '/', treats as 'provider/<id|alias>' and resolves accordingly
177    ///   Note: model IDs may themselves contain '/', so everything after the first '/'
178    ///   is treated as the model ID or alias.
179    /// - Otherwise, looks up by alias
180    /// - Returns error if not found or invalid
181    pub fn resolve(&self, input: &str) -> Result<ModelId, Error> {
182        if let Some((provider_str, part_raw)) = input.split_once('/') {
183            // Parse provider directly and validate it exists in the registry
184            let provider: ProviderId = ProviderId(provider_str.to_string());
185            let provider_known = self.providers.contains(&provider);
186            if !provider_known {
187                return Err(Error::Configuration(format!(
188                    "Unknown provider: {provider_str}"
189                )));
190            }
191
192            let part = part_raw.trim();
193            if part.is_empty() {
194                return Err(Error::Configuration(
195                    "Model name cannot be empty".to_string(),
196                ));
197            }
198
199            // 1) Try exact model id match (ID can include '/')
200            let candidate = (provider.clone(), part.to_string());
201            if self.models.contains_key(&candidate) {
202                return Ok(candidate);
203            }
204
205            // 2) Try alias scoped to the provider
206            if let Some(alias_id) = self.aliases.get(part) {
207                if alias_id.0 == provider {
208                    return Ok(alias_id.clone());
209                }
210            }
211
212            Err(Error::Configuration(format!(
213                "Unknown model or alias: {input}"
214            )))
215        } else {
216            self.by_alias(input)
217                .map(|config| (config.provider.clone(), config.id.clone()))
218                .ok_or_else(|| Error::Configuration(format!("Unknown model or alias: {input}")))
219        }
220    }
221
222    pub fn recommended(&self) -> impl Iterator<Item = &ModelConfig> {
223        self.models.values().filter(|model| model.recommended)
224    }
225
226    /// Get all models in the registry
227    pub fn all(&self) -> impl Iterator<Item = &ModelConfig> {
228        self.models.values()
229    }
230
231    /// Load catalog from a specific path.
232    fn load_catalog_file(path: &Path) -> Result<Option<Catalog>, Error> {
233        if !path.exists() {
234            return Ok(None);
235        }
236
237        let content = std::fs::read_to_string(path).map_err(Error::Io)?;
238        // Parse as full catalog only
239        let catalog: Catalog = toml::from_str(&content).map_err(|e| {
240            Error::Configuration(format!(
241                "Failed to parse catalog at {}: {}",
242                path.display(),
243                e
244            ))
245        })?;
246        Ok(Some(catalog))
247    }
248
249    /// Determine default discovery paths for catalogs (user + project)
250    fn discover_catalog_paths() -> Vec<PathBuf> {
251        // Standardized discovery paths via utils::paths
252        let paths: Vec<PathBuf> = crate::utils::paths::AppPaths::discover_catalogs();
253        // Do not filter by existence here; load() already checks existence when reading
254        paths
255    }
256
257    /// Merge user models into the base models file.
258    /// Arrays are appended, scalar fields use last-write-wins.
259    fn merge_models(base: &mut Vec<ModelConfig>, user_models: Vec<ModelConfig>) {
260        // Create a map of existing models by (provider, id) for efficient lookup
261        let mut existing_models: HashMap<(ProviderId, String), usize> = HashMap::new();
262        for (idx, model) in base.iter().enumerate() {
263            existing_models.insert((model.provider.clone(), model.id.clone()), idx);
264        }
265
266        // Process each user model
267        for user_model in user_models {
268            let key = (user_model.provider.clone(), user_model.id.clone());
269
270            if let Some(&idx) = existing_models.get(&key) {
271                // Model exists - merge it
272                base[idx].merge_with(user_model);
273            } else {
274                // New model - add it
275                base.push(user_model);
276            }
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::config::provider;
285
286    #[test]
287    fn test_load_builtin_models() {
288        // Test that we can parse the built-in catalog
289        let catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML).unwrap();
290        assert!(!catalog.models.is_empty());
291        assert!(!catalog.providers.is_empty());
292
293        // Check that we have some expected models
294        let has_claude = catalog
295            .models
296            .iter()
297            .any(|m| m.provider == "anthropic" && m.id.contains("claude"));
298        assert!(has_claude, "Should have at least one Claude model");
299    }
300
301    #[test]
302    fn test_registry_creation() {
303        // Create a test catalog
304        let toml = r#"
305[[providers]]
306id = "anthropic"
307name = "Anthropic"
308api_format = "anthropic"
309auth_schemes = ["api-key"]
310
311[[models]]
312provider = "anthropic"
313id = "test-model"
314aliases = ["test", "tm"]
315recommended = true
316parameters = { thinking_config = { enabled = true } }
317"#;
318
319        let catalog: Catalog = toml::from_str(toml).unwrap();
320        // Convert Catalog to ModelConfig list using From trait
321        let models: Vec<ModelConfig> = catalog.models.into_iter().map(ModelConfig::from).collect();
322
323        let mut registry = ModelRegistry {
324            models: HashMap::new(),
325            aliases: HashMap::new(),
326            providers: HashSet::new(),
327        };
328
329        for model in models {
330            let model_id = (model.provider.clone(), model.id.clone());
331
332            // track provider
333            registry.providers.insert(model.provider.clone());
334
335            for alias in &model.aliases {
336                registry.aliases.insert(alias.clone(), model_id.clone());
337            }
338
339            registry.models.insert(model_id, model);
340        }
341
342        // Test get
343        let model_id = (provider::anthropic(), "test-model".to_string());
344        let model = registry.get(&model_id).unwrap();
345        assert_eq!(model.id, "test-model");
346        assert!(model.recommended);
347
348        // Test parameters were parsed correctly
349        assert!(model.parameters.is_some());
350        let params = model.parameters.unwrap();
351        assert!(params.thinking_config.is_some());
352        assert!(params.thinking_config.unwrap().enabled);
353
354        // Test by_alias
355        let model_by_alias = registry.by_alias("test").unwrap();
356        assert_eq!(model_by_alias.id, "test-model");
357
358        let model_by_alias2 = registry.by_alias("tm").unwrap();
359        assert_eq!(model_by_alias2.id, "test-model");
360
361        // Test recommended
362        let recommended: Vec<_> = registry.recommended().collect();
363        assert_eq!(recommended.len(), 1);
364        assert_eq!(recommended[0].id, "test-model");
365    }
366
367    #[test]
368    fn test_merge_models() {
369        let base_toml = r#"
370[[providers]]
371id = "anthropic"
372name = "Anthropic"
373api_format = "anthropic"
374auth_schemes = ["api-key"]
375
376[[providers]]
377id = "openai"
378name = "OpenAI"
379api_format = "openai-responses"
380auth_schemes = ["api-key"]
381
382[[models]]
383provider = "anthropic"
384id = "claude-3"
385aliases = ["claude"]
386recommended = false
387parameters = { temperature = 0.7, max_tokens = 2048 }
388
389[[models]]
390provider = "openai"
391id = "gpt-4"
392aliases = ["gpt"]
393recommended = true
394"#;
395
396        let user_toml = r#"
397[[providers]]
398id = "google"
399name = "Google"
400api_format = "google"
401auth_schemes = ["api-key"]
402
403[[models]]
404provider = "anthropic"
405id = "claude-3"
406aliases = ["c3", "claude3"]
407recommended = true
408parameters = { temperature = 0.9, thinking_config = { enabled = true } }
409
410[[models]]
411provider = "google"
412id = "gemini-pro"
413aliases = ["gemini"]
414recommended = true
415parameters = { temperature = 0.5, top_p = 0.95 }
416"#;
417
418        let base: Catalog = toml::from_str(base_toml).unwrap();
419        let user: Catalog = toml::from_str(user_toml).unwrap();
420
421        // Convert to ModelConfig using From trait
422        let base_models: Vec<_> = base.models.into_iter().map(ModelConfig::from).collect();
423        let user_models: Vec<_> = user.models.into_iter().map(ModelConfig::from).collect();
424
425        let mut base_models_mut = base_models;
426        ModelRegistry::merge_models(&mut base_models_mut, user_models);
427
428        // Check that we have 3 models total
429        assert_eq!(base_models_mut.len(), 3);
430
431        // Check the merged Claude model
432        let claude = base_models_mut
433            .iter()
434            .find(|m| m.provider == provider::anthropic() && m.id == "claude-3")
435            .unwrap();
436
437        // Aliases should be merged
438        assert_eq!(claude.aliases.len(), 3);
439        assert!(claude.aliases.contains(&"claude".to_string()));
440        assert!(claude.aliases.contains(&"c3".to_string()));
441        assert!(claude.aliases.contains(&"claude3".to_string()));
442
443        // Scalar fields should be overridden
444        assert!(claude.recommended);
445
446        // Parameters should be merged (user overrides base)
447        assert!(claude.parameters.is_some());
448        let claude_params = claude.parameters.unwrap();
449        assert_eq!(claude_params.temperature, Some(0.9)); // overridden from 0.7
450        assert_eq!(claude_params.max_tokens, Some(2048)); // kept from base
451        assert!(claude_params.thinking_config.is_some());
452        assert!(claude_params.thinking_config.unwrap().enabled);
453
454        // Check that GPT-4 is unchanged
455        let gpt4 = base_models_mut
456            .iter()
457            .find(|m| m.provider == provider::openai() && m.id == "gpt-4")
458            .unwrap();
459        assert!(gpt4.recommended);
460        assert!(gpt4.parameters.is_none()); // No parameters in either base or user
461
462        // Check that new model was added
463        let gemini = base_models_mut
464            .iter()
465            .find(|m| m.provider == provider::google() && m.id == "gemini-pro")
466            .unwrap();
467        assert!(gemini.recommended);
468        assert!(gemini.parameters.is_some());
469        let gemini_params = gemini.parameters.unwrap();
470        assert_eq!(gemini_params.temperature, Some(0.5));
471        assert_eq!(gemini_params.top_p, Some(0.95));
472    }
473
474    #[test]
475    fn test_load_catalog_from_path() {
476        use std::fs;
477        use tempfile::TempDir;
478
479        let dir = TempDir::new().unwrap();
480        let config_path = dir.path().join("test_catalog.toml");
481
482        let config = r#"
483[[providers]]
484id = "anthropic"
485name = "Anthropic"
486api_format = "anthropic"
487auth_schemes = ["api-key"]
488
489[[models]]
490provider = "anthropic"
491id = "test-model"
492aliases = ["test"]
493recommended = true
494"#;
495
496        fs::write(&config_path, config).unwrap();
497
498        let result = ModelRegistry::load_catalog_file(&config_path).unwrap();
499        assert!(result.is_some());
500
501        let catalog = result.unwrap();
502        assert_eq!(catalog.models.len(), 1);
503        assert_eq!(catalog.models[0].id, "test-model");
504        assert_eq!(catalog.providers.len(), 1);
505        assert_eq!(catalog.providers[0].id, "anthropic");
506    }
507
508    #[test]
509    fn test_resolve_by_provider_and_parts() {
510        // Build a small registry manually
511        let mut registry = ModelRegistry {
512            models: HashMap::new(),
513            aliases: HashMap::new(),
514            providers: HashSet::new(),
515        };
516        let prov = provider::anthropic();
517
518        let m1 = ModelConfig {
519            provider: prov.clone(),
520            id: "id-1".to_string(),
521            display_name: Some("NiceName".to_string()),
522            aliases: vec!["alias1".into()],
523            recommended: false,
524            parameters: None,
525        };
526        let m2 = ModelConfig {
527            provider: prov.clone(),
528            id: "id-2".to_string(),
529            display_name: Some("Other".to_string()),
530            aliases: vec!["alias2".into()],
531            recommended: false,
532            parameters: None,
533        };
534        let id1 = (prov.clone(), m1.id.clone());
535        let id2 = (prov.clone(), m2.id.clone());
536        registry.aliases.insert("alias1".into(), id1.clone());
537        registry.aliases.insert("alias2".into(), id2.clone());
538        registry.models.insert(id1.clone(), m1.clone());
539        registry.models.insert(id2.clone(), m2.clone());
540        registry.providers.insert(prov.clone());
541
542        // provider/id
543        assert_eq!(registry.resolve("anthropic/id-1").unwrap(), id1);
544        // provider/display_name should NOT resolve
545        assert!(registry.resolve("anthropic/NiceName").is_err());
546        // provider/alias should resolve if alias maps to this provider
547        assert_eq!(registry.resolve("anthropic/alias2").unwrap(), id2);
548        // unknown
549        assert!(registry.resolve("anthropic/does-not-exist").is_err());
550    }
551
552    #[test]
553    fn test_resolve_by_display_name_is_not_supported() {
554        // Two models with same display name under same provider
555        let mut registry = ModelRegistry {
556            models: HashMap::new(),
557            aliases: HashMap::new(),
558            providers: HashSet::new(),
559        };
560        let prov = provider::anthropic();
561        let m1 = ModelConfig {
562            provider: prov.clone(),
563            id: "id-1".into(),
564            display_name: Some("Same".into()),
565            aliases: vec![],
566            recommended: false,
567            parameters: None,
568        };
569        let m2 = ModelConfig {
570            provider: prov.clone(),
571            id: "id-2".into(),
572            display_name: Some("Same".into()),
573            aliases: vec![],
574            recommended: false,
575            parameters: None,
576        };
577        let id1 = (prov.clone(), m1.id.clone());
578        let id2 = (prov.clone(), m2.id.clone());
579        registry.models.insert(id1, m1);
580        registry.models.insert(id2, m2);
581        registry.providers.insert(prov.clone());
582
583        // Resolving by display name should not work
584        let err = registry.resolve("anthropic/Same").unwrap_err();
585        match err {
586            Error::Configuration(msg) => assert!(msg.contains("Unknown model or alias")),
587            _ => panic!("unexpected error type"),
588        }
589    }
590
591    #[test]
592    fn test_load_rejects_invalid_or_duplicate_display_names() {
593        use std::fs;
594        use tempfile::TempDir;
595
596        let dir = TempDir::new().unwrap();
597        let bad_path = dir.path().join("bad_catalog.toml");
598        let dup_path = dir.path().join("dup_catalog.toml");
599
600        // Invalid: empty display_name
601        let bad = r#"
602[[providers]]
603id = "custom"
604name = "Custom"
605api_format = "openai-responses"
606auth_schemes = ["api-key"]
607
608[[models]]
609provider = "custom"
610id = "m1"
611display_name = ""
612"#;
613        fs::write(&bad_path, bad).unwrap();
614        let res = ModelRegistry::load(&[bad_path.to_string_lossy().to_string()]);
615        assert!(matches!(res, Err(Error::Configuration(_))));
616
617        // Duplicate display_name within provider
618        let dup = r#"
619[[providers]]
620id = "custom"
621name = "Custom"
622api_format = "openai-responses"
623auth_schemes = ["api-key"]
624
625[[models]]
626provider = "custom"
627id = "m1"
628display_name = "Same"
629
630[[models]]
631provider = "custom"
632id = "m2"
633display_name = "Same"
634"#;
635        fs::write(&dup_path, dup).unwrap();
636        let res2 = ModelRegistry::load(&[dup_path.to_string_lossy().to_string()]);
637        assert!(matches!(res2, Err(Error::Configuration(_))));
638    }
639
640    #[test]
641    fn test_duplicate_aliases_across_providers_error() {
642        // Two providers, same alias used by different models => should error on load
643        use std::fs;
644        use tempfile::TempDir;
645
646        let dir = TempDir::new().unwrap();
647        let path = dir.path().join("alias_conflict.toml");
648        let toml = r#"
649[[providers]]
650id = "p1"
651name = "P1"
652api_format = "openai-responses"
653auth_schemes = ["api-key"]
654
655[[providers]]
656id = "p2"
657name = "P2"
658api_format = "openai-responses"
659auth_schemes = ["api-key"]
660
661[[models]]
662provider = "p1"
663id = "m1"
664aliases = ["shared"]
665
666[[models]]
667provider = "p2"
668id = "m2"
669aliases = ["shared"]
670"#;
671        fs::write(&path, toml).unwrap();
672        let res = ModelRegistry::load(&[path.to_string_lossy().to_string()]);
673        assert!(matches!(res, Err(Error::Configuration(_))));
674    }
675}