steer_core/auth/
api_key.rs

1use crate::api::ProviderKind;
2use crate::auth::{AuthError, AuthStorage, Credential, CredentialType, Result};
3use crate::auth::{AuthMethod, AuthProgress, AuthenticationFlow};
4use async_trait::async_trait;
5use std::sync::Arc;
6
7/// Generic API key authentication flow for providers that support API keys
8pub struct ApiKeyAuthFlow {
9    storage: Arc<dyn AuthStorage>,
10    provider: ProviderKind,
11}
12
13impl ApiKeyAuthFlow {
14    pub fn new(storage: Arc<dyn AuthStorage>, provider: ProviderKind) -> Self {
15        Self { storage, provider }
16    }
17
18    /// Validate an API key format based on provider-specific rules
19    fn validate_api_key(&self, api_key: &str) -> Result<()> {
20        let trimmed = api_key.trim();
21
22        if trimmed.is_empty() {
23            return Err(AuthError::InvalidCredential(
24                "API key cannot be empty".to_string(),
25            ));
26        }
27
28        // Provider-specific validation
29        match self.provider {
30            ProviderKind::OpenAI => {
31                if !trimmed.starts_with("sk-") || trimmed.len() < 20 {
32                    return Err(AuthError::InvalidCredential(
33                        "OpenAI API keys should start with 'sk-' and be at least 20 characters"
34                            .to_string(),
35                    ));
36                }
37            }
38            ProviderKind::Anthropic => {
39                if !trimmed.starts_with("sk-ant-") {
40                    return Err(AuthError::InvalidCredential(
41                        "Anthropic API keys should start with 'sk-ant-'".to_string(),
42                    ));
43                }
44            }
45            ProviderKind::Google => {
46                // Google/Gemini keys are typically 39 characters
47                if trimmed.len() < 30 {
48                    return Err(AuthError::InvalidCredential(
49                        "Google API key appears to be too short".to_string(),
50                    ));
51                }
52            }
53            ProviderKind::XAI => {
54                // Grok doesn't have a specific format requirement yet
55                if trimmed.len() < 10 {
56                    return Err(AuthError::InvalidCredential(
57                        "API key appears to be too short".to_string(),
58                    ));
59                }
60            }
61        }
62
63        // Check for common mistakes
64        if trimmed.contains(' ') && !trimmed.contains("Bearer") {
65            return Err(AuthError::InvalidCredential(
66                "API key should not contain spaces".to_string(),
67            ));
68        }
69
70        Ok(())
71    }
72}
73
74/// State for the API key authentication flow
75#[derive(Debug, Clone)]
76pub struct ApiKeyAuthState {
77    pub awaiting_input: bool,
78}
79
80#[async_trait]
81impl AuthenticationFlow for ApiKeyAuthFlow {
82    type State = ApiKeyAuthState;
83
84    fn available_methods(&self) -> Vec<AuthMethod> {
85        vec![AuthMethod::ApiKey]
86    }
87
88    async fn start_auth(&self, method: AuthMethod) -> Result<Self::State> {
89        match method {
90            AuthMethod::ApiKey => Ok(ApiKeyAuthState {
91                awaiting_input: true,
92            }),
93            _ => Err(AuthError::UnsupportedMethod {
94                method: format!("{method:?}"),
95                provider: self.provider,
96            }),
97        }
98    }
99
100    async fn get_initial_progress(
101        &self,
102        _state: &Self::State,
103        method: AuthMethod,
104    ) -> Result<AuthProgress> {
105        match method {
106            AuthMethod::ApiKey => Ok(AuthProgress::NeedInput(format!(
107                "Enter your {} API key",
108                self.provider
109            ))),
110            _ => Err(AuthError::UnsupportedMethod {
111                method: format!("{method:?}"),
112                provider: self.provider,
113            }),
114        }
115    }
116
117    async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress> {
118        if !state.awaiting_input {
119            return Err(AuthError::InvalidState(
120                "Not expecting input at this stage".to_string(),
121            ));
122        }
123
124        // Validate the API key
125        self.validate_api_key(input)?;
126
127        // Store the API key
128        self.storage
129            .set_credential(
130                &self.provider.to_string(),
131                Credential::ApiKey {
132                    value: input.trim().to_string(),
133                },
134            )
135            .await
136            .map_err(|e| AuthError::Storage(format!("Failed to store API key: {e}")))?;
137
138        state.awaiting_input = false;
139        Ok(AuthProgress::Complete)
140    }
141
142    async fn is_authenticated(&self) -> Result<bool> {
143        Ok(self
144            .storage
145            .get_credential(&self.provider.to_string(), CredentialType::ApiKey)
146            .await?
147            .is_some())
148    }
149
150    fn provider_name(&self) -> String {
151        self.provider.to_string()
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::auth::{AuthStorage, Credential, CredentialType};
159    use async_trait::async_trait;
160    use std::collections::HashMap;
161    use tokio::sync::Mutex;
162
163    /// Mock implementation of AuthStorage for testing
164    struct MockAuthStorage {
165        credentials: Arc<Mutex<HashMap<(String, CredentialType), Credential>>>,
166    }
167
168    impl MockAuthStorage {
169        fn new() -> Self {
170            Self {
171                credentials: Arc::new(Mutex::new(HashMap::new())),
172            }
173        }
174    }
175
176    #[async_trait]
177    impl AuthStorage for MockAuthStorage {
178        async fn get_credential(
179            &self,
180            provider: &str,
181            credential_type: CredentialType,
182        ) -> Result<Option<Credential>> {
183            let creds = self.credentials.lock().await;
184            Ok(creds.get(&(provider.to_string(), credential_type)).cloned())
185        }
186
187        async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
188            let mut creds = self.credentials.lock().await;
189            let cred_type = match &credential {
190                Credential::ApiKey { .. } => CredentialType::ApiKey,
191                Credential::AuthTokens { .. } => CredentialType::AuthTokens,
192            };
193            creds.insert((provider.to_string(), cred_type), credential);
194            Ok(())
195        }
196
197        async fn remove_credential(
198            &self,
199            provider: &str,
200            credential_type: CredentialType,
201        ) -> Result<()> {
202            let mut creds = self.credentials.lock().await;
203            creds.remove(&(provider.to_string(), credential_type));
204            Ok(())
205        }
206    }
207
208    #[tokio::test]
209    async fn test_api_key_flow() {
210        let storage = Arc::new(MockAuthStorage::new());
211        let auth_flow = ApiKeyAuthFlow::new(storage.clone(), ProviderKind::XAI);
212
213        // Test available methods
214        let methods = auth_flow.available_methods();
215        assert_eq!(methods.len(), 1);
216        assert!(methods.contains(&AuthMethod::ApiKey));
217
218        // Start API key flow
219        let state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
220        assert!(state.awaiting_input);
221
222        // Get initial progress
223        let progress = auth_flow
224            .get_initial_progress(&state, AuthMethod::ApiKey)
225            .await
226            .unwrap();
227        match progress {
228            AuthProgress::NeedInput(msg) => assert_eq!(msg, "Enter your xai API key"),
229            _ => panic!("Expected NeedInput progress"),
230        }
231
232        // Handle API key input
233        let mut state = state;
234        let progress = auth_flow
235            .handle_input(&mut state, "test-api-key-12345")
236            .await
237            .unwrap();
238        assert!(matches!(progress, AuthProgress::Complete));
239        assert!(!state.awaiting_input);
240
241        // Verify API key was stored
242        let cred = storage
243            .get_credential(&ProviderKind::XAI.to_string(), CredentialType::ApiKey)
244            .await
245            .unwrap();
246        assert!(cred.is_some());
247        if let Some(Credential::ApiKey { value }) = cred {
248            assert_eq!(value, "test-api-key-12345");
249        } else {
250            panic!("Expected API key credential");
251        }
252
253        // Verify authentication status
254        assert!(auth_flow.is_authenticated().await.unwrap());
255    }
256
257    #[tokio::test]
258    async fn test_empty_api_key() {
259        let storage = Arc::new(MockAuthStorage::new());
260        let auth_flow = ApiKeyAuthFlow::new(storage, ProviderKind::XAI);
261
262        let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
263
264        // Test empty input
265        let result = auth_flow.handle_input(&mut state, "").await;
266        assert!(result.is_err());
267        match result.unwrap_err() {
268            AuthError::InvalidCredential(msg) => {
269                assert_eq!(msg, "API key cannot be empty");
270            }
271            _ => panic!("Expected InvalidCredential error"),
272        }
273    }
274
275    #[tokio::test]
276    async fn test_invalid_method() {
277        let storage = Arc::new(MockAuthStorage::new());
278        let auth_flow = ApiKeyAuthFlow::new(storage, ProviderKind::XAI);
279
280        // Test with OAuth method
281        let result = auth_flow.start_auth(AuthMethod::OAuth).await;
282        assert!(result.is_err());
283        match result.unwrap_err() {
284            AuthError::UnsupportedMethod { method, provider } => {
285                assert_eq!(method, "OAuth");
286                assert_eq!(provider, ProviderKind::XAI);
287            }
288            _ => panic!("Expected UnsupportedMethod error"),
289        }
290    }
291
292    #[tokio::test]
293    async fn test_openai_key_validation() {
294        let storage = Arc::new(MockAuthStorage::new());
295        let auth_flow = ApiKeyAuthFlow::new(storage, ProviderKind::OpenAI);
296
297        let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
298
299        // Test invalid OpenAI key format
300        let result = auth_flow.handle_input(&mut state, "invalid-key").await;
301        assert!(result.is_err());
302        match result.unwrap_err() {
303            AuthError::InvalidCredential(msg) => {
304                assert!(msg.contains("OpenAI API keys should start with 'sk-'"));
305            }
306            _ => panic!("Expected InvalidCredential error"),
307        }
308
309        // Test valid OpenAI key format
310        let result = auth_flow
311            .handle_input(&mut state, "sk-1234567890abcdef1234567890")
312            .await;
313        assert!(result.is_ok());
314    }
315
316    #[tokio::test]
317    async fn test_api_key_with_spaces() {
318        let storage = Arc::new(MockAuthStorage::new());
319        let auth_flow = ApiKeyAuthFlow::new(storage, ProviderKind::XAI);
320
321        let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
322
323        // Test API key with spaces
324        let result = auth_flow
325            .handle_input(&mut state, "test key with spaces")
326            .await;
327        assert!(result.is_err());
328        match result.unwrap_err() {
329            AuthError::InvalidCredential(msg) => {
330                assert_eq!(msg, "API key should not contain spaces");
331            }
332            _ => panic!("Expected InvalidCredential error"),
333        }
334    }
335}