1use async_trait::async_trait;
2use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
3use rand::Rng;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9
10use steer_auth_plugin::AuthPlugin;
11use steer_auth_plugin::{
12 AnthropicAuth, AuthDirective, AuthError, AuthErrorAction, AuthErrorContext, AuthHeaderContext,
13 AuthHeaderProvider, AuthMethod, AuthProgress, AuthStorage, AuthTokens, AuthenticationFlow,
14 Credential, CredentialType, DynAuthenticationFlow, HeaderPair, InstructionPolicy, ProviderId,
15 QueryParam, Result,
16};
17
18const PROVIDER_ID: &str = "anthropic";
19const AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
20const TOKEN_URL: &str = "https://console.anthropic.com/v1/oauth/token";
21const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
22const REDIRECT_URI: &str = "https://console.anthropic.com/oauth/code/callback";
23const SCOPES: &str = "org:create_api_key user:profile user:inference";
24
25#[derive(Debug)]
26pub struct PkceChallenge {
27 pub verifier: String,
28 pub challenge: String,
29}
30
31#[derive(Clone)]
32pub struct AnthropicOAuth {
33 client_id: String,
34 redirect_uri: String,
35 http_client: reqwest::Client,
36}
37
38impl Default for AnthropicOAuth {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl AnthropicOAuth {
45 pub fn new() -> Self {
46 Self {
47 client_id: CLIENT_ID.to_string(),
48 redirect_uri: REDIRECT_URI.to_string(),
49 http_client: reqwest::Client::new(),
50 }
51 }
52
53 pub fn generate_pkce() -> PkceChallenge {
55 let verifier = generate_random_string(128);
56 let challenge = base64_url_encode(&sha256(&verifier));
57 PkceChallenge {
58 verifier,
59 challenge,
60 }
61 }
62
63 pub fn build_auth_url(&self, pkce: &PkceChallenge) -> String {
65 let params = [
66 ("code", "true"),
67 ("client_id", &self.client_id),
68 ("response_type", "code"),
69 ("redirect_uri", &self.redirect_uri),
70 ("scope", SCOPES),
71 ("code_challenge", &pkce.challenge),
72 ("code_challenge_method", "S256"),
73 ("state", &pkce.verifier),
74 ];
75
76 let query = serde_urlencoded::to_string(params).unwrap_or_default();
77 format!("{AUTHORIZE_URL}?{query}")
78 }
79
80 pub fn parse_callback_code(callback_code: &str) -> Result<(String, String)> {
82 let trimmed = callback_code.trim();
83 if trimmed.is_empty() {
84 return Err(AuthError::InvalidResponse(
85 "Invalid callback code format. Expected a URL or code/state parameters."
86 .to_string(),
87 ));
88 }
89
90 if let Ok(url) = reqwest::Url::parse(trimmed)
91 && let Some(pair) = extract_code_state_from_url(&url)
92 {
93 return Ok(pair);
94 }
95
96 if let Some(pair) = extract_code_state_from_str(trimmed) {
97 return Ok(pair);
98 }
99
100 if let Some(pair) = extract_legacy_code_state(trimmed) {
101 return Ok(pair);
102 }
103
104 Err(AuthError::InvalidResponse(
105 "Invalid callback code format. Expected a URL or code/state parameters.".to_string(),
106 ))
107 }
108
109 pub async fn exchange_code_for_tokens(
111 &self,
112 code: &str,
113 state: &str,
114 pkce_verifier: &str,
115 ) -> Result<AuthTokens> {
116 #[derive(Serialize)]
117 struct TokenRequest {
118 code: String,
119 state: String,
120 grant_type: String,
121 client_id: String,
122 redirect_uri: String,
123 code_verifier: String,
124 }
125
126 #[derive(Deserialize)]
127 struct TokenResponse {
128 access_token: String,
129 refresh_token: String,
130 expires_in: u64,
131 }
132
133 let request = TokenRequest {
134 code: code.to_string(),
135 state: state.to_string(),
136 grant_type: "authorization_code".to_string(),
137 client_id: self.client_id.clone(),
138 redirect_uri: self.redirect_uri.clone(),
139 code_verifier: pkce_verifier.to_string(),
140 };
141
142 let response = self
143 .http_client
144 .post(TOKEN_URL)
145 .json(&request)
146 .send()
147 .await?;
148
149 if !response.status().is_success() {
150 let status = response.status();
151 let error_text = response
152 .text()
153 .await
154 .unwrap_or_else(|_| "Unknown error".to_string());
155 return Err(AuthError::InvalidResponse(format!(
156 "Token exchange failed with status {status}: {error_text}"
157 )));
158 }
159
160 let token_response: TokenResponse = response.json().await.map_err(|e| {
161 AuthError::InvalidResponse(format!("Failed to parse token response: {e}"))
162 })?;
163
164 let expires_at = SystemTime::now() + Duration::from_secs(token_response.expires_in);
165
166 Ok(AuthTokens {
167 access_token: token_response.access_token,
168 refresh_token: token_response.refresh_token,
169 expires_at,
170 id_token: None,
171 })
172 }
173
174 pub async fn refresh_tokens(&self, refresh_token: &str) -> Result<AuthTokens> {
176 #[derive(Serialize)]
177 struct RefreshRequest {
178 grant_type: String,
179 refresh_token: String,
180 client_id: String,
181 }
182
183 #[derive(Deserialize)]
184 struct TokenResponse {
185 access_token: String,
186 refresh_token: String,
187 expires_in: u64,
188 }
189
190 let request = RefreshRequest {
191 grant_type: "refresh_token".to_string(),
192 refresh_token: refresh_token.to_string(),
193 client_id: self.client_id.clone(),
194 };
195
196 let response = self
197 .http_client
198 .post(TOKEN_URL)
199 .json(&request)
200 .send()
201 .await?;
202
203 if !response.status().is_success() {
204 if response.status() == reqwest::StatusCode::UNAUTHORIZED {
205 return Err(AuthError::ReauthRequired);
206 }
207
208 let status = response.status();
209 let error_text = response
210 .text()
211 .await
212 .unwrap_or_else(|_| "Unknown error".to_string());
213 return Err(AuthError::InvalidResponse(format!(
214 "Token refresh failed with status {status}: {error_text}"
215 )));
216 }
217
218 let token_response: TokenResponse = response.json().await.map_err(|e| {
219 AuthError::InvalidResponse(format!("Failed to parse refresh response: {e}"))
220 })?;
221
222 let expires_at = SystemTime::now() + Duration::from_secs(token_response.expires_in);
223
224 Ok(AuthTokens {
225 access_token: token_response.access_token,
226 refresh_token: token_response.refresh_token,
227 expires_at,
228 id_token: None,
229 })
230 }
231}
232
233fn resolve_callback_input(input: &str, verifier: &str) -> Result<(String, String)> {
234 match AnthropicOAuth::parse_callback_code(input) {
235 Ok(pair) => Ok(pair),
236 Err(err) => {
237 let trimmed = input.trim();
238 let fallback_code = extract_code_only_from_str(trimmed).or_else(|| {
239 reqwest::Url::parse(trimmed)
240 .ok()
241 .and_then(|url| extract_code_only_from_url(&url))
242 });
243
244 if let Some(code) = fallback_code {
245 Ok((code, verifier.to_string()))
246 } else {
247 Err(err)
248 }
249 }
250 }
251}
252
253fn extract_code_state_from_url(url: &reqwest::Url) -> Option<(String, String)> {
254 if let Some(query) = url.query()
255 && let Some(pair) = extract_code_state_from_kv(query)
256 {
257 return Some(pair);
258 }
259
260 if let Some(fragment) = url.fragment()
261 && let Some(pair) = extract_code_state_from_kv(fragment)
262 {
263 return Some(pair);
264 }
265
266 None
267}
268
269fn extract_code_state_from_str(input: &str) -> Option<(String, String)> {
270 if let Some(pair) = extract_code_state_from_kv(input) {
271 return Some(pair);
272 }
273
274 if let Some(query_start) = input.find('?')
275 && let Some(pair) = extract_code_state_from_kv(&input[query_start + 1..])
276 {
277 return Some(pair);
278 }
279
280 if let Some(fragment_start) = input.find('#')
281 && let Some(pair) = extract_code_state_from_kv(&input[fragment_start + 1..])
282 {
283 return Some(pair);
284 }
285
286 None
287}
288
289fn extract_code_state_from_kv(raw: &str) -> Option<(String, String)> {
290 if raw.is_empty() {
291 return None;
292 }
293
294 let params: HashMap<String, String> = serde_urlencoded::from_str(raw).ok()?;
295 let code = params.get("code")?;
296 let state = params.get("state")?;
297 Some((code.clone(), state.clone()))
298}
299
300fn extract_code_only_from_url(url: &reqwest::Url) -> Option<String> {
301 if let Some(query) = url.query()
302 && let Some(code) = extract_code_only_from_kv(query)
303 {
304 return Some(code);
305 }
306
307 if let Some(fragment) = url.fragment()
308 && let Some(code) = extract_code_only_from_kv(fragment)
309 {
310 return Some(code);
311 }
312
313 None
314}
315
316fn extract_code_only_from_str(input: &str) -> Option<String> {
317 if let Some(code) = extract_code_only_from_kv(input) {
318 return Some(code);
319 }
320
321 if let Some(query_start) = input.find('?')
322 && let Some(code) = extract_code_only_from_kv(&input[query_start + 1..])
323 {
324 return Some(code);
325 }
326
327 if let Some(fragment_start) = input.find('#')
328 && let Some(code) = extract_code_only_from_kv(&input[fragment_start + 1..])
329 {
330 return Some(code);
331 }
332
333 None
334}
335
336fn extract_code_only_from_kv(raw: &str) -> Option<String> {
337 if raw.is_empty() {
338 return None;
339 }
340
341 let params: HashMap<String, String> = serde_urlencoded::from_str(raw).ok()?;
342 params.get("code").cloned()
343}
344
345fn extract_legacy_code_state(input: &str) -> Option<(String, String)> {
346 let parts: Vec<&str> = input.split('#').collect();
347 if parts.len() == 2 && !parts[0].is_empty() && !parts[1].is_empty() {
348 Some((parts[0].to_string(), parts[1].to_string()))
349 } else {
350 None
351 }
352}
353
354pub fn tokens_need_refresh(tokens: &AuthTokens) -> bool {
356 match tokens.expires_at.duration_since(SystemTime::now()) {
357 Ok(duration) => duration.as_secs() <= 300,
358 Err(_) => true,
359 }
360}
361
362pub fn get_oauth_headers(access_token: &str) -> Vec<HeaderPair> {
364 vec![
365 HeaderPair {
366 name: "authorization".to_string(),
367 value: format!("Bearer {access_token}"),
368 },
369 HeaderPair {
370 name: "anthropic-beta".to_string(),
371 value: "oauth-2025-04-20,interleaved-thinking-2025-05-14,claude-code-20250219"
372 .to_string(),
373 },
374 HeaderPair {
375 name: "user-agent".to_string(),
376 value: "claude-cli/2.1.2 (external, cli)".to_string(),
377 },
378 ]
379}
380
381pub async fn refresh_if_needed(
383 storage: &Arc<dyn AuthStorage>,
384 oauth_client: &AnthropicOAuth,
385) -> Result<AuthTokens> {
386 let credential = storage
387 .get_credential(PROVIDER_ID, CredentialType::OAuth2)
388 .await?
389 .ok_or(AuthError::ReauthRequired)?;
390
391 let mut tokens = match credential {
392 Credential::OAuth2(tokens) => tokens,
393 Credential::ApiKey { .. } => return Err(AuthError::ReauthRequired),
394 };
395
396 if tokens_need_refresh(&tokens) {
397 match oauth_client.refresh_tokens(&tokens.refresh_token).await {
398 Ok(new_tokens) => {
399 storage
400 .set_credential(PROVIDER_ID, Credential::OAuth2(new_tokens.clone()))
401 .await?;
402 tokens = new_tokens;
403 }
404 Err(AuthError::ReauthRequired) => {
405 storage
406 .remove_credential(PROVIDER_ID, CredentialType::OAuth2)
407 .await?;
408 return Err(AuthError::ReauthRequired);
409 }
410 Err(e) => return Err(e),
411 }
412 }
413
414 Ok(tokens)
415}
416
417async fn force_refresh(
418 storage: &Arc<dyn AuthStorage>,
419 oauth_client: &AnthropicOAuth,
420) -> Result<AuthTokens> {
421 let credential = storage
422 .get_credential(PROVIDER_ID, CredentialType::OAuth2)
423 .await?
424 .ok_or(AuthError::ReauthRequired)?;
425
426 let tokens = match credential {
427 Credential::OAuth2(tokens) => tokens,
428 Credential::ApiKey { .. } => return Err(AuthError::ReauthRequired),
429 };
430
431 match oauth_client.refresh_tokens(&tokens.refresh_token).await {
432 Ok(new_tokens) => {
433 storage
434 .set_credential(PROVIDER_ID, Credential::OAuth2(new_tokens.clone()))
435 .await?;
436 Ok(new_tokens)
437 }
438 Err(AuthError::ReauthRequired) => {
439 storage
440 .remove_credential(PROVIDER_ID, CredentialType::OAuth2)
441 .await?;
442 Err(AuthError::ReauthRequired)
443 }
444 Err(err) => Err(err),
445 }
446}
447
448fn generate_random_string(length: usize) -> String {
449 const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
450 let mut rng = rand::thread_rng();
451
452 (0..length)
453 .map(|_| {
454 let idx = rng.gen_range(0..CHARSET.len());
455 CHARSET[idx] as char
456 })
457 .collect()
458}
459
460fn sha256(data: &str) -> Vec<u8> {
461 let mut hasher = Sha256::new();
462 hasher.update(data.as_bytes());
463 hasher.finalize().to_vec()
464}
465
466fn base64_url_encode(data: &[u8]) -> String {
467 URL_SAFE_NO_PAD.encode(data)
468}
469
470#[derive(Debug, Clone)]
471pub struct AnthropicAuthState {
472 pub kind: AnthropicAuthStateKind,
473}
474
475#[derive(Debug, Clone)]
476pub enum AnthropicAuthStateKind {
477 OAuthStarted { verifier: String, auth_url: String },
478}
479
480pub struct AnthropicOAuthFlow {
481 storage: Arc<dyn AuthStorage>,
482 oauth_client: AnthropicOAuth,
483}
484
485impl AnthropicOAuthFlow {
486 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
487 Self {
488 storage,
489 oauth_client: AnthropicOAuth::new(),
490 }
491 }
492}
493
494#[async_trait]
495impl AuthenticationFlow for AnthropicOAuthFlow {
496 type State = AnthropicAuthState;
497
498 fn available_methods(&self) -> Vec<AuthMethod> {
499 vec![AuthMethod::OAuth]
500 }
501
502 async fn start_auth(&self, method: AuthMethod) -> Result<Self::State> {
503 match method {
504 AuthMethod::OAuth => {
505 let pkce = AnthropicOAuth::generate_pkce();
506 let auth_url = self.oauth_client.build_auth_url(&pkce);
507
508 Ok(AnthropicAuthState {
509 kind: AnthropicAuthStateKind::OAuthStarted {
510 verifier: pkce.verifier,
511 auth_url,
512 },
513 })
514 }
515 AuthMethod::ApiKey => Err(AuthError::UnsupportedMethod {
516 method: format!("{method:?}"),
517 provider: PROVIDER_ID.to_string(),
518 }),
519 }
520 }
521
522 async fn get_initial_progress(
523 &self,
524 state: &Self::State,
525 method: AuthMethod,
526 ) -> Result<AuthProgress> {
527 match method {
528 AuthMethod::OAuth => {
529 let AnthropicAuthStateKind::OAuthStarted { auth_url, .. } = &state.kind;
530 Ok(AuthProgress::OAuthStarted {
531 auth_url: auth_url.clone(),
532 })
533 }
534 AuthMethod::ApiKey => Err(AuthError::UnsupportedMethod {
535 method: format!("{method:?}"),
536 provider: PROVIDER_ID.to_string(),
537 }),
538 }
539 }
540
541 async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress> {
542 match &mut state.kind {
543 AnthropicAuthStateKind::OAuthStarted { verifier, .. } => {
544 if input.trim().is_empty() {
545 return Ok(AuthProgress::NeedInput(
546 "Paste the redirect URL or code from your browser".to_string(),
547 ));
548 }
549
550 let (code, state_param) = resolve_callback_input(input, verifier)?;
551
552 let tokens = self
553 .oauth_client
554 .exchange_code_for_tokens(&code, &state_param, verifier)
555 .await?;
556
557 self.storage
558 .set_credential(PROVIDER_ID, Credential::OAuth2(tokens))
559 .await?;
560
561 Ok(AuthProgress::Complete)
562 }
563 }
564 }
565
566 async fn is_authenticated(&self) -> Result<bool> {
567 if let Some(Credential::OAuth2(tokens)) = self
568 .storage
569 .get_credential(PROVIDER_ID, CredentialType::OAuth2)
570 .await?
571 {
572 return Ok(!tokens_need_refresh(&tokens));
573 }
574
575 Ok(false)
576 }
577
578 fn provider_name(&self) -> String {
579 PROVIDER_ID.to_string()
580 }
581}
582
583#[derive(Clone)]
584struct AnthropicHeaderProvider {
585 storage: Arc<dyn AuthStorage>,
586 oauth: AnthropicOAuth,
587}
588
589impl AnthropicHeaderProvider {
590 fn new(storage: Arc<dyn AuthStorage>) -> Self {
591 Self {
592 storage,
593 oauth: AnthropicOAuth::new(),
594 }
595 }
596
597 async fn header_pairs(&self, _ctx: AuthHeaderContext) -> Result<Vec<HeaderPair>> {
598 let tokens = refresh_if_needed(&self.storage, &self.oauth).await?;
599 Ok(get_oauth_headers(&tokens.access_token))
600 }
601}
602
603#[async_trait]
604impl AuthHeaderProvider for AnthropicHeaderProvider {
605 async fn headers(&self, ctx: AuthHeaderContext) -> Result<Vec<HeaderPair>> {
606 self.header_pairs(ctx).await
607 }
608
609 async fn on_auth_error(&self, _ctx: AuthErrorContext) -> Result<AuthErrorAction> {
610 match force_refresh(&self.storage, &self.oauth).await {
611 Ok(_) => Ok(AuthErrorAction::RetryOnce),
612 Err(AuthError::ReauthRequired) => Ok(AuthErrorAction::ReauthRequired),
613 Err(err) => Err(err),
614 }
615 }
616}
617
618#[derive(Clone)]
619pub struct AnthropicAuthPlugin;
620
621impl Default for AnthropicAuthPlugin {
622 fn default() -> Self {
623 Self::new()
624 }
625}
626
627impl AnthropicAuthPlugin {
628 pub fn new() -> Self {
629 Self
630 }
631}
632
633#[async_trait]
634impl AuthPlugin for AnthropicAuthPlugin {
635 fn provider_id(&self) -> ProviderId {
636 ProviderId(PROVIDER_ID.to_string())
637 }
638
639 fn supported_methods(&self) -> Vec<AuthMethod> {
640 vec![AuthMethod::OAuth]
641 }
642
643 fn create_flow(&self, storage: Arc<dyn AuthStorage>) -> Option<Box<dyn DynAuthenticationFlow>> {
644 Some(Box::new(steer_auth_plugin::AuthFlowWrapper::new(
645 AnthropicOAuthFlow::new(storage),
646 )))
647 }
648
649 async fn resolve_auth(&self, storage: Arc<dyn AuthStorage>) -> Result<Option<AuthDirective>> {
650 let is_authenticated = self.is_authenticated(storage.clone()).await?;
651 if !is_authenticated {
652 return Ok(None);
653 }
654
655 let headers = Arc::new(AnthropicHeaderProvider::new(storage));
656 let directive = AnthropicAuth {
657 headers,
658 instruction_policy: Some(InstructionPolicy::Prefix(
659 "You are Claude Code, Anthropic's official CLI for Claude.".to_string(),
660 )),
661 query_params: Some(vec![QueryParam {
662 name: "beta".to_string(),
663 value: "true".to_string(),
664 }]),
665 };
666
667 Ok(Some(AuthDirective::Anthropic(directive)))
668 }
669
670 async fn is_authenticated(&self, storage: Arc<dyn AuthStorage>) -> Result<bool> {
671 if let Some(Credential::OAuth2(tokens)) = storage
672 .get_credential(PROVIDER_ID, CredentialType::OAuth2)
673 .await?
674 {
675 return Ok(!tokens_need_refresh(&tokens));
676 }
677
678 Ok(false)
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685 use std::collections::HashMap;
686 use tokio::sync::Mutex;
687
688 struct TestAuthStorage {
689 credentials: Arc<Mutex<HashMap<String, Credential>>>,
690 }
691
692 impl TestAuthStorage {
693 fn new() -> Self {
694 Self {
695 credentials: Arc::new(Mutex::new(HashMap::new())),
696 }
697 }
698 }
699
700 #[async_trait]
701 impl AuthStorage for TestAuthStorage {
702 async fn get_credential(
703 &self,
704 provider: &str,
705 credential_type: CredentialType,
706 ) -> Result<Option<Credential>> {
707 let key = format!("{provider}-{credential_type}");
708 Ok(self.credentials.lock().await.get(&key).cloned())
709 }
710
711 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
712 let key = format!("{}-{}", provider, credential.credential_type());
713 self.credentials.lock().await.insert(key, credential);
714 Ok(())
715 }
716
717 async fn remove_credential(
718 &self,
719 provider: &str,
720 credential_type: CredentialType,
721 ) -> Result<()> {
722 let key = format!("{provider}-{credential_type}");
723 self.credentials.lock().await.remove(&key);
724 Ok(())
725 }
726 }
727
728 #[test]
729 fn test_pkce_generation() {
730 let pkce = AnthropicOAuth::generate_pkce();
731
732 assert_eq!(pkce.verifier.len(), 128);
733 assert_eq!(pkce.challenge.len(), 43);
734
735 let expected_challenge = base64_url_encode(&sha256(&pkce.verifier));
736 assert_eq!(pkce.challenge, expected_challenge);
737 }
738
739 #[test]
740 fn test_state_generation() {
741 let pkce1 = AnthropicOAuth::generate_pkce();
742 let pkce2 = AnthropicOAuth::generate_pkce();
743
744 assert_ne!(pkce1.verifier, pkce2.verifier);
745 }
746
747 #[test]
748 fn test_build_auth_url() {
749 let oauth = AnthropicOAuth::new();
750 let pkce = AnthropicOAuth::generate_pkce();
751 let url = oauth.build_auth_url(&pkce);
752
753 assert!(url.contains(AUTHORIZE_URL));
754 assert!(url.contains(&format!("client_id={CLIENT_ID}")));
755 assert!(url.contains("response_type=code"));
756 assert!(url.contains("code_challenge="));
757 assert!(url.contains("code_challenge_method=S256"));
758 assert!(url.contains(
759 "redirect_uri=https%3A%2F%2Fconsole.anthropic.com%2Foauth%2Fcode%2Fcallback"
760 ));
761 }
762
763 #[test]
764 fn test_parse_callback_code_from_url() {
765 let input = "https://console.anthropic.com/oauth/code/callback?code=abc123&state=state456";
766 let (code, state) = AnthropicOAuth::parse_callback_code(input).unwrap();
767 assert_eq!(code, "abc123");
768 assert_eq!(state, "state456");
769 }
770
771 #[test]
772 fn test_parse_callback_code_from_query() {
773 let input = "code=abc123&state=state456";
774 let (code, state) = AnthropicOAuth::parse_callback_code(input).unwrap();
775 assert_eq!(code, "abc123");
776 assert_eq!(state, "state456");
777 }
778
779 #[test]
780 fn test_parse_callback_code_from_fragment() {
781 let input = "https://console.anthropic.com/oauth/code/callback#code=abc123&state=state456";
782 let (code, state) = AnthropicOAuth::parse_callback_code(input).unwrap();
783 assert_eq!(code, "abc123");
784 assert_eq!(state, "state456");
785 }
786
787 #[test]
788 fn test_parse_callback_code_legacy() {
789 let input = "abc123#state456";
790 let (code, state) = AnthropicOAuth::parse_callback_code(input).unwrap();
791 assert_eq!(code, "abc123");
792 assert_eq!(state, "state456");
793 }
794
795 #[test]
796 fn test_extract_code_only_from_query() {
797 let input = "code=abc123";
798 let code = extract_code_only_from_str(input).unwrap();
799 assert_eq!(code, "abc123");
800 }
801
802 #[test]
803 fn test_extract_code_only_from_url() {
804 let input = "https://console.anthropic.com/oauth/code/callback?code=abc123";
805 let code = extract_code_only_from_str(input).unwrap();
806 assert_eq!(code, "abc123");
807 }
808
809 #[test]
810 fn test_extract_code_only_from_fragment() {
811 let input = "https://console.anthropic.com/oauth/code/callback#code=abc123";
812 let code = extract_code_only_from_str(input).unwrap();
813 assert_eq!(code, "abc123");
814 }
815
816 #[test]
817 fn test_resolve_callback_input_code_only_uses_verifier() {
818 let (code, state) = resolve_callback_input("code=abc123", "verifier-123").unwrap();
819 assert_eq!(code, "abc123");
820 assert_eq!(state, "verifier-123");
821 }
822
823 #[tokio::test]
824 async fn test_handle_input_empty_returns_need_input() {
825 let storage = Arc::new(TestAuthStorage::new());
826 let flow = AnthropicOAuthFlow::new(storage);
827 let mut state = flow.start_auth(AuthMethod::OAuth).await.unwrap();
828
829 let progress = flow.handle_input(&mut state, "").await.unwrap();
830
831 match progress {
832 AuthProgress::NeedInput(message) => {
833 assert!(message.contains("Paste the redirect URL"));
834 }
835 other => panic!("Expected NeedInput, got {other:?}"),
836 }
837 }
838
839 #[test]
840 fn test_get_oauth_headers() {
841 let headers = get_oauth_headers("test-token");
842
843 assert_eq!(headers.len(), 3);
844
845 let auth = headers.iter().find(|h| h.name == "authorization").unwrap();
846 assert_eq!(auth.value, "Bearer test-token");
847
848 let beta = headers.iter().find(|h| h.name == "anthropic-beta").unwrap();
849 assert!(beta.value.contains("oauth-2025-04-20"));
850 assert!(beta.value.contains("interleaved-thinking-2025-05-14"));
851 assert!(beta.value.contains("claude-code-20250219"));
852
853 let ua = headers.iter().find(|h| h.name == "user-agent").unwrap();
854 assert_eq!(ua.value, "claude-cli/2.1.2 (external, cli)");
855 }
856}