steer_core/auth/
registry.rs1use crate::config::provider::{ProviderConfig, ProviderId};
2use crate::config::toml_types::Catalog;
3use std::collections::HashMap;
4use std::path::Path;
5
6#[derive(Debug, Clone)]
10pub struct ProviderRegistry {
11 providers: HashMap<ProviderId, ProviderConfig>,
12}
13
14const DEFAULT_CATALOG_TOML: &str = include_str!("../../assets/default_catalog.toml");
15
16impl ProviderRegistry {
17 pub fn load(additional_catalogs: &[String]) -> crate::error::Result<Self> {
24 let mut providers: HashMap<ProviderId, ProviderConfig> = HashMap::new();
25
26 let builtin_catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML).map_err(|e| {
28 crate::error::Error::Configuration(format!(
29 "Failed to parse embedded default_catalog.toml: {e}"
30 ))
31 })?;
32
33 for p in builtin_catalog.providers {
34 let config = ProviderConfig::from(p);
35 providers.insert(config.id.clone(), config);
36 }
37
38 for path in crate::utils::paths::AppPaths::discover_catalogs() {
40 if let Some(catalog) = Self::load_catalog_file(&path)? {
41 for p in catalog.providers {
42 let config = ProviderConfig::from(p);
43 providers.insert(config.id.clone(), config);
44 }
45 }
46 }
47
48 for catalog_path in additional_catalogs {
50 if let Some(catalog) = Self::load_catalog_file(Path::new(catalog_path))? {
51 for p in catalog.providers {
52 let config = ProviderConfig::from(p);
53 providers.insert(config.id.clone(), config);
54 }
55 }
56 }
57
58 Ok(Self { providers })
59 }
60
61 fn load_catalog_file(path: &Path) -> crate::error::Result<Option<Catalog>> {
63 if !path.exists() {
64 return Ok(None);
65 }
66
67 let contents = std::fs::read_to_string(path)?;
68 let catalog: Catalog = toml::from_str(&contents).map_err(|e| {
69 crate::error::Error::Configuration(format!("Failed to parse {}: {}", path.display(), e))
70 })?;
71
72 Ok(Some(catalog))
73 }
74
75 pub fn get(&self, id: &ProviderId) -> Option<&ProviderConfig> {
77 self.providers.get(id)
78 }
79
80 pub fn all(&self) -> impl Iterator<Item = &ProviderConfig> {
82 self.providers.values()
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::config::provider::{self, ApiFormat, AuthScheme, ProviderId};
90 use crate::config::toml_types::{Catalog, ProviderData};
91 use std::fs;
92
93 fn write_test_catalog(base_dir: &std::path::Path, catalog: &Catalog) {
95 let catalog_path = base_dir.join("test_catalog.toml");
96 let toml_str = toml::to_string(catalog).unwrap();
97 fs::write(catalog_path, toml_str).unwrap();
98 }
99
100 #[test]
101 fn loads_builtin_when_no_additional_catalogs() {
102 let reg = ProviderRegistry::load(&[]).expect("load registry");
103 assert!(reg.all().count() >= 4);
105 }
106
107 #[test]
108 fn loads_and_merges_additional_catalog() {
109 let temp = tempfile::tempdir().unwrap();
110
111 let catalog = Catalog {
113 providers: vec![
114 ProviderData {
115 id: "anthropic".to_string(),
116 name: "Anthropic (override)".to_string(),
117 api_format: ApiFormat::Anthropic,
118 auth_schemes: vec![AuthScheme::ApiKey],
119 base_url: None,
120 },
121 ProviderData {
122 id: "myprov".to_string(),
123 name: "My Provider".to_string(),
124 api_format: ApiFormat::OpenaiResponses,
125 auth_schemes: vec![AuthScheme::ApiKey],
126 base_url: None,
127 },
128 ],
129 models: vec![],
130 };
131
132 write_test_catalog(temp.path(), &catalog);
133
134 let catalog_path = temp
135 .path()
136 .join("test_catalog.toml")
137 .to_string_lossy()
138 .to_string();
139 let reg = ProviderRegistry::load(&[catalog_path]).expect("load registry");
140
141 let anthro = reg.get(&provider::anthropic()).unwrap();
143 assert_eq!(anthro.name, "Anthropic (override)");
144
145 let custom = reg.get(&ProviderId("myprov".to_string())).unwrap();
147 assert_eq!(custom.name, "My Provider");
148
149 assert!(reg.all().count() >= 5);
150 }
151}