Skip to main content

steer_auth_anthropic/
lib.rs

1use async_trait::async_trait;
2use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
3use rand::Rng;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9
10use steer_auth_plugin::AuthPlugin;
11use steer_auth_plugin::{
12    AnthropicAuth, AuthDirective, AuthError, AuthErrorAction, AuthErrorContext, AuthHeaderContext,
13    AuthHeaderProvider, AuthMethod, AuthProgress, AuthStorage, AuthTokens, AuthenticationFlow,
14    Credential, CredentialType, DynAuthenticationFlow, HeaderPair, InstructionPolicy, ProviderId,
15    QueryParam, Result,
16};
17
18const PROVIDER_ID: &str = "anthropic";
19const AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
20const TOKEN_URL: &str = "https://console.anthropic.com/v1/oauth/token";
21const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
22const REDIRECT_URI: &str = "https://console.anthropic.com/oauth/code/callback";
23const SCOPES: &str = "org:create_api_key user:profile user:inference";
24
25#[derive(Debug)]
26pub struct PkceChallenge {
27    pub verifier: String,
28    pub challenge: String,
29}
30
31#[derive(Clone)]
32pub struct AnthropicOAuth {
33    client_id: String,
34    redirect_uri: String,
35    http_client: reqwest::Client,
36}
37
38impl Default for AnthropicOAuth {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl AnthropicOAuth {
45    pub fn new() -> Self {
46        Self {
47            client_id: CLIENT_ID.to_string(),
48            redirect_uri: REDIRECT_URI.to_string(),
49            http_client: reqwest::Client::new(),
50        }
51    }
52
53    /// Generate PKCE challenge
54    pub fn generate_pkce() -> PkceChallenge {
55        let verifier = generate_random_string(128);
56        let challenge = base64_url_encode(&sha256(&verifier));
57        PkceChallenge {
58            verifier,
59            challenge,
60        }
61    }
62
63    /// Build authorization URL
64    pub fn build_auth_url(&self, pkce: &PkceChallenge) -> String {
65        let params = [
66            ("code", "true"),
67            ("client_id", &self.client_id),
68            ("response_type", "code"),
69            ("redirect_uri", &self.redirect_uri),
70            ("scope", SCOPES),
71            ("code_challenge", &pkce.challenge),
72            ("code_challenge_method", "S256"),
73            ("state", &pkce.verifier),
74        ];
75
76        let query = serde_urlencoded::to_string(params).unwrap_or_default();
77        format!("{AUTHORIZE_URL}?{query}")
78    }
79
80    /// Parse the callback code from the redirect URL or query fragment.
81    pub fn parse_callback_code(callback_code: &str) -> Result<(String, String)> {
82        let trimmed = callback_code.trim();
83        if trimmed.is_empty() {
84            return Err(AuthError::InvalidResponse(
85                "Invalid callback code format. Expected a URL or code/state parameters."
86                    .to_string(),
87            ));
88        }
89
90        if let Ok(url) = reqwest::Url::parse(trimmed)
91            && let Some(pair) = extract_code_state_from_url(&url)
92        {
93            return Ok(pair);
94        }
95
96        if let Some(pair) = extract_code_state_from_str(trimmed) {
97            return Ok(pair);
98        }
99
100        if let Some(pair) = extract_legacy_code_state(trimmed) {
101            return Ok(pair);
102        }
103
104        Err(AuthError::InvalidResponse(
105            "Invalid callback code format. Expected a URL or code/state parameters.".to_string(),
106        ))
107    }
108
109    /// Exchange authorization code for tokens
110    pub async fn exchange_code_for_tokens(
111        &self,
112        code: &str,
113        state: &str,
114        pkce_verifier: &str,
115    ) -> Result<AuthTokens> {
116        #[derive(Serialize)]
117        struct TokenRequest {
118            code: String,
119            state: String,
120            grant_type: String,
121            client_id: String,
122            redirect_uri: String,
123            code_verifier: String,
124        }
125
126        #[derive(Deserialize)]
127        struct TokenResponse {
128            access_token: String,
129            refresh_token: String,
130            expires_in: u64,
131        }
132
133        let request = TokenRequest {
134            code: code.to_string(),
135            state: state.to_string(),
136            grant_type: "authorization_code".to_string(),
137            client_id: self.client_id.clone(),
138            redirect_uri: self.redirect_uri.clone(),
139            code_verifier: pkce_verifier.to_string(),
140        };
141
142        let response = self
143            .http_client
144            .post(TOKEN_URL)
145            .json(&request)
146            .send()
147            .await?;
148
149        if !response.status().is_success() {
150            let status = response.status();
151            let error_text = response
152                .text()
153                .await
154                .unwrap_or_else(|_| "Unknown error".to_string());
155            return Err(AuthError::InvalidResponse(format!(
156                "Token exchange failed with status {status}: {error_text}"
157            )));
158        }
159
160        let token_response: TokenResponse = response.json().await.map_err(|e| {
161            AuthError::InvalidResponse(format!("Failed to parse token response: {e}"))
162        })?;
163
164        let expires_at = SystemTime::now() + Duration::from_secs(token_response.expires_in);
165
166        Ok(AuthTokens {
167            access_token: token_response.access_token,
168            refresh_token: token_response.refresh_token,
169            expires_at,
170            id_token: None,
171        })
172    }
173
174    /// Refresh access token using refresh token
175    pub async fn refresh_tokens(&self, refresh_token: &str) -> Result<AuthTokens> {
176        #[derive(Serialize)]
177        struct RefreshRequest {
178            grant_type: String,
179            refresh_token: String,
180            client_id: String,
181        }
182
183        #[derive(Deserialize)]
184        struct TokenResponse {
185            access_token: String,
186            refresh_token: String,
187            expires_in: u64,
188        }
189
190        let request = RefreshRequest {
191            grant_type: "refresh_token".to_string(),
192            refresh_token: refresh_token.to_string(),
193            client_id: self.client_id.clone(),
194        };
195
196        let response = self
197            .http_client
198            .post(TOKEN_URL)
199            .json(&request)
200            .send()
201            .await?;
202
203        if !response.status().is_success() {
204            if response.status() == reqwest::StatusCode::UNAUTHORIZED {
205                return Err(AuthError::ReauthRequired);
206            }
207
208            let status = response.status();
209            let error_text = response
210                .text()
211                .await
212                .unwrap_or_else(|_| "Unknown error".to_string());
213            return Err(AuthError::InvalidResponse(format!(
214                "Token refresh failed with status {status}: {error_text}"
215            )));
216        }
217
218        let token_response: TokenResponse = response.json().await.map_err(|e| {
219            AuthError::InvalidResponse(format!("Failed to parse refresh response: {e}"))
220        })?;
221
222        let expires_at = SystemTime::now() + Duration::from_secs(token_response.expires_in);
223
224        Ok(AuthTokens {
225            access_token: token_response.access_token,
226            refresh_token: token_response.refresh_token,
227            expires_at,
228            id_token: None,
229        })
230    }
231}
232
233fn resolve_callback_input(input: &str, verifier: &str) -> Result<(String, String)> {
234    match AnthropicOAuth::parse_callback_code(input) {
235        Ok(pair) => Ok(pair),
236        Err(err) => {
237            let trimmed = input.trim();
238            let fallback_code = extract_code_only_from_str(trimmed).or_else(|| {
239                reqwest::Url::parse(trimmed)
240                    .ok()
241                    .and_then(|url| extract_code_only_from_url(&url))
242            });
243
244            if let Some(code) = fallback_code {
245                Ok((code, verifier.to_string()))
246            } else {
247                Err(err)
248            }
249        }
250    }
251}
252
253fn extract_code_state_from_url(url: &reqwest::Url) -> Option<(String, String)> {
254    if let Some(query) = url.query()
255        && let Some(pair) = extract_code_state_from_kv(query)
256    {
257        return Some(pair);
258    }
259
260    if let Some(fragment) = url.fragment()
261        && let Some(pair) = extract_code_state_from_kv(fragment)
262    {
263        return Some(pair);
264    }
265
266    None
267}
268
269fn extract_code_state_from_str(input: &str) -> Option<(String, String)> {
270    if let Some(pair) = extract_code_state_from_kv(input) {
271        return Some(pair);
272    }
273
274    if let Some(query_start) = input.find('?')
275        && let Some(pair) = extract_code_state_from_kv(&input[query_start + 1..])
276    {
277        return Some(pair);
278    }
279
280    if let Some(fragment_start) = input.find('#')
281        && let Some(pair) = extract_code_state_from_kv(&input[fragment_start + 1..])
282    {
283        return Some(pair);
284    }
285
286    None
287}
288
289fn extract_code_state_from_kv(raw: &str) -> Option<(String, String)> {
290    if raw.is_empty() {
291        return None;
292    }
293
294    let params: HashMap<String, String> = serde_urlencoded::from_str(raw).ok()?;
295    let code = params.get("code")?;
296    let state = params.get("state")?;
297    Some((code.clone(), state.clone()))
298}
299
300fn extract_code_only_from_url(url: &reqwest::Url) -> Option<String> {
301    if let Some(query) = url.query()
302        && let Some(code) = extract_code_only_from_kv(query)
303    {
304        return Some(code);
305    }
306
307    if let Some(fragment) = url.fragment()
308        && let Some(code) = extract_code_only_from_kv(fragment)
309    {
310        return Some(code);
311    }
312
313    None
314}
315
316fn extract_code_only_from_str(input: &str) -> Option<String> {
317    if let Some(code) = extract_code_only_from_kv(input) {
318        return Some(code);
319    }
320
321    if let Some(query_start) = input.find('?')
322        && let Some(code) = extract_code_only_from_kv(&input[query_start + 1..])
323    {
324        return Some(code);
325    }
326
327    if let Some(fragment_start) = input.find('#')
328        && let Some(code) = extract_code_only_from_kv(&input[fragment_start + 1..])
329    {
330        return Some(code);
331    }
332
333    None
334}
335
336fn extract_code_only_from_kv(raw: &str) -> Option<String> {
337    if raw.is_empty() {
338        return None;
339    }
340
341    let params: HashMap<String, String> = serde_urlencoded::from_str(raw).ok()?;
342    params.get("code").cloned()
343}
344
345fn extract_legacy_code_state(input: &str) -> Option<(String, String)> {
346    let parts: Vec<&str> = input.split('#').collect();
347    if parts.len() == 2 && !parts[0].is_empty() && !parts[1].is_empty() {
348        Some((parts[0].to_string(), parts[1].to_string()))
349    } else {
350        None
351    }
352}
353
354/// Check if tokens need refresh (within 5 minutes of expiry)
355pub fn tokens_need_refresh(tokens: &AuthTokens) -> bool {
356    match tokens.expires_at.duration_since(SystemTime::now()) {
357        Ok(duration) => duration.as_secs() <= 300,
358        Err(_) => true,
359    }
360}
361
362/// Get OAuth headers for Anthropic API requests
363pub fn get_oauth_headers(access_token: &str) -> Vec<HeaderPair> {
364    vec![
365        HeaderPair {
366            name: "authorization".to_string(),
367            value: format!("Bearer {access_token}"),
368        },
369        HeaderPair {
370            name: "anthropic-beta".to_string(),
371            value: "oauth-2025-04-20,interleaved-thinking-2025-05-14,claude-code-20250219"
372                .to_string(),
373        },
374        HeaderPair {
375            name: "user-agent".to_string(),
376            value: "claude-cli/2.1.2 (external, cli)".to_string(),
377        },
378    ]
379}
380
381/// Helper to refresh tokens if needed
382pub async fn refresh_if_needed(
383    storage: &Arc<dyn AuthStorage>,
384    oauth_client: &AnthropicOAuth,
385) -> Result<AuthTokens> {
386    let credential = storage
387        .get_credential(PROVIDER_ID, CredentialType::OAuth2)
388        .await?
389        .ok_or(AuthError::ReauthRequired)?;
390
391    let mut tokens = match credential {
392        Credential::OAuth2(tokens) => tokens,
393        Credential::ApiKey { .. } => return Err(AuthError::ReauthRequired),
394    };
395
396    if tokens_need_refresh(&tokens) {
397        match oauth_client.refresh_tokens(&tokens.refresh_token).await {
398            Ok(new_tokens) => {
399                storage
400                    .set_credential(PROVIDER_ID, Credential::OAuth2(new_tokens.clone()))
401                    .await?;
402                tokens = new_tokens;
403            }
404            Err(AuthError::ReauthRequired) => {
405                storage
406                    .remove_credential(PROVIDER_ID, CredentialType::OAuth2)
407                    .await?;
408                return Err(AuthError::ReauthRequired);
409            }
410            Err(e) => return Err(e),
411        }
412    }
413
414    Ok(tokens)
415}
416
417async fn force_refresh(
418    storage: &Arc<dyn AuthStorage>,
419    oauth_client: &AnthropicOAuth,
420) -> Result<AuthTokens> {
421    let credential = storage
422        .get_credential(PROVIDER_ID, CredentialType::OAuth2)
423        .await?
424        .ok_or(AuthError::ReauthRequired)?;
425
426    let tokens = match credential {
427        Credential::OAuth2(tokens) => tokens,
428        Credential::ApiKey { .. } => return Err(AuthError::ReauthRequired),
429    };
430
431    match oauth_client.refresh_tokens(&tokens.refresh_token).await {
432        Ok(new_tokens) => {
433            storage
434                .set_credential(PROVIDER_ID, Credential::OAuth2(new_tokens.clone()))
435                .await?;
436            Ok(new_tokens)
437        }
438        Err(AuthError::ReauthRequired) => {
439            storage
440                .remove_credential(PROVIDER_ID, CredentialType::OAuth2)
441                .await?;
442            Err(AuthError::ReauthRequired)
443        }
444        Err(err) => Err(err),
445    }
446}
447
448fn generate_random_string(length: usize) -> String {
449    const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
450    let mut rng = rand::thread_rng();
451
452    (0..length)
453        .map(|_| {
454            let idx = rng.gen_range(0..CHARSET.len());
455            CHARSET[idx] as char
456        })
457        .collect()
458}
459
460fn sha256(data: &str) -> Vec<u8> {
461    let mut hasher = Sha256::new();
462    hasher.update(data.as_bytes());
463    hasher.finalize().to_vec()
464}
465
466fn base64_url_encode(data: &[u8]) -> String {
467    URL_SAFE_NO_PAD.encode(data)
468}
469
470#[derive(Debug, Clone)]
471pub struct AnthropicAuthState {
472    pub kind: AnthropicAuthStateKind,
473}
474
475#[derive(Debug, Clone)]
476pub enum AnthropicAuthStateKind {
477    OAuthStarted { verifier: String, auth_url: String },
478}
479
480pub struct AnthropicOAuthFlow {
481    storage: Arc<dyn AuthStorage>,
482    oauth_client: AnthropicOAuth,
483}
484
485impl AnthropicOAuthFlow {
486    pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
487        Self {
488            storage,
489            oauth_client: AnthropicOAuth::new(),
490        }
491    }
492}
493
494#[async_trait]
495impl AuthenticationFlow for AnthropicOAuthFlow {
496    type State = AnthropicAuthState;
497
498    fn available_methods(&self) -> Vec<AuthMethod> {
499        vec![AuthMethod::OAuth]
500    }
501
502    async fn start_auth(&self, method: AuthMethod) -> Result<Self::State> {
503        match method {
504            AuthMethod::OAuth => {
505                let pkce = AnthropicOAuth::generate_pkce();
506                let auth_url = self.oauth_client.build_auth_url(&pkce);
507
508                Ok(AnthropicAuthState {
509                    kind: AnthropicAuthStateKind::OAuthStarted {
510                        verifier: pkce.verifier,
511                        auth_url,
512                    },
513                })
514            }
515            AuthMethod::ApiKey => Err(AuthError::UnsupportedMethod {
516                method: format!("{method:?}"),
517                provider: PROVIDER_ID.to_string(),
518            }),
519        }
520    }
521
522    async fn get_initial_progress(
523        &self,
524        state: &Self::State,
525        method: AuthMethod,
526    ) -> Result<AuthProgress> {
527        match method {
528            AuthMethod::OAuth => {
529                let AnthropicAuthStateKind::OAuthStarted { auth_url, .. } = &state.kind;
530                Ok(AuthProgress::OAuthStarted {
531                    auth_url: auth_url.clone(),
532                })
533            }
534            AuthMethod::ApiKey => Err(AuthError::UnsupportedMethod {
535                method: format!("{method:?}"),
536                provider: PROVIDER_ID.to_string(),
537            }),
538        }
539    }
540
541    async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress> {
542        match &mut state.kind {
543            AnthropicAuthStateKind::OAuthStarted { verifier, .. } => {
544                if input.trim().is_empty() {
545                    return Ok(AuthProgress::NeedInput(
546                        "Paste the redirect URL or code from your browser".to_string(),
547                    ));
548                }
549
550                let (code, state_param) = resolve_callback_input(input, verifier)?;
551
552                let tokens = self
553                    .oauth_client
554                    .exchange_code_for_tokens(&code, &state_param, verifier)
555                    .await?;
556
557                self.storage
558                    .set_credential(PROVIDER_ID, Credential::OAuth2(tokens))
559                    .await?;
560
561                Ok(AuthProgress::Complete)
562            }
563        }
564    }
565
566    async fn is_authenticated(&self) -> Result<bool> {
567        if let Some(Credential::OAuth2(tokens)) = self
568            .storage
569            .get_credential(PROVIDER_ID, CredentialType::OAuth2)
570            .await?
571        {
572            return Ok(!tokens_need_refresh(&tokens));
573        }
574
575        Ok(false)
576    }
577
578    fn provider_name(&self) -> String {
579        PROVIDER_ID.to_string()
580    }
581}
582
583#[derive(Clone)]
584struct AnthropicHeaderProvider {
585    storage: Arc<dyn AuthStorage>,
586    oauth: AnthropicOAuth,
587}
588
589impl AnthropicHeaderProvider {
590    fn new(storage: Arc<dyn AuthStorage>) -> Self {
591        Self {
592            storage,
593            oauth: AnthropicOAuth::new(),
594        }
595    }
596
597    async fn header_pairs(&self, _ctx: AuthHeaderContext) -> Result<Vec<HeaderPair>> {
598        let tokens = refresh_if_needed(&self.storage, &self.oauth).await?;
599        Ok(get_oauth_headers(&tokens.access_token))
600    }
601}
602
603#[async_trait]
604impl AuthHeaderProvider for AnthropicHeaderProvider {
605    async fn headers(&self, ctx: AuthHeaderContext) -> Result<Vec<HeaderPair>> {
606        self.header_pairs(ctx).await
607    }
608
609    async fn on_auth_error(&self, _ctx: AuthErrorContext) -> Result<AuthErrorAction> {
610        match force_refresh(&self.storage, &self.oauth).await {
611            Ok(_) => Ok(AuthErrorAction::RetryOnce),
612            Err(AuthError::ReauthRequired) => Ok(AuthErrorAction::ReauthRequired),
613            Err(err) => Err(err),
614        }
615    }
616}
617
618#[derive(Clone)]
619pub struct AnthropicAuthPlugin;
620
621impl Default for AnthropicAuthPlugin {
622    fn default() -> Self {
623        Self::new()
624    }
625}
626
627impl AnthropicAuthPlugin {
628    pub fn new() -> Self {
629        Self
630    }
631}
632
633#[async_trait]
634impl AuthPlugin for AnthropicAuthPlugin {
635    fn provider_id(&self) -> ProviderId {
636        ProviderId(PROVIDER_ID.to_string())
637    }
638
639    fn supported_methods(&self) -> Vec<AuthMethod> {
640        vec![AuthMethod::OAuth]
641    }
642
643    fn create_flow(&self, storage: Arc<dyn AuthStorage>) -> Option<Box<dyn DynAuthenticationFlow>> {
644        Some(Box::new(steer_auth_plugin::AuthFlowWrapper::new(
645            AnthropicOAuthFlow::new(storage),
646        )))
647    }
648
649    async fn resolve_auth(&self, storage: Arc<dyn AuthStorage>) -> Result<Option<AuthDirective>> {
650        let is_authenticated = self.is_authenticated(storage.clone()).await?;
651        if !is_authenticated {
652            return Ok(None);
653        }
654
655        let headers = Arc::new(AnthropicHeaderProvider::new(storage));
656        let directive = AnthropicAuth {
657            headers,
658            instruction_policy: Some(InstructionPolicy::Prefix(
659                "You are Claude Code, Anthropic's official CLI for Claude.".to_string(),
660            )),
661            query_params: Some(vec![QueryParam {
662                name: "beta".to_string(),
663                value: "true".to_string(),
664            }]),
665        };
666
667        Ok(Some(AuthDirective::Anthropic(directive)))
668    }
669
670    async fn is_authenticated(&self, storage: Arc<dyn AuthStorage>) -> Result<bool> {
671        if let Some(Credential::OAuth2(tokens)) = storage
672            .get_credential(PROVIDER_ID, CredentialType::OAuth2)
673            .await?
674        {
675            return Ok(!tokens_need_refresh(&tokens));
676        }
677
678        Ok(false)
679    }
680}
681
682#[cfg(test)]
683mod tests {
684    use super::*;
685    use std::collections::HashMap;
686    use tokio::sync::Mutex;
687
688    struct TestAuthStorage {
689        credentials: Arc<Mutex<HashMap<String, Credential>>>,
690    }
691
692    impl TestAuthStorage {
693        fn new() -> Self {
694            Self {
695                credentials: Arc::new(Mutex::new(HashMap::new())),
696            }
697        }
698    }
699
700    #[async_trait]
701    impl AuthStorage for TestAuthStorage {
702        async fn get_credential(
703            &self,
704            provider: &str,
705            credential_type: CredentialType,
706        ) -> Result<Option<Credential>> {
707            let key = format!("{provider}-{credential_type}");
708            Ok(self.credentials.lock().await.get(&key).cloned())
709        }
710
711        async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
712            let key = format!("{}-{}", provider, credential.credential_type());
713            self.credentials.lock().await.insert(key, credential);
714            Ok(())
715        }
716
717        async fn remove_credential(
718            &self,
719            provider: &str,
720            credential_type: CredentialType,
721        ) -> Result<()> {
722            let key = format!("{provider}-{credential_type}");
723            self.credentials.lock().await.remove(&key);
724            Ok(())
725        }
726    }
727
728    #[test]
729    fn test_pkce_generation() {
730        let pkce = AnthropicOAuth::generate_pkce();
731
732        assert_eq!(pkce.verifier.len(), 128);
733        assert_eq!(pkce.challenge.len(), 43);
734
735        let expected_challenge = base64_url_encode(&sha256(&pkce.verifier));
736        assert_eq!(pkce.challenge, expected_challenge);
737    }
738
739    #[test]
740    fn test_state_generation() {
741        let pkce1 = AnthropicOAuth::generate_pkce();
742        let pkce2 = AnthropicOAuth::generate_pkce();
743
744        assert_ne!(pkce1.verifier, pkce2.verifier);
745    }
746
747    #[test]
748    fn test_build_auth_url() {
749        let oauth = AnthropicOAuth::new();
750        let pkce = AnthropicOAuth::generate_pkce();
751        let url = oauth.build_auth_url(&pkce);
752
753        assert!(url.contains(AUTHORIZE_URL));
754        assert!(url.contains(&format!("client_id={CLIENT_ID}")));
755        assert!(url.contains("response_type=code"));
756        assert!(url.contains("code_challenge="));
757        assert!(url.contains("code_challenge_method=S256"));
758        assert!(url.contains(
759            "redirect_uri=https%3A%2F%2Fconsole.anthropic.com%2Foauth%2Fcode%2Fcallback"
760        ));
761    }
762
763    #[test]
764    fn test_parse_callback_code_from_url() {
765        let input = "https://console.anthropic.com/oauth/code/callback?code=abc123&state=state456";
766        let (code, state) = AnthropicOAuth::parse_callback_code(input).unwrap();
767        assert_eq!(code, "abc123");
768        assert_eq!(state, "state456");
769    }
770
771    #[test]
772    fn test_parse_callback_code_from_query() {
773        let input = "code=abc123&state=state456";
774        let (code, state) = AnthropicOAuth::parse_callback_code(input).unwrap();
775        assert_eq!(code, "abc123");
776        assert_eq!(state, "state456");
777    }
778
779    #[test]
780    fn test_parse_callback_code_from_fragment() {
781        let input = "https://console.anthropic.com/oauth/code/callback#code=abc123&state=state456";
782        let (code, state) = AnthropicOAuth::parse_callback_code(input).unwrap();
783        assert_eq!(code, "abc123");
784        assert_eq!(state, "state456");
785    }
786
787    #[test]
788    fn test_parse_callback_code_legacy() {
789        let input = "abc123#state456";
790        let (code, state) = AnthropicOAuth::parse_callback_code(input).unwrap();
791        assert_eq!(code, "abc123");
792        assert_eq!(state, "state456");
793    }
794
795    #[test]
796    fn test_extract_code_only_from_query() {
797        let input = "code=abc123";
798        let code = extract_code_only_from_str(input).unwrap();
799        assert_eq!(code, "abc123");
800    }
801
802    #[test]
803    fn test_extract_code_only_from_url() {
804        let input = "https://console.anthropic.com/oauth/code/callback?code=abc123";
805        let code = extract_code_only_from_str(input).unwrap();
806        assert_eq!(code, "abc123");
807    }
808
809    #[test]
810    fn test_extract_code_only_from_fragment() {
811        let input = "https://console.anthropic.com/oauth/code/callback#code=abc123";
812        let code = extract_code_only_from_str(input).unwrap();
813        assert_eq!(code, "abc123");
814    }
815
816    #[test]
817    fn test_resolve_callback_input_code_only_uses_verifier() {
818        let (code, state) = resolve_callback_input("code=abc123", "verifier-123").unwrap();
819        assert_eq!(code, "abc123");
820        assert_eq!(state, "verifier-123");
821    }
822
823    #[tokio::test]
824    async fn test_handle_input_empty_returns_need_input() {
825        let storage = Arc::new(TestAuthStorage::new());
826        let flow = AnthropicOAuthFlow::new(storage);
827        let mut state = flow.start_auth(AuthMethod::OAuth).await.unwrap();
828
829        let progress = flow.handle_input(&mut state, "").await.unwrap();
830
831        match progress {
832            AuthProgress::NeedInput(message) => {
833                assert!(message.contains("Paste the redirect URL"));
834            }
835            other => panic!("Expected NeedInput, got {other:?}"),
836        }
837    }
838
839    #[test]
840    fn test_get_oauth_headers() {
841        let headers = get_oauth_headers("test-token");
842
843        assert_eq!(headers.len(), 3);
844
845        let auth = headers.iter().find(|h| h.name == "authorization").unwrap();
846        assert_eq!(auth.value, "Bearer test-token");
847
848        let beta = headers.iter().find(|h| h.name == "anthropic-beta").unwrap();
849        assert!(beta.value.contains("oauth-2025-04-20"));
850        assert!(beta.value.contains("interleaved-thinking-2025-05-14"));
851        assert!(beta.value.contains("claude-code-20250219"));
852
853        let ua = headers.iter().find(|h| h.name == "user-agent").unwrap();
854        assert_eq!(ua.value, "claude-cli/2.1.2 (external, cli)");
855    }
856}