Skip to main content

steer_core/
model_registry.rs

1use std::collections::{HashMap, HashSet};
2use std::path::{Path, PathBuf};
3
4use tracing::debug;
5
6use crate::config::model::{ModelConfig, ModelId};
7use crate::config::provider::ProviderId;
8use crate::config::toml_types::Catalog;
9use crate::error::Error;
10
11const DEFAULT_CATALOG_TOML: &str = include_str!("../assets/default_catalog.toml");
12
13/// Registry containing all available model configurations.
14#[derive(Debug, Clone)]
15pub struct ModelRegistry {
16    /// Map of ModelId to ModelConfig for fast lookups.
17    models: HashMap<ModelId, ModelConfig>,
18    /// Map of aliases to ModelIds for alias resolution.
19    aliases: HashMap<String, ModelId>,
20    /// Set of providers that have at least one model in the registry (for fast provider checks).
21    providers: HashSet<ProviderId>,
22}
23
24impl ModelRegistry {
25    /// Load the model registry with optional additional catalog files.
26    ///
27    /// Merge order (later overrides earlier):
28    /// 1. Built-in defaults from embedded catalog
29    /// 2. Discovered catalogs (project, then user)
30    /// 3. Additional catalog files specified
31    pub fn load(additional_catalogs: &[String]) -> Result<Self, Error> {
32        // First, load the built-in models from embedded catalog
33        let builtin_catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML)
34            .map_err(|e| Error::Configuration(format!("Failed to parse default catalog: {e}")))?;
35
36        // Convert TOML models to ModelConfig
37        let mut models: Vec<ModelConfig> = builtin_catalog
38            .models
39            .into_iter()
40            .map(ModelConfig::from)
41            .collect();
42
43        // Validate providers exist for built-in models
44        let mut known_providers: HashMap<ProviderId, bool> = HashMap::new();
45        for p in builtin_catalog.providers {
46            known_providers.insert(ProviderId(p.id), true);
47        }
48
49        // Load discovered catalogs (user + project)
50        for path in Self::discover_catalog_paths() {
51            if let Some(catalog) = Self::load_catalog_file(&path)? {
52                for p in catalog.providers {
53                    known_providers.insert(ProviderId(p.id), true);
54                }
55                let more_models: Vec<ModelConfig> =
56                    catalog.models.into_iter().map(ModelConfig::from).collect();
57                Self::merge_models(&mut models, more_models);
58            }
59        }
60
61        // Load additional catalog files
62        for catalog_path in additional_catalogs {
63            if let Some(catalog) = Self::load_catalog_file(Path::new(catalog_path))? {
64                // Add new providers to known set
65                for p in catalog.providers {
66                    known_providers.insert(ProviderId(p.id), true);
67                }
68
69                // Merge models
70                let catalog_models: Vec<ModelConfig> =
71                    catalog.models.into_iter().map(ModelConfig::from).collect();
72                Self::merge_models(&mut models, catalog_models);
73            }
74        }
75
76        // Validate all models reference known providers and include required max_output_tokens.
77        for model in &models {
78            if !known_providers.contains_key(&model.provider) {
79                return Err(Error::Configuration(format!(
80                    "Model '{}' references unknown provider '{}'",
81                    model.id, model.provider
82                )));
83            }
84
85            let has_required_max_output_tokens = model
86                .parameters
87                .as_ref()
88                .and_then(|parameters| parameters.max_output_tokens)
89                .is_some();
90            if !has_required_max_output_tokens {
91                return Err(Error::Configuration(format!(
92                    "Model '{}' is missing required parameters.max_output_tokens",
93                    model.id
94                )));
95            }
96        }
97
98        // Build the registry from the merged models
99        let mut registry = Self {
100            models: HashMap::new(),
101            aliases: HashMap::new(),
102            providers: HashSet::new(),
103        };
104
105        for model in models {
106            let model_id = ModelId::new(model.provider.clone(), model.id.clone());
107
108            // Track provider presence
109            registry.providers.insert(model.provider.clone());
110
111            // Store aliases, ensuring global uniqueness; trim and reject empty aliases
112            for raw in &model.aliases {
113                let alias = raw.trim();
114                if alias.is_empty() {
115                    return Err(Error::Configuration(format!(
116                        "Empty alias found for {}/{}",
117                        model_id.provider.storage_key(),
118                        model_id.id.as_str(),
119                    )));
120                }
121                if let Some(existing) = registry.aliases.get(alias)
122                    && existing != &model_id
123                {
124                    return Err(Error::Configuration(format!(
125                        "Duplicate alias '{}' used by {}/{} and {}/{}",
126                        alias,
127                        existing.provider.storage_key(),
128                        existing.id.as_str(),
129                        model_id.provider.storage_key(),
130                        model_id.id.as_str(),
131                    )));
132                }
133                registry.aliases.insert(alias.to_string(), model_id.clone());
134            }
135
136            // Store model
137            registry.models.insert(model_id, model);
138        }
139
140        debug!(
141            target: "model_registry::load",
142            "Loaded models: {:?}",
143            registry.models
144        );
145
146        // Validate display_name values: non-empty, unique per provider
147        {
148            let mut seen: HashMap<ProviderId, HashSet<String>> = HashMap::new();
149            for (model_id, cfg) in &registry.models {
150                if let Some(name_raw) = cfg.display_name.as_deref() {
151                    let name = name_raw.trim();
152                    if name.is_empty() {
153                        return Err(Error::Configuration(format!(
154                            "Invalid display_name '{}' for {}/{}",
155                            name_raw,
156                            model_id.provider.storage_key(),
157                            cfg.id
158                        )));
159                    }
160                    let set = seen.entry(model_id.provider.clone()).or_default();
161                    if !set.insert(name.to_string()) {
162                        return Err(Error::Configuration(format!(
163                            "Duplicate display_name '{}' for provider {}",
164                            name,
165                            model_id.provider.storage_key()
166                        )));
167                    }
168                }
169            }
170        }
171
172        // Validate alias collisions across providers (already enforced during build)
173        // Add a targeted test to ensure cross-provider duplicate aliases error out.
174
175        Ok(registry)
176    }
177
178    /// Build an empty registry (primarily for fallbacks/tests).
179    pub fn empty() -> Self {
180        Self {
181            models: HashMap::new(),
182            aliases: HashMap::new(),
183            providers: HashSet::new(),
184        }
185    }
186
187    /// Get a model by its ID.
188    pub fn get(&self, id: &ModelId) -> Option<&ModelConfig> {
189        self.models.get(id)
190    }
191
192    /// Find a model by its alias.
193    pub fn by_alias(&self, alias: &str) -> Option<&ModelConfig> {
194        self.aliases.get(alias).and_then(|id| self.models.get(id))
195    }
196    /// Resolve a model string to a ModelId.
197    /// - If input contains '/', treats as 'provider/<id|alias>' and resolves accordingly
198    ///   Note: model IDs may themselves contain '/', so everything after the first '/'
199    ///   is treated as the model ID or alias.
200    /// - Otherwise, looks up by alias
201    /// - Returns error if not found or invalid
202    pub fn resolve(&self, input: &str) -> Result<ModelId, Error> {
203        if let Some((provider_str, part_raw)) = input.split_once('/') {
204            // Parse provider directly and validate it exists in the registry
205            let provider: ProviderId = ProviderId(provider_str.to_string());
206            let provider_known = self.providers.contains(&provider);
207            if !provider_known {
208                return Err(Error::Configuration(format!(
209                    "Unknown provider: {provider_str}"
210                )));
211            }
212
213            let part = part_raw.trim();
214            if part.is_empty() {
215                return Err(Error::Configuration(
216                    "Model name cannot be empty".to_string(),
217                ));
218            }
219
220            // 1) Try exact model id match (ID can include '/')
221            let candidate = ModelId::new(provider.clone(), part.to_string());
222            if self.models.contains_key(&candidate) {
223                return Ok(candidate);
224            }
225
226            // 2) Try alias scoped to the provider
227            if let Some(alias_id) = self.aliases.get(part)
228                && alias_id.provider == provider
229            {
230                return Ok(alias_id.clone());
231            }
232
233            Err(Error::Configuration(format!(
234                "Unknown model or alias: {input}"
235            )))
236        } else {
237            self.by_alias(input)
238                .map(|config| ModelId::new(config.provider.clone(), config.id.clone()))
239                .ok_or_else(|| Error::Configuration(format!("Unknown model or alias: {input}")))
240        }
241    }
242
243    pub fn recommended(&self) -> impl Iterator<Item = &ModelConfig> {
244        self.models.values().filter(|model| model.recommended)
245    }
246
247    /// Get all models in the registry
248    pub fn all(&self) -> impl Iterator<Item = &ModelConfig> {
249        self.models.values()
250    }
251
252    /// Load catalog from a specific path.
253    fn load_catalog_file(path: &Path) -> Result<Option<Catalog>, Error> {
254        if !path.exists() {
255            return Ok(None);
256        }
257
258        let content = std::fs::read_to_string(path).map_err(Error::Io)?;
259        // Parse as full catalog only
260        let catalog: Catalog = toml::from_str(&content).map_err(|e| {
261            Error::Configuration(format!(
262                "Failed to parse catalog at {}: {}",
263                path.display(),
264                e
265            ))
266        })?;
267        Ok(Some(catalog))
268    }
269
270    /// Determine default discovery paths for catalogs (user + project)
271    fn discover_catalog_paths() -> Vec<PathBuf> {
272        // Standardized discovery paths via utils::paths
273        let paths: Vec<PathBuf> = crate::utils::paths::AppPaths::discover_catalogs();
274        // Do not filter by existence here; load() already checks existence when reading
275        paths
276    }
277
278    /// Merge user models into the base models file.
279    /// Arrays are appended, scalar fields use last-write-wins.
280    fn merge_models(base: &mut Vec<ModelConfig>, user_models: Vec<ModelConfig>) {
281        // Create a map of existing models by ModelId for efficient lookup
282        let mut existing_models: HashMap<ModelId, usize> = HashMap::new();
283        for (idx, model) in base.iter().enumerate() {
284            existing_models.insert(ModelId::new(model.provider.clone(), model.id.clone()), idx);
285        }
286
287        // Process each user model
288        for user_model in user_models {
289            let key = ModelId::new(user_model.provider.clone(), user_model.id.clone());
290
291            if let Some(&idx) = existing_models.get(&key) {
292                // Model exists - merge it
293                base[idx].merge_with(user_model);
294            } else {
295                // New model - add it
296                base.push(user_model);
297            }
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::config::provider;
306
307    #[test]
308    fn test_load_builtin_models() {
309        // Test that we can parse the built-in catalog
310        let catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML).unwrap();
311        assert!(!catalog.models.is_empty());
312        assert!(!catalog.providers.is_empty());
313
314        // Check that we have some expected models
315        let has_claude = catalog
316            .models
317            .iter()
318            .any(|m| m.provider == "anthropic" && m.id.contains("claude"));
319        assert!(has_claude, "Should have at least one Claude model");
320    }
321
322    #[test]
323    fn test_registry_creation() {
324        // Create a test catalog
325        let toml = r#"
326[[providers]]
327id = "anthropic"
328name = "Anthropic"
329api_format = "anthropic"
330auth_schemes = ["api-key"]
331
332[[models]]
333provider = "anthropic"
334id = "test-model"
335aliases = ["test", "tm"]
336recommended = true
337parameters = { thinking_config = { enabled = true } }
338"#;
339
340        let catalog: Catalog = toml::from_str(toml).unwrap();
341        // Convert Catalog to ModelConfig list using From trait
342        let models: Vec<ModelConfig> = catalog.models.into_iter().map(ModelConfig::from).collect();
343
344        let mut registry = ModelRegistry {
345            models: HashMap::new(),
346            aliases: HashMap::new(),
347            providers: HashSet::new(),
348        };
349
350        for model in models {
351            let model_id = ModelId::new(model.provider.clone(), model.id.clone());
352
353            // track provider
354            registry.providers.insert(model.provider.clone());
355
356            for alias in &model.aliases {
357                registry.aliases.insert(alias.clone(), model_id.clone());
358            }
359
360            registry.models.insert(model_id, model);
361        }
362
363        // Test get
364        let model_id = ModelId::new(provider::anthropic(), "test-model");
365        let model = registry.get(&model_id).unwrap();
366        assert_eq!(model.id, "test-model");
367        assert!(model.recommended);
368
369        // Test parameters were parsed correctly
370        assert!(model.parameters.is_some());
371        let params = model.parameters.unwrap();
372        assert!(params.thinking_config.is_some());
373        assert!(params.thinking_config.unwrap().enabled);
374
375        // Test by_alias
376        let model_by_alias = registry.by_alias("test").unwrap();
377        assert_eq!(model_by_alias.id, "test-model");
378
379        let model_by_alias2 = registry.by_alias("tm").unwrap();
380        assert_eq!(model_by_alias2.id, "test-model");
381
382        // Test recommended
383        let recommended: Vec<_> = registry.recommended().collect();
384        assert_eq!(recommended.len(), 1);
385        assert_eq!(recommended[0].id, "test-model");
386    }
387
388    #[test]
389    fn test_merge_models() {
390        let base_toml = r#"
391[[providers]]
392id = "anthropic"
393name = "Anthropic"
394api_format = "anthropic"
395auth_schemes = ["api-key"]
396
397[[providers]]
398id = "openai"
399name = "OpenAI"
400api_format = "openai-responses"
401auth_schemes = ["api-key"]
402
403[[models]]
404provider = "anthropic"
405id = "claude-3"
406aliases = ["claude"]
407recommended = false
408parameters = { temperature = 0.7, max_output_tokens = 2048 }
409
410[[models]]
411provider = "openai"
412id = "gpt-4"
413aliases = ["gpt"]
414recommended = true
415"#;
416
417        let user_toml = r#"
418[[providers]]
419id = "google"
420name = "Google"
421api_format = "google"
422auth_schemes = ["api-key"]
423
424[[models]]
425provider = "anthropic"
426id = "claude-3"
427aliases = ["c3", "claude3"]
428recommended = true
429parameters = { temperature = 0.9, thinking_config = { enabled = true } }
430
431[[models]]
432provider = "google"
433id = "gemini-pro"
434aliases = ["gemini"]
435recommended = true
436parameters = { temperature = 0.5, max_output_tokens = 4096, top_p = 0.95 }
437"#;
438
439        let base: Catalog = toml::from_str(base_toml).unwrap();
440        let user: Catalog = toml::from_str(user_toml).unwrap();
441
442        // Convert to ModelConfig using From trait
443        let base_models: Vec<_> = base.models.into_iter().map(ModelConfig::from).collect();
444        let user_models: Vec<_> = user.models.into_iter().map(ModelConfig::from).collect();
445
446        let mut base_models_mut = base_models;
447        ModelRegistry::merge_models(&mut base_models_mut, user_models);
448
449        // Check that we have 3 models total
450        assert_eq!(base_models_mut.len(), 3);
451
452        // Check the merged Claude model
453        let claude = base_models_mut
454            .iter()
455            .find(|m| m.provider == provider::anthropic() && m.id == "claude-3")
456            .unwrap();
457
458        // Aliases should be merged
459        assert_eq!(claude.aliases.len(), 3);
460        assert!(claude.aliases.contains(&"claude".to_string()));
461        assert!(claude.aliases.contains(&"c3".to_string()));
462        assert!(claude.aliases.contains(&"claude3".to_string()));
463
464        // Scalar fields should be overridden
465        assert!(claude.recommended);
466
467        // Parameters should be merged (user overrides base)
468        assert!(claude.parameters.is_some());
469        let claude_params = claude.parameters.unwrap();
470        assert_eq!(claude_params.temperature, Some(0.9)); // overridden from 0.7
471        assert_eq!(claude_params.max_output_tokens, Some(2048)); // kept from base
472        assert!(claude_params.thinking_config.is_some());
473        assert!(claude_params.thinking_config.unwrap().enabled);
474
475        // Check that GPT-4 is unchanged
476        let gpt4 = base_models_mut
477            .iter()
478            .find(|m| m.provider == provider::openai() && m.id == "gpt-4")
479            .unwrap();
480        assert!(gpt4.recommended);
481        assert!(gpt4.parameters.is_none()); // No parameters in either base or user
482
483        // Check that new model was added
484        let gemini = base_models_mut
485            .iter()
486            .find(|m| m.provider == provider::google() && m.id == "gemini-pro")
487            .unwrap();
488        assert!(gemini.recommended);
489        assert!(gemini.parameters.is_some());
490        let gemini_params = gemini.parameters.unwrap();
491        assert_eq!(gemini_params.temperature, Some(0.5));
492        assert_eq!(gemini_params.top_p, Some(0.95));
493    }
494
495    #[test]
496    fn test_load_catalog_from_path() {
497        use std::fs;
498        use tempfile::TempDir;
499
500        let dir = TempDir::new().unwrap();
501        let config_path = dir.path().join("test_catalog.toml");
502
503        let config = r#"
504[[providers]]
505id = "anthropic"
506name = "Anthropic"
507api_format = "anthropic"
508auth_schemes = ["api-key"]
509
510[[models]]
511provider = "anthropic"
512id = "test-model"
513aliases = ["test"]
514recommended = true
515parameters = { max_output_tokens = 4096 }
516"#;
517
518        fs::write(&config_path, config).unwrap();
519
520        let result = ModelRegistry::load_catalog_file(&config_path).unwrap();
521        assert!(result.is_some());
522
523        let catalog = result.unwrap();
524        assert_eq!(catalog.models.len(), 1);
525        assert_eq!(catalog.models[0].id, "test-model");
526        assert_eq!(catalog.providers.len(), 1);
527        assert_eq!(catalog.providers[0].id, "anthropic");
528    }
529
530    #[test]
531    fn test_resolve_by_provider_and_parts() {
532        // Build a small registry manually
533        let mut registry = ModelRegistry {
534            models: HashMap::new(),
535            aliases: HashMap::new(),
536            providers: HashSet::new(),
537        };
538        let prov = provider::anthropic();
539
540        let m1 = ModelConfig {
541            provider: prov.clone(),
542            id: "id-1".to_string(),
543            display_name: Some("NiceName".to_string()),
544            aliases: vec!["alias1".into()],
545            recommended: false,
546            context_window_tokens: None,
547            parameters: None,
548        };
549        let m2 = ModelConfig {
550            provider: prov.clone(),
551            id: "id-2".to_string(),
552            display_name: Some("Other".to_string()),
553            aliases: vec!["alias2".into()],
554            recommended: false,
555            context_window_tokens: None,
556            parameters: None,
557        };
558        let id1 = ModelId::new(prov.clone(), m1.id.clone());
559        let id2 = ModelId::new(prov.clone(), m2.id.clone());
560        registry.aliases.insert("alias1".into(), id1.clone());
561        registry.aliases.insert("alias2".into(), id2.clone());
562        registry.models.insert(id1.clone(), m1.clone());
563        registry.models.insert(id2.clone(), m2.clone());
564        registry.providers.insert(prov.clone());
565
566        // provider/id
567        assert_eq!(registry.resolve("anthropic/id-1").unwrap(), id1);
568        // provider/display_name should NOT resolve
569        assert!(registry.resolve("anthropic/NiceName").is_err());
570        // provider/alias should resolve if alias maps to this provider
571        assert_eq!(registry.resolve("anthropic/alias2").unwrap(), id2);
572        // unknown
573        assert!(registry.resolve("anthropic/does-not-exist").is_err());
574    }
575
576    #[test]
577    fn test_resolve_by_display_name_is_not_supported() {
578        // Two models with same display name under same provider
579        let mut registry = ModelRegistry {
580            models: HashMap::new(),
581            aliases: HashMap::new(),
582            providers: HashSet::new(),
583        };
584        let prov = provider::anthropic();
585        let m1 = ModelConfig {
586            provider: prov.clone(),
587            id: "id-1".into(),
588            display_name: Some("Same".into()),
589            aliases: vec![],
590            recommended: false,
591            context_window_tokens: None,
592            parameters: None,
593        };
594        let m2 = ModelConfig {
595            provider: prov.clone(),
596            id: "id-2".into(),
597            display_name: Some("Same".into()),
598            aliases: vec![],
599            recommended: false,
600            context_window_tokens: None,
601            parameters: None,
602        };
603        let id1 = ModelId::new(prov.clone(), m1.id.clone());
604        let id2 = ModelId::new(prov.clone(), m2.id.clone());
605        registry.models.insert(id1, m1);
606        registry.models.insert(id2, m2);
607        registry.providers.insert(prov.clone());
608
609        // Resolving by display name should not work
610        let err = registry.resolve("anthropic/Same").unwrap_err();
611        match err {
612            Error::Configuration(msg) => assert!(msg.contains("Unknown model or alias")),
613            _ => panic!("unexpected error type"),
614        }
615    }
616
617    #[test]
618    fn test_load_rejects_invalid_or_duplicate_display_names() {
619        use std::fs;
620        use tempfile::TempDir;
621
622        let dir = TempDir::new().unwrap();
623        let bad_path = dir.path().join("bad_catalog.toml");
624        let dup_path = dir.path().join("dup_catalog.toml");
625
626        // Invalid: missing required max_output_tokens
627        let bad = r#"
628[[providers]]
629id = "custom"
630name = "Custom"
631api_format = "openai-responses"
632auth_schemes = ["api-key"]
633
634[[models]]
635provider = "custom"
636id = "m1"
637display_name = ""
638"#;
639        fs::write(&bad_path, bad).unwrap();
640        let res = ModelRegistry::load(&[bad_path.to_string_lossy().to_string()]);
641        assert!(matches!(res, Err(Error::Configuration(_))));
642
643        // Duplicate display_name within provider
644        let dup = r#"
645[[providers]]
646id = "custom"
647name = "Custom"
648api_format = "openai-responses"
649auth_schemes = ["api-key"]
650
651[[models]]
652provider = "custom"
653id = "m1"
654display_name = "Same"
655parameters = { max_output_tokens = 1024 }
656
657[[models]]
658provider = "custom"
659id = "m2"
660display_name = "Same"
661parameters = { max_output_tokens = 2048 }
662"#;
663        fs::write(&dup_path, dup).unwrap();
664        let res2 = ModelRegistry::load(&[dup_path.to_string_lossy().to_string()]);
665        assert!(matches!(res2, Err(Error::Configuration(_))));
666    }
667
668    #[test]
669    fn test_duplicate_aliases_across_providers_error() {
670        // Two providers, same alias used by different models => should error on load
671        use std::fs;
672        use tempfile::TempDir;
673
674        let dir = TempDir::new().unwrap();
675        let path = dir.path().join("alias_conflict.toml");
676        let toml = r#"
677[[providers]]
678id = "p1"
679name = "P1"
680api_format = "openai-responses"
681auth_schemes = ["api-key"]
682
683[[providers]]
684id = "p2"
685name = "P2"
686api_format = "openai-responses"
687auth_schemes = ["api-key"]
688
689[[models]]
690provider = "p1"
691id = "m1"
692aliases = ["shared"]
693parameters = { max_output_tokens = 1024 }
694
695[[models]]
696provider = "p2"
697id = "m2"
698aliases = ["shared"]
699parameters = { max_output_tokens = 1024 }
700"#;
701        fs::write(&path, toml).unwrap();
702        let res = ModelRegistry::load(&[path.to_string_lossy().to_string()]);
703        assert!(matches!(res, Err(Error::Configuration(_))));
704    }
705}