steer_core/auth/
registry.rs

1use crate::config::provider::{ProviderConfig, ProviderId};
2use crate::config::toml_types::Catalog;
3use std::collections::HashMap;
4use std::path::Path;
5
6/// Registry for provider definitions and authentication flow factories.
7///
8/// This struct is pure domain logic – no networking or gRPC dependencies.
9#[derive(Debug, Clone)]
10pub struct ProviderRegistry {
11    providers: HashMap<ProviderId, ProviderConfig>,
12}
13
14const DEFAULT_CATALOG_TOML: &str = include_str!("../../assets/default_catalog.toml");
15
16impl ProviderRegistry {
17    /// Load provider definitions with optional additional catalog files.
18    ///
19    /// Merge order (later overrides earlier):
20    /// 1. Built-in defaults from embedded catalog
21    /// 2. Discovered catalogs (project, then user)
22    /// 3. Additional catalog files specified
23    pub fn load(additional_catalogs: &[String]) -> crate::error::Result<Self> {
24        let mut providers: HashMap<ProviderId, ProviderConfig> = HashMap::new();
25
26        // 1. Built-in providers from embedded catalog
27        let builtin_catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML).map_err(|e| {
28            crate::error::Error::Configuration(format!(
29                "Failed to parse embedded default_catalog.toml: {e}"
30            ))
31        })?;
32
33        for p in builtin_catalog.providers {
34            let config = ProviderConfig::from(p);
35            providers.insert(config.id.clone(), config);
36        }
37
38        // 2. Discovered catalog files (project then user)
39        for path in crate::utils::paths::AppPaths::discover_catalogs() {
40            if let Some(catalog) = Self::load_catalog_file(&path)? {
41                for p in catalog.providers {
42                    let config = ProviderConfig::from(p);
43                    providers.insert(config.id.clone(), config);
44                }
45            }
46        }
47
48        // 3. Additional catalog files
49        for catalog_path in additional_catalogs {
50            if let Some(catalog) = Self::load_catalog_file(Path::new(catalog_path))? {
51                for p in catalog.providers {
52                    let config = ProviderConfig::from(p);
53                    providers.insert(config.id.clone(), config);
54                }
55            }
56        }
57
58        Ok(Self { providers })
59    }
60
61    /// Load a catalog file from disk.
62    fn load_catalog_file(path: &Path) -> crate::error::Result<Option<Catalog>> {
63        if !path.exists() {
64            return Ok(None);
65        }
66
67        let contents = std::fs::read_to_string(path)?;
68        let catalog: Catalog = toml::from_str(&contents).map_err(|e| {
69            crate::error::Error::Configuration(format!("Failed to parse {}: {}", path.display(), e))
70        })?;
71
72        Ok(Some(catalog))
73    }
74
75    /// Get a provider config by ID.
76    pub fn get(&self, id: &ProviderId) -> Option<&ProviderConfig> {
77        self.providers.get(id)
78    }
79
80    /// Iterate over all provider configs.
81    pub fn all(&self) -> impl Iterator<Item = &ProviderConfig> {
82        self.providers.values()
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::config::provider::{self, ApiFormat, AuthScheme, ProviderId};
90    use crate::config::toml_types::{Catalog, ProviderData};
91    use std::fs;
92
93    // Helper to write a test catalog
94    fn write_test_catalog(base_dir: &std::path::Path, catalog: &Catalog) {
95        let catalog_path = base_dir.join("test_catalog.toml");
96        let toml_str = toml::to_string(catalog).unwrap();
97        fs::write(catalog_path, toml_str).unwrap();
98    }
99
100    #[test]
101    fn loads_builtin_when_no_additional_catalogs() {
102        let reg = ProviderRegistry::load(&[]).expect("load registry");
103        // At least built-ins; discovered catalogs may add more in dev envs
104        assert!(reg.all().count() >= 4);
105    }
106
107    #[test]
108    fn loads_and_merges_additional_catalog() {
109        let temp = tempfile::tempdir().unwrap();
110
111        // Create a catalog with override and new provider
112        let catalog = Catalog {
113            providers: vec![
114                ProviderData {
115                    id: "anthropic".to_string(),
116                    name: "Anthropic (override)".to_string(),
117                    api_format: ApiFormat::Anthropic,
118                    auth_schemes: vec![AuthScheme::ApiKey],
119                    base_url: None,
120                },
121                ProviderData {
122                    id: "myprov".to_string(),
123                    name: "My Provider".to_string(),
124                    api_format: ApiFormat::OpenaiResponses,
125                    auth_schemes: vec![AuthScheme::ApiKey],
126                    base_url: None,
127                },
128            ],
129            models: vec![],
130        };
131
132        write_test_catalog(temp.path(), &catalog);
133
134        let catalog_path = temp
135            .path()
136            .join("test_catalog.toml")
137            .to_string_lossy()
138            .to_string();
139        let reg = ProviderRegistry::load(&[catalog_path]).expect("load registry");
140
141        // Overridden provider
142        let anthro = reg.get(&provider::anthropic()).unwrap();
143        assert_eq!(anthro.name, "Anthropic (override)");
144
145        // Custom provider present
146        let custom = reg.get(&ProviderId("myprov".to_string())).unwrap();
147        assert_eq!(custom.name, "My Provider");
148
149        assert!(reg.all().count() >= 5);
150    }
151}