steer_auth_plugin/
flow.rs1use crate::error::{AuthError, Result};
2use async_trait::async_trait;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum AuthMethod {
7 OAuth,
8 ApiKey,
9}
10
11#[derive(Debug, Clone)]
13pub enum AuthProgress {
14 NeedInput(String),
16 InProgress(String),
18 Complete,
20 Error(String),
22 OAuthStarted { auth_url: String },
24}
25
26#[async_trait]
28pub trait AuthenticationFlow: Send + Sync {
29 type State: Send + Sync;
31
32 fn available_methods(&self) -> Vec<AuthMethod>;
34
35 async fn start_auth(&self, method: AuthMethod) -> Result<Self::State>;
37
38 async fn get_initial_progress(
40 &self,
41 state: &Self::State,
42 method: AuthMethod,
43 ) -> Result<AuthProgress>;
44
45 async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress>;
47
48 async fn is_authenticated(&self) -> Result<bool>;
50
51 fn provider_name(&self) -> String;
53}
54
55#[async_trait]
57pub trait DynAuthenticationFlow: Send + Sync {
58 fn available_methods(&self) -> Vec<AuthMethod>;
60
61 async fn start_auth(&self, method: AuthMethod) -> Result<Box<dyn std::any::Any + Send + Sync>>;
63
64 async fn get_initial_progress(
66 &self,
67 state: &Box<dyn std::any::Any + Send + Sync>,
68 method: AuthMethod,
69 ) -> Result<AuthProgress>;
70
71 async fn handle_input(
73 &self,
74 state: &mut Box<dyn std::any::Any + Send + Sync>,
75 input: &str,
76 ) -> Result<AuthProgress>;
77
78 async fn is_authenticated(&self) -> Result<bool>;
80
81 fn provider_name(&self) -> String;
83}
84
85pub struct AuthFlowWrapper<T: AuthenticationFlow> {
87 inner: T,
88}
89
90impl<T: AuthenticationFlow> AuthFlowWrapper<T> {
91 pub fn new(inner: T) -> Self {
92 Self { inner }
93 }
94}
95
96#[async_trait]
97impl<T: AuthenticationFlow + 'static> DynAuthenticationFlow for AuthFlowWrapper<T>
98where
99 T::State: 'static,
100{
101 fn available_methods(&self) -> Vec<AuthMethod> {
102 self.inner.available_methods()
103 }
104
105 async fn start_auth(&self, method: AuthMethod) -> Result<Box<dyn std::any::Any + Send + Sync>> {
106 let state = self.inner.start_auth(method).await?;
107 Ok(Box::new(state))
108 }
109
110 async fn get_initial_progress(
111 &self,
112 state: &Box<dyn std::any::Any + Send + Sync>,
113 method: AuthMethod,
114 ) -> Result<AuthProgress> {
115 let concrete_state = state
116 .downcast_ref::<T::State>()
117 .ok_or_else(|| AuthError::InvalidResponse("Invalid state type".to_string()))?;
118 self.inner
119 .get_initial_progress(concrete_state, method)
120 .await
121 }
122
123 async fn handle_input(
124 &self,
125 state: &mut Box<dyn std::any::Any + Send + Sync>,
126 input: &str,
127 ) -> Result<AuthProgress> {
128 let concrete_state = state
129 .downcast_mut::<T::State>()
130 .ok_or_else(|| AuthError::InvalidResponse("Invalid state type".to_string()))?;
131 self.inner.handle_input(concrete_state, input).await
132 }
133
134 async fn is_authenticated(&self) -> Result<bool> {
135 self.inner.is_authenticated().await
136 }
137
138 fn provider_name(&self) -> String {
139 self.inner.provider_name()
140 }
141}