steer_core/auth/
registry.rs1use crate::config::provider::{ProviderConfig, ProviderId};
2use crate::config::toml_types::Catalog;
3use std::collections::HashMap;
4use std::path::Path;
5
6#[derive(Debug, Clone)]
10pub struct ProviderRegistry {
11 pub(crate) providers: HashMap<ProviderId, ProviderConfig>,
12}
13
14const DEFAULT_CATALOG_TOML: &str = include_str!("../../assets/default_catalog.toml");
15
16impl ProviderRegistry {
17 pub fn load(additional_catalogs: &[String]) -> crate::error::Result<Self> {
24 let mut providers: HashMap<ProviderId, ProviderConfig> = HashMap::new();
25
26 let builtin_catalog: Catalog = toml::from_str(DEFAULT_CATALOG_TOML).map_err(|e| {
28 crate::error::Error::Configuration(format!(
29 "Failed to parse embedded default_catalog.toml: {e}"
30 ))
31 })?;
32
33 for p in builtin_catalog.providers {
34 let config = ProviderConfig::from(p);
35 providers.insert(config.id.clone(), config);
36 }
37
38 for path in crate::utils::paths::AppPaths::discover_catalogs() {
40 if let Some(catalog) = Self::load_catalog_file(&path)? {
41 for p in catalog.providers {
42 let config = ProviderConfig::from(p);
43 providers.insert(config.id.clone(), config);
44 }
45 }
46 }
47
48 for catalog_path in additional_catalogs {
50 if let Some(catalog) = Self::load_catalog_file(Path::new(catalog_path))? {
51 for p in catalog.providers {
52 let config = ProviderConfig::from(p);
53 providers.insert(config.id.clone(), config);
54 }
55 }
56 }
57
58 Ok(Self { providers })
59 }
60
61 fn load_catalog_file(path: &Path) -> crate::error::Result<Option<Catalog>> {
63 if !path.exists() {
64 return Ok(None);
65 }
66
67 let contents = std::fs::read_to_string(path)?;
68 let catalog: Catalog = toml::from_str(&contents).map_err(|e| {
69 crate::error::Error::Configuration(format!("Failed to parse {}: {}", path.display(), e))
70 })?;
71
72 Ok(Some(catalog))
73 }
74
75 pub fn empty() -> Self {
77 Self {
78 providers: HashMap::new(),
79 }
80 }
81
82 pub fn get(&self, id: &ProviderId) -> Option<&ProviderConfig> {
84 self.providers.get(id)
85 }
86
87 pub fn all(&self) -> impl Iterator<Item = &ProviderConfig> {
89 self.providers.values()
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::config::provider::{self, ApiFormat, AuthScheme, ProviderId};
97 use crate::config::toml_types::{Catalog, ProviderData};
98 use std::fs;
99
100 fn write_test_catalog(base_dir: &std::path::Path, catalog: &Catalog) {
102 let catalog_path = base_dir.join("test_catalog.toml");
103 let toml_str = toml::to_string(catalog).unwrap();
104 fs::write(catalog_path, toml_str).unwrap();
105 }
106
107 #[test]
108 fn loads_builtin_when_no_additional_catalogs() {
109 let reg = ProviderRegistry::load(&[]).expect("load registry");
110 assert!(reg.all().count() >= 4);
112 }
113
114 #[test]
115 fn loads_and_merges_additional_catalog() {
116 let temp = tempfile::tempdir().unwrap();
117
118 let catalog = Catalog {
120 providers: vec![
121 ProviderData {
122 id: "anthropic".to_string(),
123 name: "Anthropic (override)".to_string(),
124 api_format: ApiFormat::Anthropic,
125 auth_schemes: vec![AuthScheme::ApiKey],
126 base_url: None,
127 },
128 ProviderData {
129 id: "myprov".to_string(),
130 name: "My Provider".to_string(),
131 api_format: ApiFormat::OpenaiResponses,
132 auth_schemes: vec![AuthScheme::ApiKey],
133 base_url: None,
134 },
135 ],
136 models: vec![],
137 };
138
139 write_test_catalog(temp.path(), &catalog);
140
141 let catalog_path = temp
142 .path()
143 .join("test_catalog.toml")
144 .to_string_lossy()
145 .to_string();
146 let reg = ProviderRegistry::load(&[catalog_path]).expect("load registry");
147
148 let anthro = reg.get(&provider::anthropic()).unwrap();
150 assert_eq!(anthro.name, "Anthropic (override)");
151
152 let custom = reg.get(&ProviderId("myprov".to_string())).unwrap();
154 assert_eq!(custom.name, "My Provider");
155
156 assert!(reg.all().count() >= 5);
157 }
158}