Skip to main content

steer_auth_plugin/
flow.rs

1use crate::error::{AuthError, Result};
2use async_trait::async_trait;
3
4/// Available authentication methods for a provider
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum AuthMethod {
7    OAuth,
8    ApiKey,
9}
10
11/// Progress status for authentication flows
12#[derive(Debug, Clone)]
13pub enum AuthProgress {
14    /// Need input from the user with a prompt message
15    NeedInput(String),
16    /// Authentication is in progress with a status message
17    InProgress(String),
18    /// Authentication is complete
19    Complete,
20    /// An error occurred
21    Error(String),
22    /// OAuth flow started, contains the authorization URL
23    OAuthStarted { auth_url: String },
24}
25
26/// Generic authentication flow trait that providers can implement
27#[async_trait]
28pub trait AuthenticationFlow: Send + Sync {
29    /// The state type for this authentication flow
30    type State: Send + Sync;
31
32    /// Get available authentication methods for this provider
33    fn available_methods(&self) -> Vec<AuthMethod>;
34
35    /// Start an authentication flow
36    async fn start_auth(&self, method: AuthMethod) -> Result<Self::State>;
37
38    /// Get initial progress/instructions after starting auth
39    async fn get_initial_progress(
40        &self,
41        state: &Self::State,
42        method: AuthMethod,
43    ) -> Result<AuthProgress>;
44
45    /// Handle user input during authentication
46    async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress>;
47
48    /// Check if the provider is already authenticated
49    async fn is_authenticated(&self) -> Result<bool>;
50
51    /// Get a display name for the provider
52    fn provider_name(&self) -> String;
53}
54
55/// Type-erased authentication flow for dynamic dispatch
56#[async_trait]
57pub trait DynAuthenticationFlow: Send + Sync {
58    /// Get available authentication methods for this provider
59    fn available_methods(&self) -> Vec<AuthMethod>;
60
61    /// Start an authentication flow
62    async fn start_auth(&self, method: AuthMethod) -> Result<Box<dyn std::any::Any + Send + Sync>>;
63
64    /// Get initial progress/instructions after starting auth
65    async fn get_initial_progress(
66        &self,
67        state: &Box<dyn std::any::Any + Send + Sync>,
68        method: AuthMethod,
69    ) -> Result<AuthProgress>;
70
71    /// Handle user input during authentication
72    async fn handle_input(
73        &self,
74        state: &mut Box<dyn std::any::Any + Send + Sync>,
75        input: &str,
76    ) -> Result<AuthProgress>;
77
78    /// Check if the provider is already authenticated
79    async fn is_authenticated(&self) -> Result<bool>;
80
81    /// Get a display name for the provider
82    fn provider_name(&self) -> String;
83}
84
85/// Wrapper to convert a concrete AuthenticationFlow into a DynAuthenticationFlow
86pub struct AuthFlowWrapper<T: AuthenticationFlow> {
87    inner: T,
88}
89
90impl<T: AuthenticationFlow> AuthFlowWrapper<T> {
91    pub fn new(inner: T) -> Self {
92        Self { inner }
93    }
94}
95
96#[async_trait]
97impl<T: AuthenticationFlow + 'static> DynAuthenticationFlow for AuthFlowWrapper<T>
98where
99    T::State: 'static,
100{
101    fn available_methods(&self) -> Vec<AuthMethod> {
102        self.inner.available_methods()
103    }
104
105    async fn start_auth(&self, method: AuthMethod) -> Result<Box<dyn std::any::Any + Send + Sync>> {
106        let state = self.inner.start_auth(method).await?;
107        Ok(Box::new(state))
108    }
109
110    async fn get_initial_progress(
111        &self,
112        state: &Box<dyn std::any::Any + Send + Sync>,
113        method: AuthMethod,
114    ) -> Result<AuthProgress> {
115        let concrete_state = state
116            .downcast_ref::<T::State>()
117            .ok_or_else(|| AuthError::InvalidResponse("Invalid state type".to_string()))?;
118        self.inner
119            .get_initial_progress(concrete_state, method)
120            .await
121    }
122
123    async fn handle_input(
124        &self,
125        state: &mut Box<dyn std::any::Any + Send + Sync>,
126        input: &str,
127    ) -> Result<AuthProgress> {
128        let concrete_state = state
129            .downcast_mut::<T::State>()
130            .ok_or_else(|| AuthError::InvalidResponse("Invalid state type".to_string()))?;
131        self.inner.handle_input(concrete_state, input).await
132    }
133
134    async fn is_authenticated(&self) -> Result<bool> {
135        self.inner.is_authenticated().await
136    }
137
138    fn provider_name(&self) -> String {
139        self.inner.provider_name()
140    }
141}