Skip to main content

steer_core/auth/
api_key.rs

1use crate::auth::{AuthError, AuthStorage, Credential, CredentialType, Result};
2use crate::auth::{AuthMethod, AuthProgress, AuthenticationFlow};
3use crate::config::provider::{self, ProviderId};
4use async_trait::async_trait;
5use std::sync::Arc;
6
7/// Generic API key authentication flow for providers that support API keys
8pub struct ApiKeyAuthFlow {
9    storage: Arc<dyn AuthStorage>,
10    provider_id: ProviderId,
11    provider_display_name: String,
12}
13
14impl ApiKeyAuthFlow {
15    pub fn new(
16        storage: Arc<dyn AuthStorage>,
17        provider_id: ProviderId,
18        provider_display_name: String,
19    ) -> Self {
20        Self {
21            storage,
22            provider_id,
23            provider_display_name,
24        }
25    }
26
27    /// Get the provider name for display/storage purposes
28    fn provider_name(&self) -> String {
29        self.provider_id.storage_key()
30    }
31
32    /// Get the display name for the provider
33    fn provider_display_name(&self) -> String {
34        self.provider_display_name.clone()
35    }
36
37    /// Validate an API key format based on provider-specific rules
38    fn validate_api_key(&self, api_key: &str) -> Result<()> {
39        let trimmed = api_key.trim();
40
41        if trimmed.is_empty() {
42            return Err(AuthError::InvalidCredential(
43                "API key cannot be empty".to_string(),
44            ));
45        }
46
47        // Provider-specific validation
48        if self.provider_id.as_str() == provider::OPENAI_ID {
49            if !trimmed.starts_with("sk-") || trimmed.len() < 20 {
50                return Err(AuthError::InvalidCredential(
51                    "OpenAI API keys should start with 'sk-' and be at least 20 characters"
52                        .to_string(),
53                ));
54            }
55        } else if self.provider_id.as_str() == provider::ANTHROPIC_ID {
56            if !trimmed.starts_with("sk-ant-") {
57                return Err(AuthError::InvalidCredential(
58                    "Anthropic API keys should start with 'sk-ant-'".to_string(),
59                ));
60            }
61        } else if self.provider_id.as_str() == provider::GOOGLE_ID {
62            // Google/Gemini keys are typically 39 characters
63            if trimmed.len() < 30 {
64                return Err(AuthError::InvalidCredential(
65                    "Google API key appears to be too short".to_string(),
66                ));
67            }
68        } else if self.provider_id.as_str() == provider::XAI_ID {
69            // Grok doesn't have a specific format requirement yet
70            if trimmed.len() < 10 {
71                return Err(AuthError::InvalidCredential(
72                    "API key appears to be too short".to_string(),
73                ));
74            }
75        }
76        // Custom providers - no specific validation
77
78        // Check for common mistakes
79        if trimmed.contains(' ') && !trimmed.contains("Bearer") {
80            return Err(AuthError::InvalidCredential(
81                "API key should not contain spaces".to_string(),
82            ));
83        }
84
85        Ok(())
86    }
87}
88
89/// State for the API key authentication flow
90#[derive(Debug, Clone)]
91pub struct ApiKeyAuthState {
92    pub awaiting_input: bool,
93}
94
95#[async_trait]
96impl AuthenticationFlow for ApiKeyAuthFlow {
97    type State = ApiKeyAuthState;
98
99    fn available_methods(&self) -> Vec<AuthMethod> {
100        vec![AuthMethod::ApiKey]
101    }
102
103    async fn start_auth(&self, method: AuthMethod) -> Result<Self::State> {
104        match method {
105            AuthMethod::ApiKey => Ok(ApiKeyAuthState {
106                awaiting_input: true,
107            }),
108            AuthMethod::OAuth => Err(AuthError::UnsupportedMethod {
109                method: format!("{method:?}"),
110                provider: self.provider_display_name(),
111            }),
112        }
113    }
114
115    async fn get_initial_progress(
116        &self,
117        _state: &Self::State,
118        method: AuthMethod,
119    ) -> Result<AuthProgress> {
120        match method {
121            AuthMethod::ApiKey => Ok(AuthProgress::NeedInput(format!(
122                "Enter your {} API key",
123                self.provider_display_name()
124            ))),
125            AuthMethod::OAuth => Err(AuthError::UnsupportedMethod {
126                method: format!("{method:?}"),
127                provider: self.provider_display_name(),
128            }),
129        }
130    }
131
132    async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress> {
133        if !state.awaiting_input {
134            return Err(AuthError::InvalidState(
135                "Not expecting input at this stage".to_string(),
136            ));
137        }
138
139        // Validate the API key
140        self.validate_api_key(input)?;
141
142        // Store the API key
143        self.storage
144            .set_credential(
145                &self.provider_name(),
146                Credential::ApiKey {
147                    value: input.trim().to_string(),
148                },
149            )
150            .await
151            .map_err(|e| AuthError::Storage(format!("Failed to store API key: {e}")))?;
152
153        state.awaiting_input = false;
154        Ok(AuthProgress::Complete)
155    }
156
157    async fn is_authenticated(&self) -> Result<bool> {
158        Ok(self
159            .storage
160            .get_credential(&self.provider_name(), CredentialType::ApiKey)
161            .await?
162            .is_some())
163    }
164
165    fn provider_name(&self) -> String {
166        ApiKeyAuthFlow::provider_name(self)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::auth::{AuthStorage, Credential, CredentialType};
174    use async_trait::async_trait;
175    use std::collections::HashMap;
176    use tokio::sync::Mutex;
177
178    /// Mock implementation of AuthStorage for testing
179    struct MockAuthStorage {
180        credentials: Arc<Mutex<HashMap<(String, CredentialType), Credential>>>,
181    }
182
183    impl MockAuthStorage {
184        fn new() -> Self {
185            Self {
186                credentials: Arc::new(Mutex::new(HashMap::new())),
187            }
188        }
189    }
190
191    #[async_trait]
192    impl AuthStorage for MockAuthStorage {
193        async fn get_credential(
194            &self,
195            provider: &str,
196            credential_type: CredentialType,
197        ) -> Result<Option<Credential>> {
198            let creds = self.credentials.lock().await;
199            Ok(creds.get(&(provider.to_string(), credential_type)).cloned())
200        }
201
202        async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
203            let mut creds = self.credentials.lock().await;
204            let cred_type = match &credential {
205                Credential::ApiKey { .. } => CredentialType::ApiKey,
206                Credential::OAuth2 { .. } => CredentialType::OAuth2,
207            };
208            creds.insert((provider.to_string(), cred_type), credential);
209            Ok(())
210        }
211
212        async fn remove_credential(
213            &self,
214            provider: &str,
215            credential_type: CredentialType,
216        ) -> Result<()> {
217            let mut creds = self.credentials.lock().await;
218            creds.remove(&(provider.to_string(), credential_type));
219            Ok(())
220        }
221    }
222
223    #[tokio::test]
224    async fn test_api_key_flow() {
225        let storage = Arc::new(MockAuthStorage::new());
226        let auth_flow = ApiKeyAuthFlow::new(storage.clone(), provider::xai(), "xAI".to_string());
227
228        // Test available methods
229        let methods = auth_flow.available_methods();
230        assert_eq!(methods.len(), 1);
231        assert!(methods.contains(&AuthMethod::ApiKey));
232
233        // Start API key flow
234        let state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
235        assert!(state.awaiting_input);
236
237        // Get initial progress
238        let progress = auth_flow
239            .get_initial_progress(&state, AuthMethod::ApiKey)
240            .await
241            .unwrap();
242        match progress {
243            AuthProgress::NeedInput(msg) => assert_eq!(msg, "Enter your xAI API key"),
244            _ => panic!("Expected NeedInput progress"),
245        }
246
247        // Handle API key input
248        let mut state = state;
249        let progress = auth_flow
250            .handle_input(&mut state, "test-api-key-12345")
251            .await
252            .unwrap();
253        assert!(matches!(progress, AuthProgress::Complete));
254        assert!(!state.awaiting_input);
255
256        // Verify API key was stored
257        let cred = storage
258            .get_credential("xai", CredentialType::ApiKey)
259            .await
260            .unwrap();
261        assert!(cred.is_some());
262        if let Some(Credential::ApiKey { value }) = cred {
263            assert_eq!(value, "test-api-key-12345");
264        } else {
265            panic!("Expected API key credential");
266        }
267
268        // Verify authentication status
269        assert!(auth_flow.is_authenticated().await.unwrap());
270    }
271
272    #[tokio::test]
273    async fn test_empty_api_key() {
274        let storage = Arc::new(MockAuthStorage::new());
275        let auth_flow = ApiKeyAuthFlow::new(storage, provider::xai(), "xAI".to_string());
276
277        let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
278
279        // Test empty input
280        let result = auth_flow.handle_input(&mut state, "").await;
281        assert!(result.is_err());
282        match result.unwrap_err() {
283            AuthError::InvalidCredential(msg) => {
284                assert_eq!(msg, "API key cannot be empty");
285            }
286            _ => panic!("Expected InvalidCredential error"),
287        }
288    }
289
290    #[tokio::test]
291    async fn test_invalid_method() {
292        let storage = Arc::new(MockAuthStorage::new());
293        let auth_flow = ApiKeyAuthFlow::new(storage, provider::xai(), "xAI".to_string());
294
295        // Test with OAuth method
296        let result = auth_flow.start_auth(AuthMethod::OAuth).await;
297        assert!(result.is_err());
298        match result.unwrap_err() {
299            AuthError::UnsupportedMethod { method, provider } => {
300                assert_eq!(method, "OAuth");
301                assert_eq!(provider, "xAI");
302            }
303            _ => panic!("Expected UnsupportedMethod error"),
304        }
305    }
306
307    #[tokio::test]
308    async fn test_openai_key_validation() {
309        let storage = Arc::new(MockAuthStorage::new());
310        let auth_flow = ApiKeyAuthFlow::new(storage, provider::openai(), "OpenAI".to_string());
311
312        let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
313
314        // Test invalid OpenAI key format
315        let result = auth_flow.handle_input(&mut state, "invalid-key").await;
316        assert!(result.is_err());
317        match result.unwrap_err() {
318            AuthError::InvalidCredential(msg) => {
319                assert!(msg.contains("OpenAI API keys should start with 'sk-'"));
320            }
321            _ => panic!("Expected InvalidCredential error"),
322        }
323
324        // Test valid OpenAI key format
325        let result = auth_flow
326            .handle_input(&mut state, "sk-1234567890abcdef1234567890")
327            .await;
328        assert!(result.is_ok());
329    }
330
331    #[tokio::test]
332    async fn test_api_key_with_spaces() {
333        let storage = Arc::new(MockAuthStorage::new());
334        let auth_flow = ApiKeyAuthFlow::new(storage, provider::xai(), "xAI".to_string());
335
336        let mut state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
337
338        // Test API key with spaces
339        let result = auth_flow
340            .handle_input(&mut state, "test key with spaces")
341            .await;
342        assert!(result.is_err());
343        match result.unwrap_err() {
344            AuthError::InvalidCredential(msg) => {
345                assert_eq!(msg, "API key should not contain spaces");
346            }
347            _ => panic!("Expected InvalidCredential error"),
348        }
349    }
350}