Skip to main content

steer_core/
model_registry.rs

1use std::collections::{HashMap, HashSet};
2use std::path::{Path, PathBuf};
3
4use tracing::debug;
5
6use crate::config::model::{ModelConfig, ModelId};
7use crate::config::provider::ProviderId;
8use crate::config::toml_types::Catalog;
9use crate::error::Error;
10
11const DEFAULT_CATALOG_TOML: &str = include_str!("../assets/default_catalog.toml");
12
13/// Registry containing all available model configurations.
14#[derive(Debug, Clone)]
15pub struct ModelRegistry {
16    /// Map of ModelId to ModelConfig for fast lookups.
17    models: HashMap<ModelId, ModelConfig>,
18    /// Map of aliases to ModelIds for alias resolution.
19    aliases: HashMap<String, ModelId>,
20    /// Set of providers that have at least one model in the registry (for fast provider checks).
21    providers: HashSet<ProviderId>,
22}
23
24impl ModelRegistry {
25    /// Load the model registry with optional additional catalog files.
26    ///
27    /// Merge order (later overrides earlier):
28    /// 1. Built-in defaults from embedded catalog
29    /// 2. Discovered catalogs (project, then user)
30    /// 3. Additional catalog files specified
31    pub fn load(additional_catalogs: &[String]) -> Result<Self, Error> {
32        // First, load the built-in models from embedded catalog
33        let builtin_catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML)
34            .map_err(|e| Error::Configuration(format!("Failed to parse default catalog: {e}")))?;
35
36        // Convert TOML models to ModelConfig
37        let mut models: Vec<ModelConfig> = builtin_catalog
38            .models
39            .into_iter()
40            .map(ModelConfig::from)
41            .collect();
42
43        // Validate providers exist for built-in models
44        let mut known_providers: HashMap<ProviderId, bool> = HashMap::new();
45        for p in builtin_catalog.providers {
46            known_providers.insert(ProviderId(p.id), true);
47        }
48
49        // Load discovered catalogs (user + project)
50        for path in Self::discover_catalog_paths() {
51            if let Some(catalog) = Self::load_catalog_file(&path)? {
52                for p in catalog.providers {
53                    known_providers.insert(ProviderId(p.id), true);
54                }
55                let more_models: Vec<ModelConfig> =
56                    catalog.models.into_iter().map(ModelConfig::from).collect();
57                Self::merge_models(&mut models, more_models);
58            }
59        }
60
61        // Load additional catalog files
62        for catalog_path in additional_catalogs {
63            if let Some(catalog) = Self::load_catalog_file(Path::new(catalog_path))? {
64                // Add new providers to known set
65                for p in catalog.providers {
66                    known_providers.insert(ProviderId(p.id), true);
67                }
68
69                // Merge models
70                let catalog_models: Vec<ModelConfig> =
71                    catalog.models.into_iter().map(ModelConfig::from).collect();
72                Self::merge_models(&mut models, catalog_models);
73            }
74        }
75
76        // Validate all models reference known providers
77        for model in &models {
78            if !known_providers.contains_key(&model.provider) {
79                return Err(Error::Configuration(format!(
80                    "Model '{}' references unknown provider '{}'",
81                    model.id, model.provider
82                )));
83            }
84        }
85
86        // Build the registry from the merged models
87        let mut registry = Self {
88            models: HashMap::new(),
89            aliases: HashMap::new(),
90            providers: HashSet::new(),
91        };
92
93        for model in models {
94            let model_id = ModelId::new(model.provider.clone(), model.id.clone());
95
96            // Track provider presence
97            registry.providers.insert(model.provider.clone());
98
99            // Store aliases, ensuring global uniqueness; trim and reject empty aliases
100            for raw in &model.aliases {
101                let alias = raw.trim();
102                if alias.is_empty() {
103                    return Err(Error::Configuration(format!(
104                        "Empty alias found for {}/{}",
105                        model_id.provider.storage_key(),
106                        model_id.id.as_str(),
107                    )));
108                }
109                if let Some(existing) = registry.aliases.get(alias)
110                    && existing != &model_id
111                {
112                    return Err(Error::Configuration(format!(
113                        "Duplicate alias '{}' used by {}/{} and {}/{}",
114                        alias,
115                        existing.provider.storage_key(),
116                        existing.id.as_str(),
117                        model_id.provider.storage_key(),
118                        model_id.id.as_str(),
119                    )));
120                }
121                registry.aliases.insert(alias.to_string(), model_id.clone());
122            }
123
124            // Store model
125            registry.models.insert(model_id, model);
126        }
127
128        debug!(
129            target: "model_registry::load",
130            "Loaded models: {:?}",
131            registry.models
132        );
133
134        // Validate display_name values: non-empty, unique per provider
135        {
136            let mut seen: HashMap<ProviderId, HashSet<String>> = HashMap::new();
137            for (model_id, cfg) in &registry.models {
138                if let Some(name_raw) = cfg.display_name.as_deref() {
139                    let name = name_raw.trim();
140                    if name.is_empty() {
141                        return Err(Error::Configuration(format!(
142                            "Invalid display_name '{}' for {}/{}",
143                            name_raw,
144                            model_id.provider.storage_key(),
145                            cfg.id
146                        )));
147                    }
148                    let set = seen.entry(model_id.provider.clone()).or_default();
149                    if !set.insert(name.to_string()) {
150                        return Err(Error::Configuration(format!(
151                            "Duplicate display_name '{}' for provider {}",
152                            name,
153                            model_id.provider.storage_key()
154                        )));
155                    }
156                }
157            }
158        }
159
160        // Validate alias collisions across providers (already enforced during build)
161        // Add a targeted test to ensure cross-provider duplicate aliases error out.
162
163        Ok(registry)
164    }
165
166    /// Build an empty registry (primarily for fallbacks/tests).
167    pub fn empty() -> Self {
168        Self {
169            models: HashMap::new(),
170            aliases: HashMap::new(),
171            providers: HashSet::new(),
172        }
173    }
174
175    /// Get a model by its ID.
176    pub fn get(&self, id: &ModelId) -> Option<&ModelConfig> {
177        self.models.get(id)
178    }
179
180    /// Find a model by its alias.
181    pub fn by_alias(&self, alias: &str) -> Option<&ModelConfig> {
182        self.aliases.get(alias).and_then(|id| self.models.get(id))
183    }
184    /// Resolve a model string to a ModelId.
185    /// - If input contains '/', treats as 'provider/<id|alias>' and resolves accordingly
186    ///   Note: model IDs may themselves contain '/', so everything after the first '/'
187    ///   is treated as the model ID or alias.
188    /// - Otherwise, looks up by alias
189    /// - Returns error if not found or invalid
190    pub fn resolve(&self, input: &str) -> Result<ModelId, Error> {
191        if let Some((provider_str, part_raw)) = input.split_once('/') {
192            // Parse provider directly and validate it exists in the registry
193            let provider: ProviderId = ProviderId(provider_str.to_string());
194            let provider_known = self.providers.contains(&provider);
195            if !provider_known {
196                return Err(Error::Configuration(format!(
197                    "Unknown provider: {provider_str}"
198                )));
199            }
200
201            let part = part_raw.trim();
202            if part.is_empty() {
203                return Err(Error::Configuration(
204                    "Model name cannot be empty".to_string(),
205                ));
206            }
207
208            // 1) Try exact model id match (ID can include '/')
209            let candidate = ModelId::new(provider.clone(), part.to_string());
210            if self.models.contains_key(&candidate) {
211                return Ok(candidate);
212            }
213
214            // 2) Try alias scoped to the provider
215            if let Some(alias_id) = self.aliases.get(part)
216                && alias_id.provider == provider
217            {
218                return Ok(alias_id.clone());
219            }
220
221            Err(Error::Configuration(format!(
222                "Unknown model or alias: {input}"
223            )))
224        } else {
225            self.by_alias(input)
226                .map(|config| ModelId::new(config.provider.clone(), config.id.clone()))
227                .ok_or_else(|| Error::Configuration(format!("Unknown model or alias: {input}")))
228        }
229    }
230
231    pub fn recommended(&self) -> impl Iterator<Item = &ModelConfig> {
232        self.models.values().filter(|model| model.recommended)
233    }
234
235    /// Get all models in the registry
236    pub fn all(&self) -> impl Iterator<Item = &ModelConfig> {
237        self.models.values()
238    }
239
240    /// Load catalog from a specific path.
241    fn load_catalog_file(path: &Path) -> Result<Option<Catalog>, Error> {
242        if !path.exists() {
243            return Ok(None);
244        }
245
246        let content = std::fs::read_to_string(path).map_err(Error::Io)?;
247        // Parse as full catalog only
248        let catalog: Catalog = toml::from_str(&content).map_err(|e| {
249            Error::Configuration(format!(
250                "Failed to parse catalog at {}: {}",
251                path.display(),
252                e
253            ))
254        })?;
255        Ok(Some(catalog))
256    }
257
258    /// Determine default discovery paths for catalogs (user + project)
259    fn discover_catalog_paths() -> Vec<PathBuf> {
260        // Standardized discovery paths via utils::paths
261        let paths: Vec<PathBuf> = crate::utils::paths::AppPaths::discover_catalogs();
262        // Do not filter by existence here; load() already checks existence when reading
263        paths
264    }
265
266    /// Merge user models into the base models file.
267    /// Arrays are appended, scalar fields use last-write-wins.
268    fn merge_models(base: &mut Vec<ModelConfig>, user_models: Vec<ModelConfig>) {
269        // Create a map of existing models by ModelId for efficient lookup
270        let mut existing_models: HashMap<ModelId, usize> = HashMap::new();
271        for (idx, model) in base.iter().enumerate() {
272            existing_models.insert(ModelId::new(model.provider.clone(), model.id.clone()), idx);
273        }
274
275        // Process each user model
276        for user_model in user_models {
277            let key = ModelId::new(user_model.provider.clone(), user_model.id.clone());
278
279            if let Some(&idx) = existing_models.get(&key) {
280                // Model exists - merge it
281                base[idx].merge_with(user_model);
282            } else {
283                // New model - add it
284                base.push(user_model);
285            }
286        }
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use crate::config::provider;
294
295    #[test]
296    fn test_load_builtin_models() {
297        // Test that we can parse the built-in catalog
298        let catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML).unwrap();
299        assert!(!catalog.models.is_empty());
300        assert!(!catalog.providers.is_empty());
301
302        // Check that we have some expected models
303        let has_claude = catalog
304            .models
305            .iter()
306            .any(|m| m.provider == "anthropic" && m.id.contains("claude"));
307        assert!(has_claude, "Should have at least one Claude model");
308    }
309
310    #[test]
311    fn test_registry_creation() {
312        // Create a test catalog
313        let toml = r#"
314[[providers]]
315id = "anthropic"
316name = "Anthropic"
317api_format = "anthropic"
318auth_schemes = ["api-key"]
319
320[[models]]
321provider = "anthropic"
322id = "test-model"
323aliases = ["test", "tm"]
324recommended = true
325parameters = { thinking_config = { enabled = true } }
326"#;
327
328        let catalog: Catalog = toml::from_str(toml).unwrap();
329        // Convert Catalog to ModelConfig list using From trait
330        let models: Vec<ModelConfig> = catalog.models.into_iter().map(ModelConfig::from).collect();
331
332        let mut registry = ModelRegistry {
333            models: HashMap::new(),
334            aliases: HashMap::new(),
335            providers: HashSet::new(),
336        };
337
338        for model in models {
339            let model_id = ModelId::new(model.provider.clone(), model.id.clone());
340
341            // track provider
342            registry.providers.insert(model.provider.clone());
343
344            for alias in &model.aliases {
345                registry.aliases.insert(alias.clone(), model_id.clone());
346            }
347
348            registry.models.insert(model_id, model);
349        }
350
351        // Test get
352        let model_id = ModelId::new(provider::anthropic(), "test-model");
353        let model = registry.get(&model_id).unwrap();
354        assert_eq!(model.id, "test-model");
355        assert!(model.recommended);
356
357        // Test parameters were parsed correctly
358        assert!(model.parameters.is_some());
359        let params = model.parameters.unwrap();
360        assert!(params.thinking_config.is_some());
361        assert!(params.thinking_config.unwrap().enabled);
362
363        // Test by_alias
364        let model_by_alias = registry.by_alias("test").unwrap();
365        assert_eq!(model_by_alias.id, "test-model");
366
367        let model_by_alias2 = registry.by_alias("tm").unwrap();
368        assert_eq!(model_by_alias2.id, "test-model");
369
370        // Test recommended
371        let recommended: Vec<_> = registry.recommended().collect();
372        assert_eq!(recommended.len(), 1);
373        assert_eq!(recommended[0].id, "test-model");
374    }
375
376    #[test]
377    fn test_merge_models() {
378        let base_toml = r#"
379[[providers]]
380id = "anthropic"
381name = "Anthropic"
382api_format = "anthropic"
383auth_schemes = ["api-key"]
384
385[[providers]]
386id = "openai"
387name = "OpenAI"
388api_format = "openai-responses"
389auth_schemes = ["api-key"]
390
391[[models]]
392provider = "anthropic"
393id = "claude-3"
394aliases = ["claude"]
395recommended = false
396parameters = { temperature = 0.7, max_tokens = 2048 }
397
398[[models]]
399provider = "openai"
400id = "gpt-4"
401aliases = ["gpt"]
402recommended = true
403"#;
404
405        let user_toml = r#"
406[[providers]]
407id = "google"
408name = "Google"
409api_format = "google"
410auth_schemes = ["api-key"]
411
412[[models]]
413provider = "anthropic"
414id = "claude-3"
415aliases = ["c3", "claude3"]
416recommended = true
417parameters = { temperature = 0.9, thinking_config = { enabled = true } }
418
419[[models]]
420provider = "google"
421id = "gemini-pro"
422aliases = ["gemini"]
423recommended = true
424parameters = { temperature = 0.5, top_p = 0.95 }
425"#;
426
427        let base: Catalog = toml::from_str(base_toml).unwrap();
428        let user: Catalog = toml::from_str(user_toml).unwrap();
429
430        // Convert to ModelConfig using From trait
431        let base_models: Vec<_> = base.models.into_iter().map(ModelConfig::from).collect();
432        let user_models: Vec<_> = user.models.into_iter().map(ModelConfig::from).collect();
433
434        let mut base_models_mut = base_models;
435        ModelRegistry::merge_models(&mut base_models_mut, user_models);
436
437        // Check that we have 3 models total
438        assert_eq!(base_models_mut.len(), 3);
439
440        // Check the merged Claude model
441        let claude = base_models_mut
442            .iter()
443            .find(|m| m.provider == provider::anthropic() && m.id == "claude-3")
444            .unwrap();
445
446        // Aliases should be merged
447        assert_eq!(claude.aliases.len(), 3);
448        assert!(claude.aliases.contains(&"claude".to_string()));
449        assert!(claude.aliases.contains(&"c3".to_string()));
450        assert!(claude.aliases.contains(&"claude3".to_string()));
451
452        // Scalar fields should be overridden
453        assert!(claude.recommended);
454
455        // Parameters should be merged (user overrides base)
456        assert!(claude.parameters.is_some());
457        let claude_params = claude.parameters.unwrap();
458        assert_eq!(claude_params.temperature, Some(0.9)); // overridden from 0.7
459        assert_eq!(claude_params.max_tokens, Some(2048)); // kept from base
460        assert!(claude_params.thinking_config.is_some());
461        assert!(claude_params.thinking_config.unwrap().enabled);
462
463        // Check that GPT-4 is unchanged
464        let gpt4 = base_models_mut
465            .iter()
466            .find(|m| m.provider == provider::openai() && m.id == "gpt-4")
467            .unwrap();
468        assert!(gpt4.recommended);
469        assert!(gpt4.parameters.is_none()); // No parameters in either base or user
470
471        // Check that new model was added
472        let gemini = base_models_mut
473            .iter()
474            .find(|m| m.provider == provider::google() && m.id == "gemini-pro")
475            .unwrap();
476        assert!(gemini.recommended);
477        assert!(gemini.parameters.is_some());
478        let gemini_params = gemini.parameters.unwrap();
479        assert_eq!(gemini_params.temperature, Some(0.5));
480        assert_eq!(gemini_params.top_p, Some(0.95));
481    }
482
483    #[test]
484    fn test_load_catalog_from_path() {
485        use std::fs;
486        use tempfile::TempDir;
487
488        let dir = TempDir::new().unwrap();
489        let config_path = dir.path().join("test_catalog.toml");
490
491        let config = r#"
492[[providers]]
493id = "anthropic"
494name = "Anthropic"
495api_format = "anthropic"
496auth_schemes = ["api-key"]
497
498[[models]]
499provider = "anthropic"
500id = "test-model"
501aliases = ["test"]
502recommended = true
503"#;
504
505        fs::write(&config_path, config).unwrap();
506
507        let result = ModelRegistry::load_catalog_file(&config_path).unwrap();
508        assert!(result.is_some());
509
510        let catalog = result.unwrap();
511        assert_eq!(catalog.models.len(), 1);
512        assert_eq!(catalog.models[0].id, "test-model");
513        assert_eq!(catalog.providers.len(), 1);
514        assert_eq!(catalog.providers[0].id, "anthropic");
515    }
516
517    #[test]
518    fn test_resolve_by_provider_and_parts() {
519        // Build a small registry manually
520        let mut registry = ModelRegistry {
521            models: HashMap::new(),
522            aliases: HashMap::new(),
523            providers: HashSet::new(),
524        };
525        let prov = provider::anthropic();
526
527        let m1 = ModelConfig {
528            provider: prov.clone(),
529            id: "id-1".to_string(),
530            display_name: Some("NiceName".to_string()),
531            aliases: vec!["alias1".into()],
532            recommended: false,
533            parameters: None,
534        };
535        let m2 = ModelConfig {
536            provider: prov.clone(),
537            id: "id-2".to_string(),
538            display_name: Some("Other".to_string()),
539            aliases: vec!["alias2".into()],
540            recommended: false,
541            parameters: None,
542        };
543        let id1 = ModelId::new(prov.clone(), m1.id.clone());
544        let id2 = ModelId::new(prov.clone(), m2.id.clone());
545        registry.aliases.insert("alias1".into(), id1.clone());
546        registry.aliases.insert("alias2".into(), id2.clone());
547        registry.models.insert(id1.clone(), m1.clone());
548        registry.models.insert(id2.clone(), m2.clone());
549        registry.providers.insert(prov.clone());
550
551        // provider/id
552        assert_eq!(registry.resolve("anthropic/id-1").unwrap(), id1);
553        // provider/display_name should NOT resolve
554        assert!(registry.resolve("anthropic/NiceName").is_err());
555        // provider/alias should resolve if alias maps to this provider
556        assert_eq!(registry.resolve("anthropic/alias2").unwrap(), id2);
557        // unknown
558        assert!(registry.resolve("anthropic/does-not-exist").is_err());
559    }
560
561    #[test]
562    fn test_resolve_by_display_name_is_not_supported() {
563        // Two models with same display name under same provider
564        let mut registry = ModelRegistry {
565            models: HashMap::new(),
566            aliases: HashMap::new(),
567            providers: HashSet::new(),
568        };
569        let prov = provider::anthropic();
570        let m1 = ModelConfig {
571            provider: prov.clone(),
572            id: "id-1".into(),
573            display_name: Some("Same".into()),
574            aliases: vec![],
575            recommended: false,
576            parameters: None,
577        };
578        let m2 = ModelConfig {
579            provider: prov.clone(),
580            id: "id-2".into(),
581            display_name: Some("Same".into()),
582            aliases: vec![],
583            recommended: false,
584            parameters: None,
585        };
586        let id1 = ModelId::new(prov.clone(), m1.id.clone());
587        let id2 = ModelId::new(prov.clone(), m2.id.clone());
588        registry.models.insert(id1, m1);
589        registry.models.insert(id2, m2);
590        registry.providers.insert(prov.clone());
591
592        // Resolving by display name should not work
593        let err = registry.resolve("anthropic/Same").unwrap_err();
594        match err {
595            Error::Configuration(msg) => assert!(msg.contains("Unknown model or alias")),
596            _ => panic!("unexpected error type"),
597        }
598    }
599
600    #[test]
601    fn test_load_rejects_invalid_or_duplicate_display_names() {
602        use std::fs;
603        use tempfile::TempDir;
604
605        let dir = TempDir::new().unwrap();
606        let bad_path = dir.path().join("bad_catalog.toml");
607        let dup_path = dir.path().join("dup_catalog.toml");
608
609        // Invalid: empty display_name
610        let bad = r#"
611[[providers]]
612id = "custom"
613name = "Custom"
614api_format = "openai-responses"
615auth_schemes = ["api-key"]
616
617[[models]]
618provider = "custom"
619id = "m1"
620display_name = ""
621"#;
622        fs::write(&bad_path, bad).unwrap();
623        let res = ModelRegistry::load(&[bad_path.to_string_lossy().to_string()]);
624        assert!(matches!(res, Err(Error::Configuration(_))));
625
626        // Duplicate display_name within provider
627        let dup = r#"
628[[providers]]
629id = "custom"
630name = "Custom"
631api_format = "openai-responses"
632auth_schemes = ["api-key"]
633
634[[models]]
635provider = "custom"
636id = "m1"
637display_name = "Same"
638
639[[models]]
640provider = "custom"
641id = "m2"
642display_name = "Same"
643"#;
644        fs::write(&dup_path, dup).unwrap();
645        let res2 = ModelRegistry::load(&[dup_path.to_string_lossy().to_string()]);
646        assert!(matches!(res2, Err(Error::Configuration(_))));
647    }
648
649    #[test]
650    fn test_duplicate_aliases_across_providers_error() {
651        // Two providers, same alias used by different models => should error on load
652        use std::fs;
653        use tempfile::TempDir;
654
655        let dir = TempDir::new().unwrap();
656        let path = dir.path().join("alias_conflict.toml");
657        let toml = r#"
658[[providers]]
659id = "p1"
660name = "P1"
661api_format = "openai-responses"
662auth_schemes = ["api-key"]
663
664[[providers]]
665id = "p2"
666name = "P2"
667api_format = "openai-responses"
668auth_schemes = ["api-key"]
669
670[[models]]
671provider = "p1"
672id = "m1"
673aliases = ["shared"]
674
675[[models]]
676provider = "p2"
677id = "m2"
678aliases = ["shared"]
679"#;
680        fs::write(&path, toml).unwrap();
681        let res = ModelRegistry::load(&[path.to_string_lossy().to_string()]);
682        assert!(matches!(res, Err(Error::Configuration(_))));
683    }
684}