steer_core/auth/
plugin_registry.rs1use crate::config::provider::ProviderId;
2use crate::error::{Error, Result};
3use std::collections::HashMap;
4use std::sync::Arc;
5use steer_auth_anthropic::AnthropicAuthPlugin;
6use steer_auth_openai::OpenAiAuthPlugin;
7use steer_auth_plugin::identifiers::ProviderId as PluginProviderId;
8use steer_auth_plugin::plugin::AuthPlugin;
9
10#[derive(Clone, Default)]
11pub struct AuthPluginRegistry {
12 plugins: HashMap<ProviderId, Arc<dyn AuthPlugin>>,
13}
14
15impl AuthPluginRegistry {
16 pub fn new() -> Self {
17 Self::default()
18 }
19
20 pub fn with_defaults() -> Result<Self> {
21 let mut registry = Self::new();
22 registry.register(Arc::new(OpenAiAuthPlugin::new()))?;
23 registry.register(Arc::new(AnthropicAuthPlugin::new()))?;
24 Ok(registry)
25 }
26
27 pub fn register(&mut self, plugin: Arc<dyn AuthPlugin>) -> Result<()> {
28 let plugin_id: PluginProviderId = plugin.provider_id();
29 let provider_id = ProviderId(plugin_id.0);
30 if self.plugins.contains_key(&provider_id) {
31 return Err(Error::Configuration(format!(
32 "Auth plugin conflict for provider {}",
33 provider_id.as_str()
34 )));
35 }
36 self.plugins.insert(provider_id, plugin);
37 Ok(())
38 }
39
40 pub fn get(&self, provider_id: &ProviderId) -> Option<&Arc<dyn AuthPlugin>> {
41 self.plugins.get(provider_id)
42 }
43
44 pub fn all(&self) -> impl Iterator<Item = &Arc<dyn AuthPlugin>> {
45 self.plugins.values()
46 }
47}