Skip to main content

steer_core/auth/
plugin_registry.rs

1use crate::config::provider::ProviderId;
2use crate::error::{Error, Result};
3use std::collections::HashMap;
4use std::sync::Arc;
5use steer_auth_anthropic::AnthropicAuthPlugin;
6use steer_auth_openai::OpenAiAuthPlugin;
7use steer_auth_plugin::identifiers::ProviderId as PluginProviderId;
8use steer_auth_plugin::plugin::AuthPlugin;
9
10#[derive(Clone, Default)]
11pub struct AuthPluginRegistry {
12    plugins: HashMap<ProviderId, Arc<dyn AuthPlugin>>,
13}
14
15impl AuthPluginRegistry {
16    pub fn new() -> Self {
17        Self::default()
18    }
19
20    pub fn with_defaults() -> Result<Self> {
21        let mut registry = Self::new();
22        registry.register(Arc::new(OpenAiAuthPlugin::new()))?;
23        registry.register(Arc::new(AnthropicAuthPlugin::new()))?;
24        Ok(registry)
25    }
26
27    pub fn register(&mut self, plugin: Arc<dyn AuthPlugin>) -> Result<()> {
28        let plugin_id: PluginProviderId = plugin.provider_id();
29        let provider_id = ProviderId(plugin_id.0);
30        if self.plugins.contains_key(&provider_id) {
31            return Err(Error::Configuration(format!(
32                "Auth plugin conflict for provider {}",
33                provider_id.as_str()
34            )));
35        }
36        self.plugins.insert(provider_id, plugin);
37        Ok(())
38    }
39
40    pub fn get(&self, provider_id: &ProviderId) -> Option<&Arc<dyn AuthPlugin>> {
41        self.plugins.get(provider_id)
42    }
43
44    pub fn all(&self) -> impl Iterator<Item = &Arc<dyn AuthPlugin>> {
45        self.plugins.values()
46    }
47}