1use crate::auth::{AuthError, AuthStorage, Credential, CredentialType, Result};
2use crate::auth::{AuthMethod, AuthProgress, AuthenticationFlow};
3use crate::config::provider::{self, ProviderId};
4use async_trait::async_trait;
5use std::sync::Arc;
6
7pub struct ApiKeyAuthFlow {
9 storage: Arc<dyn AuthStorage>,
10 provider_id: ProviderId,
11 provider_display_name: String,
12}
13
14impl ApiKeyAuthFlow {
15 pub fn new(
16 storage: Arc<dyn AuthStorage>,
17 provider_id: ProviderId,
18 provider_display_name: String,
19 ) -> Self {
20 Self {
21 storage,
22 provider_id,
23 provider_display_name,
24 }
25 }
26
27 fn provider_name(&self) -> String {
29 self.provider_id.storage_key()
30 }
31
32 fn provider_display_name(&self) -> String {
34 self.provider_display_name.clone()
35 }
36
37 fn validate_api_key(&self, api_key: &str) -> Result<()> {
39 let trimmed = api_key.trim();
40
41 if trimmed.is_empty() {
42 return Err(AuthError::InvalidCredential(
43 "API key cannot be empty".to_string(),
44 ));
45 }
46
47 if self.provider_id.as_str() == provider::OPENAI_ID {
49 if !trimmed.starts_with("sk-") || trimmed.len() < 20 {
50 return Err(AuthError::InvalidCredential(
51 "OpenAI API keys should start with 'sk-' and be at least 20 characters"
52 .to_string(),
53 ));
54 }
55 } else if self.provider_id.as_str() == provider::ANTHROPIC_ID {
56 if !trimmed.starts_with("sk-ant-") {
57 return Err(AuthError::InvalidCredential(
58 "Anthropic API keys should start with 'sk-ant-'".to_string(),
59 ));
60 }
61 } else if self.provider_id.as_str() == provider::GOOGLE_ID {
62 if trimmed.len() < 30 {
64 return Err(AuthError::InvalidCredential(
65 "Google API key appears to be too short".to_string(),
66 ));
67 }
68 } else if self.provider_id.as_str() == provider::XAI_ID {
69 if trimmed.len() < 10 {
71 return Err(AuthError::InvalidCredential(
72 "API key appears to be too short".to_string(),
73 ));
74 }
75 }
76 if trimmed.contains(' ') && !trimmed.contains("Bearer") {
80 return Err(AuthError::InvalidCredential(
81 "API key should not contain spaces".to_string(),
82 ));
83 }
84
85 Ok(())
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct ApiKeyAuthState {
92 pub awaiting_input: bool,
93}
94
95#[async_trait]
96impl AuthenticationFlow for ApiKeyAuthFlow {
97 type State = ApiKeyAuthState;
98
99 fn available_methods(&self) -> Vec<AuthMethod> {
100 vec![AuthMethod::ApiKey]
101 }
102
103 async fn start_auth(&self, method: AuthMethod) -> Result<Self::State> {
104 match method {
105 AuthMethod::ApiKey => Ok(ApiKeyAuthState {
106 awaiting_input: true,
107 }),
108 AuthMethod::OAuth => Err(AuthError::UnsupportedMethod {
109 method: format!("{method:?}"),
110 provider: self.provider_display_name(),
111 }),
112 }
113 }
114
115 async fn get_initial_progress(
116 &self,
117 _state: &Self::State,
118 method: AuthMethod,
119 ) -> Result<AuthProgress> {
120 match method {
121 AuthMethod::ApiKey => Ok(AuthProgress::NeedInput(format!(
122 "Enter your {} API key",
123 self.provider_display_name()
124 ))),
125 AuthMethod::OAuth => Err(AuthError::UnsupportedMethod {
126 method: format!("{method:?}"),
127 provider: self.provider_display_name(),
128 }),
129 }
130 }
131
132 async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress> {
133 if !state.awaiting_input {
134 return Err(AuthError::InvalidState(
135 "Not expecting input at this stage".to_string(),
136 ));
137 }
138
139 self.validate_api_key(input)?;
141
142 self.storage
144 .set_credential(
145 &self.provider_name(),
146 Credential::ApiKey {
147 value: input.trim().to_string(),
148 },
149 )
150 .await
151 .map_err(|e| AuthError::Storage(format!("Failed to store API key: {e}")))?;
152
153 state.awaiting_input = false;
154 Ok(AuthProgress::Complete)
155 }
156
157 async fn is_authenticated(&self) -> Result<bool> {
158 Ok(self
159 .storage
160 .get_credential(&self.provider_name(), CredentialType::ApiKey)
161 .await?
162 .is_some())
163 }
164
165 fn provider_name(&self) -> String {
166 ApiKeyAuthFlow::provider_name(self)
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::auth::{AuthStorage, Credential, CredentialType};
174 use async_trait::async_trait;
175 use std::collections::HashMap;
176 use tokio::sync::Mutex;
177
178 struct MockAuthStorage {
180 credentials: Arc<Mutex<HashMap<(String, CredentialType), Credential>>>,
181 }
182
183 impl MockAuthStorage {
184 fn new() -> Self {
185 Self {
186 credentials: Arc::new(Mutex::new(HashMap::new())),
187 }
188 }
189 }
190
191 #[async_trait]
192 impl AuthStorage for MockAuthStorage {
193 async fn get_credential(
194 &self,
195 provider: &str,
196 credential_type: CredentialType,
197 ) -> Result<Option<Credential>> {
198 let creds = self.credentials.lock().await;
199 Ok(creds.get(&(provider.to_string(), credential_type)).cloned())
200 }
201
202 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
203 let mut creds = self.credentials.lock().await;
204 let cred_type = match &credential {
205 Credential::ApiKey { .. } => CredentialType::ApiKey,
206 Credential::OAuth2 { .. } => CredentialType::OAuth2,
207 };
208 creds.insert((provider.to_string(), cred_type), credential);
209 Ok(())
210 }
211
212 async fn remove_credential(
213 &self,
214 provider: &str,
215 credential_type: CredentialType,
216 ) -> Result<()> {
217 let mut creds = self.credentials.lock().await;
218 creds.remove(&(provider.to_string(), credential_type));
219 Ok(())
220 }
221 }
222
223 #[tokio::test]
224 async fn test_api_key_flow() {
225 let storage = Arc::new(MockAuthStorage::new());
226 let auth_flow = ApiKeyAuthFlow::new(storage.clone(), provider::xai(), "xAI".to_string());
227
228 let methods = auth_flow.available_methods();
230 assert_eq!(methods.len(), 1);
231 assert!(methods.contains(&AuthMethod::ApiKey));
232
233 let state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
235 assert!(state.awaiting_input);
236
237 let progress = auth_flow
239 .get_initial_progress(&state, AuthMethod::ApiKey)
240 .await
241 .unwrap();
242 match progress {
243 AuthProgress::NeedInput(msg) => assert_eq!(msg, "Enter your xAI API key"),
244 _ => panic!("Expected NeedInput progress"),
245 }
246
247 let mut state = state;
249 let progress = auth_flow
250 .handle_input(&mut state, "test-api-key-12345")
251 .await
252 .unwrap();
253 assert!(matches!(progress, AuthProgress::Complete));
254 assert!(!state.awaiting_input);
255
256 let cred = storage
258 .get_credential("xai", CredentialType::ApiKey)
259 .await
260 .unwrap();
261 assert!(cred.is_some());
262 if let Some(Credential::ApiKey { value }) = cred {
263 assert_eq!(value, "test-api-key-12345");
264 } else {
265 panic!("Expected API key credential");
266 }
267
268 assert!(auth_flow.is_authenticated().await.unwrap());
270 }
271
272 #[tokio::test]
273 async fn test_empty_api_key() {
274 let storage = Arc::new(MockAuthStorage::new());
275 let auth_flow = ApiKeyAuthFlow::new(storage, provider::xai(), "xAI".to_string());
276
277 let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
278
279 let result = auth_flow.handle_input(&mut state, "").await;
281 assert!(result.is_err());
282 match result.unwrap_err() {
283 AuthError::InvalidCredential(msg) => {
284 assert_eq!(msg, "API key cannot be empty");
285 }
286 _ => panic!("Expected InvalidCredential error"),
287 }
288 }
289
290 #[tokio::test]
291 async fn test_invalid_method() {
292 let storage = Arc::new(MockAuthStorage::new());
293 let auth_flow = ApiKeyAuthFlow::new(storage, provider::xai(), "xAI".to_string());
294
295 let result = auth_flow.start_auth(AuthMethod::OAuth).await;
297 assert!(result.is_err());
298 match result.unwrap_err() {
299 AuthError::UnsupportedMethod { method, provider } => {
300 assert_eq!(method, "OAuth");
301 assert_eq!(provider, "xAI");
302 }
303 _ => panic!("Expected UnsupportedMethod error"),
304 }
305 }
306
307 #[tokio::test]
308 async fn test_openai_key_validation() {
309 let storage = Arc::new(MockAuthStorage::new());
310 let auth_flow = ApiKeyAuthFlow::new(storage, provider::openai(), "OpenAI".to_string());
311
312 let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
313
314 let result = auth_flow.handle_input(&mut state, "invalid-key").await;
316 assert!(result.is_err());
317 match result.unwrap_err() {
318 AuthError::InvalidCredential(msg) => {
319 assert!(msg.contains("OpenAI API keys should start with 'sk-'"));
320 }
321 _ => panic!("Expected InvalidCredential error"),
322 }
323
324 let result = auth_flow
326 .handle_input(&mut state, "sk-1234567890abcdef1234567890")
327 .await;
328 assert!(result.is_ok());
329 }
330
331 #[tokio::test]
332 async fn test_api_key_with_spaces() {
333 let storage = Arc::new(MockAuthStorage::new());
334 let auth_flow = ApiKeyAuthFlow::new(storage, provider::xai(), "xAI".to_string());
335
336 let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
337
338 let result = auth_flow
340 .handle_input(&mut state, "test key with spaces")
341 .await;
342 assert!(result.is_err());
343 match result.unwrap_err() {
344 AuthError::InvalidCredential(msg) => {
345 assert_eq!(msg, "API key should not contain spaces");
346 }
347 _ => panic!("Expected InvalidCredential error"),
348 }
349 }
350}