steer_core/auth/
mod.rs

1pub mod anthropic;
2pub mod api_key;
3pub mod error;
4pub mod registry;
5pub mod storage;
6
7use async_trait::async_trait;
8use std::sync::Arc;
9
10pub use error::{AuthError, Result};
11pub use registry::ProviderRegistry;
12pub use storage::{AuthStorage, AuthTokens, Credential, CredentialType, DefaultAuthStorage};
13
14/// Available authentication methods for a provider
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum AuthMethod {
17    OAuth,
18    ApiKey,
19}
20
21/// Progress status for authentication flows
22#[derive(Debug, Clone)]
23pub enum AuthProgress {
24    /// Need input from the user with a prompt message
25    NeedInput(String),
26    /// Authentication is in progress with a status message
27    InProgress(String),
28    /// Authentication is complete
29    Complete,
30    /// An error occurred
31    Error(String),
32    /// OAuth flow started, contains the authorization URL
33    OAuthStarted { auth_url: String },
34}
35
36/// Generic authentication flow trait that providers can implement
37#[async_trait]
38pub trait AuthenticationFlow: Send + Sync {
39    /// The state type for this authentication flow
40    type State: Send + Sync;
41
42    /// Get available authentication methods for this provider
43    fn available_methods(&self) -> Vec<AuthMethod>;
44
45    /// Start an authentication flow
46    async fn start_auth(&self, method: AuthMethod) -> Result<Self::State>;
47
48    /// Get initial progress/instructions after starting auth
49    async fn get_initial_progress(
50        &self,
51        state: &Self::State,
52        method: AuthMethod,
53    ) -> Result<AuthProgress>;
54
55    /// Handle user input during authentication
56    async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress>;
57
58    /// Check if the provider is already authenticated
59    async fn is_authenticated(&self) -> Result<bool>;
60
61    /// Get a display name for the provider
62    fn provider_name(&self) -> String;
63}
64
65/// Type-erased authentication flow for dynamic dispatch
66#[async_trait]
67pub trait DynAuthenticationFlow: Send + Sync {
68    /// Get available authentication methods for this provider
69    fn available_methods(&self) -> Vec<AuthMethod>;
70
71    /// Start an authentication flow
72    async fn start_auth(&self, method: AuthMethod) -> Result<Box<dyn std::any::Any + Send + Sync>>;
73
74    /// Get initial progress/instructions after starting auth
75    async fn get_initial_progress(
76        &self,
77        state: &Box<dyn std::any::Any + Send + Sync>,
78        method: AuthMethod,
79    ) -> Result<AuthProgress>;
80
81    /// Handle user input during authentication
82    async fn handle_input(
83        &self,
84        state: &mut Box<dyn std::any::Any + Send + Sync>,
85        input: &str,
86    ) -> Result<AuthProgress>;
87
88    /// Check if the provider is already authenticated
89    async fn is_authenticated(&self) -> Result<bool>;
90
91    /// Get a display name for the provider
92    fn provider_name(&self) -> String;
93}
94
95/// Wrapper to convert a concrete AuthenticationFlow into a DynAuthenticationFlow
96pub struct AuthFlowWrapper<T: AuthenticationFlow> {
97    inner: T,
98}
99
100impl<T: AuthenticationFlow> AuthFlowWrapper<T> {
101    pub fn new(inner: T) -> Self {
102        Self { inner }
103    }
104}
105
106#[async_trait]
107impl<T: AuthenticationFlow + 'static> DynAuthenticationFlow for AuthFlowWrapper<T>
108where
109    T::State: 'static,
110{
111    fn available_methods(&self) -> Vec<AuthMethod> {
112        self.inner.available_methods()
113    }
114
115    async fn start_auth(&self, method: AuthMethod) -> Result<Box<dyn std::any::Any + Send + Sync>> {
116        let state = self.inner.start_auth(method).await?;
117        Ok(Box::new(state))
118    }
119
120    async fn get_initial_progress(
121        &self,
122        state: &Box<dyn std::any::Any + Send + Sync>,
123        method: AuthMethod,
124    ) -> Result<AuthProgress> {
125        let concrete_state = state
126            .downcast_ref::<T::State>()
127            .ok_or_else(|| AuthError::InvalidResponse("Invalid state type".to_string()))?;
128        self.inner
129            .get_initial_progress(concrete_state, method)
130            .await
131    }
132
133    async fn handle_input(
134        &self,
135        state: &mut Box<dyn std::any::Any + Send + Sync>,
136        input: &str,
137    ) -> Result<AuthProgress> {
138        let concrete_state = state
139            .downcast_mut::<T::State>()
140            .ok_or_else(|| AuthError::InvalidResponse("Invalid state type".to_string()))?;
141        self.inner.handle_input(concrete_state, input).await
142    }
143
144    async fn is_authenticated(&self) -> Result<bool> {
145        self.inner.is_authenticated().await
146    }
147
148    fn provider_name(&self) -> String {
149        self.inner.provider_name()
150    }
151}
152
153/// Marker trait for providers that support interactive authentication
154pub trait InteractiveAuth: Send + Sync {
155    /// Create an authentication flow for interactive setup
156    fn create_auth_flow(
157        &self,
158        storage: Arc<dyn AuthStorage>,
159    ) -> Option<Box<dyn DynAuthenticationFlow>>;
160}