1use crate::api::ProviderKind;
2use crate::auth::{AuthError, AuthStorage, Credential, CredentialType, Result};
3use crate::auth::{AuthMethod, AuthProgress, AuthenticationFlow};
4use async_trait::async_trait;
5use std::sync::Arc;
6
7pub struct ApiKeyAuthFlow {
9 storage: Arc<dyn AuthStorage>,
10 provider: ProviderKind,
11}
12
13impl ApiKeyAuthFlow {
14 pub fn new(storage: Arc<dyn AuthStorage>, provider: ProviderKind) -> Self {
15 Self { storage, provider }
16 }
17
18 fn validate_api_key(&self, api_key: &str) -> Result<()> {
20 let trimmed = api_key.trim();
21
22 if trimmed.is_empty() {
23 return Err(AuthError::InvalidCredential(
24 "API key cannot be empty".to_string(),
25 ));
26 }
27
28 match self.provider {
30 ProviderKind::OpenAI => {
31 if !trimmed.starts_with("sk-") || trimmed.len() < 20 {
32 return Err(AuthError::InvalidCredential(
33 "OpenAI API keys should start with 'sk-' and be at least 20 characters"
34 .to_string(),
35 ));
36 }
37 }
38 ProviderKind::Anthropic => {
39 if !trimmed.starts_with("sk-ant-") {
40 return Err(AuthError::InvalidCredential(
41 "Anthropic API keys should start with 'sk-ant-'".to_string(),
42 ));
43 }
44 }
45 ProviderKind::Google => {
46 if trimmed.len() < 30 {
48 return Err(AuthError::InvalidCredential(
49 "Google API key appears to be too short".to_string(),
50 ));
51 }
52 }
53 ProviderKind::XAI => {
54 if trimmed.len() < 10 {
56 return Err(AuthError::InvalidCredential(
57 "API key appears to be too short".to_string(),
58 ));
59 }
60 }
61 }
62
63 if trimmed.contains(' ') && !trimmed.contains("Bearer") {
65 return Err(AuthError::InvalidCredential(
66 "API key should not contain spaces".to_string(),
67 ));
68 }
69
70 Ok(())
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct ApiKeyAuthState {
77 pub awaiting_input: bool,
78}
79
80#[async_trait]
81impl AuthenticationFlow for ApiKeyAuthFlow {
82 type State = ApiKeyAuthState;
83
84 fn available_methods(&self) -> Vec<AuthMethod> {
85 vec![AuthMethod::ApiKey]
86 }
87
88 async fn start_auth(&self, method: AuthMethod) -> Result<Self::State> {
89 match method {
90 AuthMethod::ApiKey => Ok(ApiKeyAuthState {
91 awaiting_input: true,
92 }),
93 _ => Err(AuthError::UnsupportedMethod {
94 method: format!("{method:?}"),
95 provider: self.provider,
96 }),
97 }
98 }
99
100 async fn get_initial_progress(
101 &self,
102 _state: &Self::State,
103 method: AuthMethod,
104 ) -> Result<AuthProgress> {
105 match method {
106 AuthMethod::ApiKey => Ok(AuthProgress::NeedInput(format!(
107 "Enter your {} API key",
108 self.provider
109 ))),
110 _ => Err(AuthError::UnsupportedMethod {
111 method: format!("{method:?}"),
112 provider: self.provider,
113 }),
114 }
115 }
116
117 async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress> {
118 if !state.awaiting_input {
119 return Err(AuthError::InvalidState(
120 "Not expecting input at this stage".to_string(),
121 ));
122 }
123
124 self.validate_api_key(input)?;
126
127 self.storage
129 .set_credential(
130 &self.provider.to_string(),
131 Credential::ApiKey {
132 value: input.trim().to_string(),
133 },
134 )
135 .await
136 .map_err(|e| AuthError::Storage(format!("Failed to store API key: {e}")))?;
137
138 state.awaiting_input = false;
139 Ok(AuthProgress::Complete)
140 }
141
142 async fn is_authenticated(&self) -> Result<bool> {
143 Ok(self
144 .storage
145 .get_credential(&self.provider.to_string(), CredentialType::ApiKey)
146 .await?
147 .is_some())
148 }
149
150 fn provider_name(&self) -> String {
151 self.provider.to_string()
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::auth::{AuthStorage, Credential, CredentialType};
159 use async_trait::async_trait;
160 use std::collections::HashMap;
161 use tokio::sync::Mutex;
162
163 struct MockAuthStorage {
165 credentials: Arc<Mutex<HashMap<(String, CredentialType), Credential>>>,
166 }
167
168 impl MockAuthStorage {
169 fn new() -> Self {
170 Self {
171 credentials: Arc::new(Mutex::new(HashMap::new())),
172 }
173 }
174 }
175
176 #[async_trait]
177 impl AuthStorage for MockAuthStorage {
178 async fn get_credential(
179 &self,
180 provider: &str,
181 credential_type: CredentialType,
182 ) -> Result<Option<Credential>> {
183 let creds = self.credentials.lock().await;
184 Ok(creds.get(&(provider.to_string(), credential_type)).cloned())
185 }
186
187 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
188 let mut creds = self.credentials.lock().await;
189 let cred_type = match &credential {
190 Credential::ApiKey { .. } => CredentialType::ApiKey,
191 Credential::AuthTokens { .. } => CredentialType::AuthTokens,
192 };
193 creds.insert((provider.to_string(), cred_type), credential);
194 Ok(())
195 }
196
197 async fn remove_credential(
198 &self,
199 provider: &str,
200 credential_type: CredentialType,
201 ) -> Result<()> {
202 let mut creds = self.credentials.lock().await;
203 creds.remove(&(provider.to_string(), credential_type));
204 Ok(())
205 }
206 }
207
208 #[tokio::test]
209 async fn test_api_key_flow() {
210 let storage = Arc::new(MockAuthStorage::new());
211 let auth_flow = ApiKeyAuthFlow::new(storage.clone(), ProviderKind::XAI);
212
213 let methods = auth_flow.available_methods();
215 assert_eq!(methods.len(), 1);
216 assert!(methods.contains(&AuthMethod::ApiKey));
217
218 let state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
220 assert!(state.awaiting_input);
221
222 let progress = auth_flow
224 .get_initial_progress(&state, AuthMethod::ApiKey)
225 .await
226 .unwrap();
227 match progress {
228 AuthProgress::NeedInput(msg) => assert_eq!(msg, "Enter your xai API key"),
229 _ => panic!("Expected NeedInput progress"),
230 }
231
232 let mut state = state;
234 let progress = auth_flow
235 .handle_input(&mut state, "test-api-key-12345")
236 .await
237 .unwrap();
238 assert!(matches!(progress, AuthProgress::Complete));
239 assert!(!state.awaiting_input);
240
241 let cred = storage
243 .get_credential(&ProviderKind::XAI.to_string(), CredentialType::ApiKey)
244 .await
245 .unwrap();
246 assert!(cred.is_some());
247 if let Some(Credential::ApiKey { value }) = cred {
248 assert_eq!(value, "test-api-key-12345");
249 } else {
250 panic!("Expected API key credential");
251 }
252
253 assert!(auth_flow.is_authenticated().await.unwrap());
255 }
256
257 #[tokio::test]
258 async fn test_empty_api_key() {
259 let storage = Arc::new(MockAuthStorage::new());
260 let auth_flow = ApiKeyAuthFlow::new(storage, ProviderKind::XAI);
261
262 let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
263
264 let result = auth_flow.handle_input(&mut state, "").await;
266 assert!(result.is_err());
267 match result.unwrap_err() {
268 AuthError::InvalidCredential(msg) => {
269 assert_eq!(msg, "API key cannot be empty");
270 }
271 _ => panic!("Expected InvalidCredential error"),
272 }
273 }
274
275 #[tokio::test]
276 async fn test_invalid_method() {
277 let storage = Arc::new(MockAuthStorage::new());
278 let auth_flow = ApiKeyAuthFlow::new(storage, ProviderKind::XAI);
279
280 let result = auth_flow.start_auth(AuthMethod::OAuth).await;
282 assert!(result.is_err());
283 match result.unwrap_err() {
284 AuthError::UnsupportedMethod { method, provider } => {
285 assert_eq!(method, "OAuth");
286 assert_eq!(provider, ProviderKind::XAI);
287 }
288 _ => panic!("Expected UnsupportedMethod error"),
289 }
290 }
291
292 #[tokio::test]
293 async fn test_openai_key_validation() {
294 let storage = Arc::new(MockAuthStorage::new());
295 let auth_flow = ApiKeyAuthFlow::new(storage, ProviderKind::OpenAI);
296
297 let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
298
299 let result = auth_flow.handle_input(&mut state, "invalid-key").await;
301 assert!(result.is_err());
302 match result.unwrap_err() {
303 AuthError::InvalidCredential(msg) => {
304 assert!(msg.contains("OpenAI API keys should start with 'sk-'"));
305 }
306 _ => panic!("Expected InvalidCredential error"),
307 }
308
309 let result = auth_flow
311 .handle_input(&mut state, "sk-1234567890abcdef1234567890")
312 .await;
313 assert!(result.is_ok());
314 }
315
316 #[tokio::test]
317 async fn test_api_key_with_spaces() {
318 let storage = Arc::new(MockAuthStorage::new());
319 let auth_flow = ApiKeyAuthFlow::new(storage, ProviderKind::XAI);
320
321 let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
322
323 let result = auth_flow
325 .handle_input(&mut state, "test key with spaces")
326 .await;
327 assert!(result.is_err());
328 match result.unwrap_err() {
329 AuthError::InvalidCredential(msg) => {
330 assert_eq!(msg, "API key should not contain spaces");
331 }
332 _ => panic!("Expected InvalidCredential error"),
333 }
334 }
335}