steer_core/auth/
anthropic.rs

1use crate::api::ProviderKind;
2use crate::auth::{AuthError, AuthStorage, AuthTokens, Credential, CredentialType, Result};
3use crate::auth::{AuthMethod, AuthProgress, AuthenticationFlow};
4use async_trait::async_trait;
5use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10
11// OAuth constants
12const AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
13const TOKEN_URL: &str = "https://console.anthropic.com/v1/oauth/token";
14const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
15const REDIRECT_URI: &str = "https://console.anthropic.com/oauth/code/callback";
16const SCOPES: &str = "org:create_api_key user:profile user:inference";
17
18#[derive(Debug)]
19pub struct PkceChallenge {
20    pub verifier: String,
21    pub challenge: String,
22}
23
24pub struct AnthropicOAuth {
25    client_id: String,
26    redirect_uri: String,
27    http_client: reqwest::Client,
28}
29
30impl Default for AnthropicOAuth {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl AnthropicOAuth {
37    pub fn new() -> Self {
38        Self {
39            client_id: CLIENT_ID.to_string(),
40            redirect_uri: REDIRECT_URI.to_string(),
41            http_client: reqwest::Client::new(),
42        }
43    }
44
45    /// Generate PKCE challenge
46    pub fn generate_pkce() -> PkceChallenge {
47        let verifier = generate_random_string(128);
48        let challenge = base64_url_encode(&sha256(&verifier));
49        PkceChallenge {
50            verifier,
51            challenge,
52        }
53    }
54
55    /// Build authorization URL
56    pub fn build_auth_url(&self, pkce: &PkceChallenge) -> String {
57        // Use the PKCE verifier as the state parameter
58        let params = [
59            ("code", "true"),
60            ("client_id", &self.client_id),
61            ("response_type", "code"),
62            ("redirect_uri", &self.redirect_uri),
63            ("scope", SCOPES),
64            ("code_challenge", &pkce.challenge),
65            ("code_challenge_method", "S256"),
66            ("state", &pkce.verifier),
67        ];
68
69        let query = serde_urlencoded::to_string(params).unwrap();
70        format!("{AUTHORIZE_URL}?{query}")
71    }
72
73    /// Parse the callback code from the redirect URL
74    /// The format is: code#state
75    pub fn parse_callback_code(callback_code: &str) -> Result<(String, String)> {
76        let parts: Vec<&str> = callback_code.split('#').collect();
77        if parts.len() != 2 {
78            return Err(AuthError::InvalidResponse(
79                "Invalid callback code format. Expected format: code#state".to_string(),
80            ));
81        }
82        Ok((parts[0].to_string(), parts[1].to_string()))
83    }
84
85    /// Exchange authorization code for tokens
86    pub async fn exchange_code_for_tokens(
87        &self,
88        code: &str,
89        state: &str,
90        pkce_verifier: &str,
91    ) -> Result<AuthTokens> {
92        #[derive(Serialize)]
93        struct TokenRequest {
94            code: String,
95            state: String,
96            grant_type: String,
97            client_id: String,
98            redirect_uri: String,
99            code_verifier: String,
100        }
101
102        #[derive(Deserialize)]
103        struct TokenResponse {
104            access_token: String,
105            refresh_token: String,
106            expires_in: u64,
107        }
108
109        let request = TokenRequest {
110            code: code.to_string(),
111            state: state.to_string(),
112            grant_type: "authorization_code".to_string(),
113            client_id: self.client_id.clone(),
114            redirect_uri: self.redirect_uri.clone(),
115            code_verifier: pkce_verifier.to_string(),
116        };
117
118        let response = self
119            .http_client
120            .post(TOKEN_URL)
121            .json(&request)
122            .send()
123            .await?;
124
125        if !response.status().is_success() {
126            let status = response.status();
127            let error_text = response
128                .text()
129                .await
130                .unwrap_or_else(|_| "Unknown error".to_string());
131            return Err(AuthError::InvalidResponse(format!(
132                "Token exchange failed with status {status}: {error_text}"
133            )));
134        }
135
136        let token_response: TokenResponse = response.json().await.map_err(|e| {
137            AuthError::InvalidResponse(format!("Failed to parse token response: {e}"))
138        })?;
139
140        let expires_at = SystemTime::now() + Duration::from_secs(token_response.expires_in);
141
142        Ok(AuthTokens {
143            access_token: token_response.access_token,
144            refresh_token: token_response.refresh_token,
145            expires_at,
146        })
147    }
148
149    /// Refresh access token using refresh token
150    pub async fn refresh_tokens(&self, refresh_token: &str) -> Result<AuthTokens> {
151        #[derive(Serialize)]
152        struct RefreshRequest {
153            grant_type: String,
154            refresh_token: String,
155            client_id: String,
156        }
157
158        #[derive(Deserialize)]
159        struct TokenResponse {
160            access_token: String,
161            refresh_token: String,
162            expires_in: u64,
163        }
164
165        let request = RefreshRequest {
166            grant_type: "refresh_token".to_string(),
167            refresh_token: refresh_token.to_string(),
168            client_id: self.client_id.clone(),
169        };
170
171        let response = self
172            .http_client
173            .post(TOKEN_URL)
174            .json(&request)
175            .send()
176            .await?;
177
178        if !response.status().is_success() {
179            if response.status() == reqwest::StatusCode::UNAUTHORIZED {
180                return Err(AuthError::ReauthRequired);
181            }
182
183            let status = response.status();
184            let error_text = response
185                .text()
186                .await
187                .unwrap_or_else(|_| "Unknown error".to_string());
188            return Err(AuthError::InvalidResponse(format!(
189                "Token refresh failed with status {status}: {error_text}"
190            )));
191        }
192
193        let token_response: TokenResponse = response.json().await.map_err(|e| {
194            AuthError::InvalidResponse(format!("Failed to parse refresh response: {e}"))
195        })?;
196
197        let expires_at = SystemTime::now() + Duration::from_secs(token_response.expires_in);
198
199        Ok(AuthTokens {
200            access_token: token_response.access_token,
201            refresh_token: token_response.refresh_token,
202            expires_at,
203        })
204    }
205}
206
207/// Check if tokens need refresh (within 5 minutes of expiry)
208pub fn tokens_need_refresh(tokens: &AuthTokens) -> bool {
209    match tokens.expires_at.duration_since(SystemTime::now()) {
210        Ok(duration) => duration.as_secs() <= 300, // 5 minutes
211        Err(_) => true,                            // Already expired
212    }
213}
214
215/// Get OAuth headers for Anthropic API requests
216pub fn get_oauth_headers(access_token: &str) -> Vec<(String, String)> {
217    vec![
218        (
219            "authorization".to_string(),
220            format!("Bearer {access_token}"),
221        ),
222        ("anthropic-beta".to_string(), "oauth-2025-04-20".to_string()),
223    ]
224}
225
226/// Helper to refresh tokens if needed
227pub async fn refresh_if_needed(
228    storage: &Arc<dyn AuthStorage>,
229    oauth_client: &AnthropicOAuth,
230) -> Result<AuthTokens> {
231    let credential = storage
232        .get_credential("anthropic", CredentialType::AuthTokens)
233        .await?
234        .ok_or(AuthError::ReauthRequired)?;
235
236    let mut tokens = match credential {
237        Credential::AuthTokens(tokens) => tokens,
238        _ => return Err(AuthError::ReauthRequired),
239    };
240
241    if tokens_need_refresh(&tokens) {
242        // Try to refresh
243        match oauth_client.refresh_tokens(&tokens.refresh_token).await {
244            Ok(new_tokens) => {
245                storage
246                    .set_credential("anthropic", Credential::AuthTokens(new_tokens.clone()))
247                    .await?;
248                tokens = new_tokens;
249            }
250            Err(AuthError::ReauthRequired) => {
251                // Refresh token is invalid, clear tokens
252                storage
253                    .remove_credential("anthropic", CredentialType::AuthTokens)
254                    .await?;
255                return Err(AuthError::ReauthRequired);
256            }
257            Err(e) => return Err(e),
258        }
259    }
260
261    Ok(tokens)
262}
263
264// Helper functions
265fn generate_random_string(length: usize) -> String {
266    use rand::Rng;
267
268    const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
269    let mut rng = rand::thread_rng();
270
271    (0..length)
272        .map(|_| {
273            let idx = rng.gen_range(0..CHARSET.len());
274            CHARSET[idx] as char
275        })
276        .collect()
277}
278
279fn sha256(data: &str) -> Vec<u8> {
280    let mut hasher = Sha256::new();
281    hasher.update(data.as_bytes());
282    hasher.finalize().to_vec()
283}
284
285fn base64_url_encode(data: &[u8]) -> String {
286    URL_SAFE_NO_PAD.encode(data)
287}
288
289/// State for the Anthropic authentication flow
290#[derive(Debug, Clone)]
291pub struct AnthropicAuthState {
292    pub kind: AnthropicAuthStateKind,
293}
294
295#[derive(Debug, Clone)]
296pub enum AnthropicAuthStateKind {
297    /// Initial state - choosing auth method
298    Initial,
299    /// OAuth flow started, waiting for redirect URL
300    OAuthStarted { verifier: String, auth_url: String },
301    /// Waiting for API key input
302    AwaitingApiKey,
303}
304
305/// Anthropic-specific authentication flow implementation
306pub struct AnthropicOAuthFlow {
307    storage: Arc<dyn AuthStorage>,
308    oauth_client: AnthropicOAuth,
309}
310
311impl AnthropicOAuthFlow {
312    pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
313        Self {
314            storage,
315            oauth_client: AnthropicOAuth::new(),
316        }
317    }
318}
319
320#[async_trait]
321impl AuthenticationFlow for AnthropicOAuthFlow {
322    type State = AnthropicAuthState;
323
324    fn available_methods(&self) -> Vec<AuthMethod> {
325        vec![AuthMethod::OAuth, AuthMethod::ApiKey]
326    }
327
328    async fn start_auth(&self, method: AuthMethod) -> Result<Self::State> {
329        match method {
330            AuthMethod::OAuth => {
331                let pkce = AnthropicOAuth::generate_pkce();
332                let auth_url = self.oauth_client.build_auth_url(&pkce);
333
334                Ok(AnthropicAuthState {
335                    kind: AnthropicAuthStateKind::OAuthStarted {
336                        verifier: pkce.verifier,
337                        auth_url,
338                    },
339                })
340            }
341            AuthMethod::ApiKey => Ok(AnthropicAuthState {
342                kind: AnthropicAuthStateKind::AwaitingApiKey,
343            }),
344        }
345    }
346
347    async fn get_initial_progress(
348        &self,
349        state: &Self::State,
350        method: AuthMethod,
351    ) -> Result<AuthProgress> {
352        match method {
353            AuthMethod::OAuth => {
354                if let AnthropicAuthStateKind::OAuthStarted { auth_url, .. } = &state.kind {
355                    Ok(AuthProgress::OAuthStarted {
356                        auth_url: auth_url.clone(),
357                    })
358                } else {
359                    Err(AuthError::InvalidState(
360                        "Invalid state for OAuth".to_string(),
361                    ))
362                }
363            }
364            AuthMethod::ApiKey => Ok(AuthProgress::NeedInput("Enter your API key".to_string())),
365        }
366    }
367
368    async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress> {
369        match &mut state.kind {
370            AnthropicAuthStateKind::Initial => Err(AuthError::InvalidState(
371                "No input expected in initial state".to_string(),
372            )),
373            AnthropicAuthStateKind::OAuthStarted { verifier, .. } => {
374                // Check if the input contains a redirect URL
375                let (code, state_param) = if input.contains("code=") && input.contains("state=") {
376                    // User pasted the full redirect URL
377                    let url = reqwest::Url::parse(input).map_err(|_| {
378                        AuthError::InvalidCredential("Invalid redirect URL".to_string())
379                    })?;
380
381                    let params: std::collections::HashMap<_, _> = url.query_pairs().collect();
382                    let code = params
383                        .get("code")
384                        .ok_or_else(|| AuthError::MissingInput("code parameter".to_string()))?;
385                    let state = params
386                        .get("state")
387                        .ok_or_else(|| AuthError::MissingInput("state parameter".to_string()))?;
388
389                    (code.to_string(), state.to_string())
390                } else {
391                    // Legacy: try to parse as callback code
392                    AnthropicOAuth::parse_callback_code(input)?
393                };
394
395                // Exchange code for tokens
396                let tokens = self
397                    .oauth_client
398                    .exchange_code_for_tokens(&code, &state_param, verifier)
399                    .await?;
400
401                // Store the tokens
402                self.storage
403                    .set_credential("anthropic", Credential::AuthTokens(tokens))
404                    .await?;
405
406                Ok(AuthProgress::Complete)
407            }
408            AnthropicAuthStateKind::AwaitingApiKey => {
409                if input.trim().is_empty() {
410                    return Err(AuthError::InvalidCredential(
411                        "API key cannot be empty".to_string(),
412                    ));
413                }
414
415                // Store the API key
416                self.storage
417                    .set_credential(
418                        "anthropic",
419                        Credential::ApiKey {
420                            value: input.to_string(),
421                        },
422                    )
423                    .await?;
424
425                Ok(AuthProgress::Complete)
426            }
427        }
428    }
429
430    async fn is_authenticated(&self) -> Result<bool> {
431        // Check for OAuth tokens first
432        if let Some(Credential::AuthTokens(tokens)) = self
433            .storage
434            .get_credential("anthropic", CredentialType::AuthTokens)
435            .await?
436        {
437            // Check if tokens are still valid
438            return Ok(!tokens_need_refresh(&tokens));
439        }
440
441        // Check for API key
442        Ok(self
443            .storage
444            .get_credential("anthropic", CredentialType::ApiKey)
445            .await?
446            .is_some())
447    }
448
449    fn provider_name(&self) -> String {
450        ProviderKind::Anthropic.to_string()
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use crate::auth::{AuthStorage, Credential, CredentialType};
458    use async_trait::async_trait;
459    use std::collections::HashMap;
460    use tokio::sync::Mutex;
461
462    #[test]
463    fn test_pkce_generation() {
464        let pkce = AnthropicOAuth::generate_pkce();
465
466        // Verifier should be 128 characters
467        assert_eq!(pkce.verifier.len(), 128);
468
469        // Challenge should be base64url encoded SHA256 (43 chars)
470        assert_eq!(pkce.challenge.len(), 43);
471
472        // Verify challenge is correctly derived from verifier
473        let expected_challenge = base64_url_encode(&sha256(&pkce.verifier));
474        assert_eq!(pkce.challenge, expected_challenge);
475    }
476
477    #[test]
478    fn test_state_generation() {
479        let pkce = AnthropicOAuth::generate_pkce();
480        // State is now the PKCE verifier
481        assert_eq!(pkce.verifier.len(), 128);
482    }
483
484    #[test]
485    fn test_auth_url_building() {
486        let oauth = AnthropicOAuth::new();
487        let pkce = AnthropicOAuth::generate_pkce();
488
489        let url = oauth.build_auth_url(&pkce);
490
491        assert!(url.contains(AUTHORIZE_URL));
492        assert!(url.contains(&format!("client_id={CLIENT_ID}")));
493        assert!(url.contains("response_type=code"));
494        // The verifier might contain URL-unsafe characters that get encoded
495        assert!(url.contains("state="));
496        assert!(url.contains(&format!("code_challenge={}", &pkce.challenge)));
497        assert!(url.contains("code_challenge_method=S256"));
498        assert!(url.contains("code=true"));
499        // Verify redirect URI is properly encoded
500        assert!(url.contains(
501            "redirect_uri=https%3A%2F%2Fconsole.anthropic.com%2Foauth%2Fcode%2Fcallback"
502        ));
503    }
504
505    #[test]
506    fn test_parse_callback_code() {
507        // Valid format
508        let (code, state) = AnthropicOAuth::parse_callback_code("abc123#xyz789").unwrap();
509        assert_eq!(code, "abc123");
510        assert_eq!(state, "xyz789");
511
512        // Invalid format - no hash
513        assert!(AnthropicOAuth::parse_callback_code("abc123").is_err());
514
515        // Invalid format - multiple hashes
516        assert!(AnthropicOAuth::parse_callback_code("abc#123#xyz").is_err());
517    }
518
519    /// Mock implementation of AuthStorage for testing
520    struct MockAuthStorage {
521        credentials: Arc<Mutex<HashMap<String, Credential>>>,
522    }
523
524    impl MockAuthStorage {
525        fn new() -> Self {
526            Self {
527                credentials: Arc::new(Mutex::new(HashMap::new())),
528            }
529        }
530    }
531
532    #[async_trait]
533    impl AuthStorage for MockAuthStorage {
534        async fn get_credential(
535            &self,
536            _provider: &str,
537            _credential_type: CredentialType,
538        ) -> Result<Option<Credential>> {
539            Ok(None)
540        }
541
542        async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
543            let mut creds = self.credentials.lock().await;
544            creds.insert(provider.to_string(), credential);
545            Ok(())
546        }
547
548        async fn remove_credential(
549            &self,
550            provider: &str,
551            _credential_type: CredentialType,
552        ) -> Result<()> {
553            let mut creds = self.credentials.lock().await;
554            creds.remove(provider);
555            Ok(())
556        }
557    }
558
559    #[tokio::test]
560    async fn test_auth_flow_api_key() {
561        let storage = Arc::new(MockAuthStorage::new());
562        let auth_flow = AnthropicOAuthFlow::new(storage.clone());
563
564        // Test available methods
565        let methods = auth_flow.available_methods();
566        assert_eq!(methods.len(), 2);
567        assert!(methods.contains(&AuthMethod::OAuth));
568        assert!(methods.contains(&AuthMethod::ApiKey));
569
570        // Start API key flow
571        let state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
572        assert!(matches!(state.kind, AnthropicAuthStateKind::AwaitingApiKey));
573
574        // Handle API key input
575        let mut state = state;
576        let progress = auth_flow
577            .handle_input(&mut state, "test-api-key")
578            .await
579            .unwrap();
580        assert!(matches!(progress, AuthProgress::Complete));
581
582        // Verify API key was stored
583        let creds = storage.credentials.lock().await;
584        assert!(creds.contains_key("anthropic"));
585        if let Some(Credential::ApiKey { value }) = creds.get("anthropic") {
586            assert_eq!(value, "test-api-key");
587        } else {
588            panic!("Expected API key credential");
589        }
590    }
591
592    #[tokio::test]
593    async fn test_auth_flow_oauth_start() {
594        let storage = Arc::new(MockAuthStorage::new());
595        let auth_flow = AnthropicOAuthFlow::new(storage);
596
597        // Start OAuth flow
598        let state = auth_flow.start_auth(AuthMethod::OAuth).await.unwrap();
599
600        if let AnthropicAuthStateKind::OAuthStarted { auth_url, verifier } = &state.kind {
601            // Verify auth URL contains required parameters
602            assert!(auth_url.contains(AUTHORIZE_URL));
603            assert!(auth_url.contains("client_id="));
604            assert!(auth_url.contains("code_challenge="));
605            assert!(!verifier.is_empty());
606        } else {
607            panic!("Expected OAuth started state");
608        }
609    }
610}