Skip to main content

steer_core/auth/
registry.rs

1use crate::config::provider::{ProviderConfig, ProviderId};
2use crate::config::toml_types::Catalog;
3use std::collections::HashMap;
4use std::path::Path;
5
6/// Registry for provider definitions and authentication flow factories.
7///
8/// This struct is pure domain logic – no networking or gRPC dependencies.
9#[derive(Debug, Clone)]
10pub struct ProviderRegistry {
11    pub(crate) providers: HashMap<ProviderId, ProviderConfig>,
12}
13
14const DEFAULT_CATALOG_TOML: &str = include_str!("../../assets/default_catalog.toml");
15
16impl ProviderRegistry {
17    /// Load provider definitions with optional additional catalog files.
18    ///
19    /// Merge order (later overrides earlier):
20    /// 1. Built-in defaults from embedded catalog
21    /// 2. Discovered catalogs (project, then user)
22    /// 3. Additional catalog files specified
23    pub fn load(additional_catalogs: &[String]) -> crate::error::Result<Self> {
24        let mut providers: HashMap<ProviderId, ProviderConfig> = HashMap::new();
25
26        // 1. Built-in providers from embedded catalog
27        let builtin_catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML).map_err(|e| {
28            crate::error::Error::Configuration(format!(
29                "Failed to parse embedded default_catalog.toml: {e}"
30            ))
31        })?;
32
33        for p in builtin_catalog.providers {
34            let config = ProviderConfig::from(p);
35            providers.insert(config.id.clone(), config);
36        }
37
38        // 2. Discovered catalog files (project then user)
39        for path in crate::utils::paths::AppPaths::discover_catalogs() {
40            if let Some(catalog) = Self::load_catalog_file(&path)? {
41                for p in catalog.providers {
42                    let config = ProviderConfig::from(p);
43                    providers.insert(config.id.clone(), config);
44                }
45            }
46        }
47
48        // 3. Additional catalog files
49        for catalog_path in additional_catalogs {
50            if let Some(catalog) = Self::load_catalog_file(Path::new(catalog_path))? {
51                for p in catalog.providers {
52                    let config = ProviderConfig::from(p);
53                    providers.insert(config.id.clone(), config);
54                }
55            }
56        }
57
58        Ok(Self { providers })
59    }
60
61    /// Load a catalog file from disk.
62    fn load_catalog_file(path: &Path) -> crate::error::Result<Option<Catalog>> {
63        if !path.exists() {
64            return Ok(None);
65        }
66
67        let contents = std::fs::read_to_string(path)?;
68        let catalog: Catalog = toml::from_str(&contents).map_err(|e| {
69            crate::error::Error::Configuration(format!("Failed to parse {}: {}", path.display(), e))
70        })?;
71
72        Ok(Some(catalog))
73    }
74
75    /// Build an empty registry (primarily for fallbacks/tests).
76    pub fn empty() -> Self {
77        Self {
78            providers: HashMap::new(),
79        }
80    }
81
82    /// Get a provider config by ID.
83    pub fn get(&self, id: &ProviderId) -> Option<&ProviderConfig> {
84        self.providers.get(id)
85    }
86
87    /// Iterate over all provider configs.
88    pub fn all(&self) -> impl Iterator<Item = &ProviderConfig> {
89        self.providers.values()
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::config::provider::{self, ApiFormat, AuthScheme, ProviderId};
97    use crate::config::toml_types::{Catalog, ProviderData};
98    use std::fs;
99
100    // Helper to write a test catalog
101    fn write_test_catalog(base_dir: &std::path::Path, catalog: &Catalog) {
102        let catalog_path = base_dir.join("test_catalog.toml");
103        let toml_str = toml::to_string(catalog).unwrap();
104        fs::write(catalog_path, toml_str).unwrap();
105    }
106
107    #[test]
108    fn loads_builtin_when_no_additional_catalogs() {
109        let reg = ProviderRegistry::load(&[]).expect("load registry");
110        // At least built-ins; discovered catalogs may add more in dev envs
111        assert!(reg.all().count() >= 4);
112    }
113
114    #[test]
115    fn loads_and_merges_additional_catalog() {
116        let temp = tempfile::tempdir().unwrap();
117
118        // Create a catalog with override and new provider
119        let catalog = Catalog {
120            providers: vec![
121                ProviderData {
122                    id: "anthropic".to_string(),
123                    name: "Anthropic (override)".to_string(),
124                    api_format: ApiFormat::Anthropic,
125                    auth_schemes: vec![AuthScheme::ApiKey],
126                    base_url: None,
127                },
128                ProviderData {
129                    id: "myprov".to_string(),
130                    name: "My Provider".to_string(),
131                    api_format: ApiFormat::OpenaiResponses,
132                    auth_schemes: vec![AuthScheme::ApiKey],
133                    base_url: None,
134                },
135            ],
136            models: vec![],
137        };
138
139        write_test_catalog(temp.path(), &catalog);
140
141        let catalog_path = temp
142            .path()
143            .join("test_catalog.toml")
144            .to_string_lossy()
145            .to_string();
146        let reg = ProviderRegistry::load(&[catalog_path]).expect("load registry");
147
148        // Overridden provider
149        let anthro = reg.get(&provider::anthropic()).unwrap();
150        assert_eq!(anthro.name, "Anthropic (override)");
151
152        // Custom provider present
153        let custom = reg.get(&ProviderId("myprov".to_string())).unwrap();
154        assert_eq!(custom.name, "My Provider");
155
156        assert!(reg.all().count() >= 5);
157    }
158}