1use crate::auth::{AuthError, AuthStorage, AuthTokens, Credential, CredentialType, Result};
2use crate::auth::{AuthMethod, AuthProgress, AuthenticationFlow};
3use crate::config::provider;
4use async_trait::async_trait;
5use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10
11const AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
13const TOKEN_URL: &str = "https://console.anthropic.com/v1/oauth/token";
14const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
15const REDIRECT_URI: &str = "https://console.anthropic.com/oauth/code/callback";
16const SCOPES: &str = "org:create_api_key user:profile user:inference";
17
18#[derive(Debug)]
19pub struct PkceChallenge {
20 pub verifier: String,
21 pub challenge: String,
22}
23
24pub struct AnthropicOAuth {
25 client_id: String,
26 redirect_uri: String,
27 http_client: reqwest::Client,
28}
29
30impl Default for AnthropicOAuth {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl AnthropicOAuth {
37 pub fn new() -> Self {
38 Self {
39 client_id: CLIENT_ID.to_string(),
40 redirect_uri: REDIRECT_URI.to_string(),
41 http_client: reqwest::Client::new(),
42 }
43 }
44
45 pub fn generate_pkce() -> PkceChallenge {
47 let verifier = generate_random_string(128);
48 let challenge = base64_url_encode(&sha256(&verifier));
49 PkceChallenge {
50 verifier,
51 challenge,
52 }
53 }
54
55 pub fn build_auth_url(&self, pkce: &PkceChallenge) -> String {
57 let params = [
59 ("code", "true"),
60 ("client_id", &self.client_id),
61 ("response_type", "code"),
62 ("redirect_uri", &self.redirect_uri),
63 ("scope", SCOPES),
64 ("code_challenge", &pkce.challenge),
65 ("code_challenge_method", "S256"),
66 ("state", &pkce.verifier),
67 ];
68
69 let query = serde_urlencoded::to_string(params).unwrap();
70 format!("{AUTHORIZE_URL}?{query}")
71 }
72
73 pub fn parse_callback_code(callback_code: &str) -> Result<(String, String)> {
76 let parts: Vec<&str> = callback_code.split('#').collect();
77 if parts.len() != 2 {
78 return Err(AuthError::InvalidResponse(
79 "Invalid callback code format. Expected format: code#state".to_string(),
80 ));
81 }
82 Ok((parts[0].to_string(), parts[1].to_string()))
83 }
84
85 pub async fn exchange_code_for_tokens(
87 &self,
88 code: &str,
89 state: &str,
90 pkce_verifier: &str,
91 ) -> Result<AuthTokens> {
92 #[derive(Serialize)]
93 struct TokenRequest {
94 code: String,
95 state: String,
96 grant_type: String,
97 client_id: String,
98 redirect_uri: String,
99 code_verifier: String,
100 }
101
102 #[derive(Deserialize)]
103 struct TokenResponse {
104 access_token: String,
105 refresh_token: String,
106 expires_in: u64,
107 }
108
109 let request = TokenRequest {
110 code: code.to_string(),
111 state: state.to_string(),
112 grant_type: "authorization_code".to_string(),
113 client_id: self.client_id.clone(),
114 redirect_uri: self.redirect_uri.clone(),
115 code_verifier: pkce_verifier.to_string(),
116 };
117
118 let response = self
119 .http_client
120 .post(TOKEN_URL)
121 .json(&request)
122 .send()
123 .await?;
124
125 if !response.status().is_success() {
126 let status = response.status();
127 let error_text = response
128 .text()
129 .await
130 .unwrap_or_else(|_| "Unknown error".to_string());
131 return Err(AuthError::InvalidResponse(format!(
132 "Token exchange failed with status {status}: {error_text}"
133 )));
134 }
135
136 let token_response: TokenResponse = response.json().await.map_err(|e| {
137 AuthError::InvalidResponse(format!("Failed to parse token response: {e}"))
138 })?;
139
140 let expires_at = SystemTime::now() + Duration::from_secs(token_response.expires_in);
141
142 Ok(AuthTokens {
143 access_token: token_response.access_token,
144 refresh_token: token_response.refresh_token,
145 expires_at,
146 })
147 }
148
149 pub async fn refresh_tokens(&self, refresh_token: &str) -> Result<AuthTokens> {
151 #[derive(Serialize)]
152 struct RefreshRequest {
153 grant_type: String,
154 refresh_token: String,
155 client_id: String,
156 }
157
158 #[derive(Deserialize)]
159 struct TokenResponse {
160 access_token: String,
161 refresh_token: String,
162 expires_in: u64,
163 }
164
165 let request = RefreshRequest {
166 grant_type: "refresh_token".to_string(),
167 refresh_token: refresh_token.to_string(),
168 client_id: self.client_id.clone(),
169 };
170
171 let response = self
172 .http_client
173 .post(TOKEN_URL)
174 .json(&request)
175 .send()
176 .await?;
177
178 if !response.status().is_success() {
179 if response.status() == reqwest::StatusCode::UNAUTHORIZED {
180 return Err(AuthError::ReauthRequired);
181 }
182
183 let status = response.status();
184 let error_text = response
185 .text()
186 .await
187 .unwrap_or_else(|_| "Unknown error".to_string());
188 return Err(AuthError::InvalidResponse(format!(
189 "Token refresh failed with status {status}: {error_text}"
190 )));
191 }
192
193 let token_response: TokenResponse = response.json().await.map_err(|e| {
194 AuthError::InvalidResponse(format!("Failed to parse refresh response: {e}"))
195 })?;
196
197 let expires_at = SystemTime::now() + Duration::from_secs(token_response.expires_in);
198
199 Ok(AuthTokens {
200 access_token: token_response.access_token,
201 refresh_token: token_response.refresh_token,
202 expires_at,
203 })
204 }
205}
206
207pub fn tokens_need_refresh(tokens: &AuthTokens) -> bool {
209 match tokens.expires_at.duration_since(SystemTime::now()) {
210 Ok(duration) => duration.as_secs() <= 300, Err(_) => true, }
213}
214
215pub fn get_oauth_headers(access_token: &str) -> Vec<(String, String)> {
217 vec![
218 (
219 "authorization".to_string(),
220 format!("Bearer {access_token}"),
221 ),
222 ("anthropic-beta".to_string(), "oauth-2025-04-20".to_string()),
223 ]
224}
225
226pub async fn refresh_if_needed(
228 storage: &Arc<dyn AuthStorage>,
229 oauth_client: &AnthropicOAuth,
230) -> Result<AuthTokens> {
231 let credential = storage
232 .get_credential(&provider::anthropic().storage_key(), CredentialType::OAuth2)
233 .await?
234 .ok_or(AuthError::ReauthRequired)?;
235
236 let mut tokens = match credential {
237 Credential::OAuth2(tokens) => tokens,
238 _ => return Err(AuthError::ReauthRequired),
239 };
240
241 if tokens_need_refresh(&tokens) {
242 match oauth_client.refresh_tokens(&tokens.refresh_token).await {
244 Ok(new_tokens) => {
245 storage
246 .set_credential("anthropic", Credential::OAuth2(new_tokens.clone()))
247 .await?;
248 tokens = new_tokens;
249 }
250 Err(AuthError::ReauthRequired) => {
251 storage
253 .remove_credential("anthropic", CredentialType::OAuth2)
254 .await?;
255 return Err(AuthError::ReauthRequired);
256 }
257 Err(e) => return Err(e),
258 }
259 }
260
261 Ok(tokens)
262}
263
264fn generate_random_string(length: usize) -> String {
266 use rand::Rng;
267
268 const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
269 let mut rng = rand::thread_rng();
270
271 (0..length)
272 .map(|_| {
273 let idx = rng.gen_range(0..CHARSET.len());
274 CHARSET[idx] as char
275 })
276 .collect()
277}
278
279fn sha256(data: &str) -> Vec<u8> {
280 let mut hasher = Sha256::new();
281 hasher.update(data.as_bytes());
282 hasher.finalize().to_vec()
283}
284
285fn base64_url_encode(data: &[u8]) -> String {
286 URL_SAFE_NO_PAD.encode(data)
287}
288
289#[derive(Debug, Clone)]
291pub struct AnthropicAuthState {
292 pub kind: AnthropicAuthStateKind,
293}
294
295#[derive(Debug, Clone)]
296pub enum AnthropicAuthStateKind {
297 Initial,
299 OAuthStarted { verifier: String, auth_url: String },
301 AwaitingApiKey,
303}
304
305pub struct AnthropicOAuthFlow {
307 storage: Arc<dyn AuthStorage>,
308 oauth_client: AnthropicOAuth,
309}
310
311impl AnthropicOAuthFlow {
312 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
313 Self {
314 storage,
315 oauth_client: AnthropicOAuth::new(),
316 }
317 }
318}
319
320#[async_trait]
321impl AuthenticationFlow for AnthropicOAuthFlow {
322 type State = AnthropicAuthState;
323
324 fn available_methods(&self) -> Vec<AuthMethod> {
325 vec![AuthMethod::OAuth, AuthMethod::ApiKey]
326 }
327
328 async fn start_auth(&self, method: AuthMethod) -> Result<Self::State> {
329 match method {
330 AuthMethod::OAuth => {
331 let pkce = AnthropicOAuth::generate_pkce();
332 let auth_url = self.oauth_client.build_auth_url(&pkce);
333
334 Ok(AnthropicAuthState {
335 kind: AnthropicAuthStateKind::OAuthStarted {
336 verifier: pkce.verifier,
337 auth_url,
338 },
339 })
340 }
341 AuthMethod::ApiKey => Ok(AnthropicAuthState {
342 kind: AnthropicAuthStateKind::AwaitingApiKey,
343 }),
344 }
345 }
346
347 async fn get_initial_progress(
348 &self,
349 state: &Self::State,
350 method: AuthMethod,
351 ) -> Result<AuthProgress> {
352 match method {
353 AuthMethod::OAuth => {
354 if let AnthropicAuthStateKind::OAuthStarted { auth_url, .. } = &state.kind {
355 Ok(AuthProgress::OAuthStarted {
356 auth_url: auth_url.clone(),
357 })
358 } else {
359 Err(AuthError::InvalidState(
360 "Invalid state for OAuth".to_string(),
361 ))
362 }
363 }
364 AuthMethod::ApiKey => Ok(AuthProgress::NeedInput("Enter your API key".to_string())),
365 }
366 }
367
368 async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress> {
369 match &mut state.kind {
370 AnthropicAuthStateKind::Initial => Err(AuthError::InvalidState(
371 "No input expected in initial state".to_string(),
372 )),
373 AnthropicAuthStateKind::OAuthStarted { verifier, .. } => {
374 let (code, state_param) = if input.contains("code=") && input.contains("state=") {
376 let url = reqwest::Url::parse(input).map_err(|_| {
378 AuthError::InvalidCredential("Invalid redirect URL".to_string())
379 })?;
380
381 let params: std::collections::HashMap<_, _> = url.query_pairs().collect();
382 let code = params
383 .get("code")
384 .ok_or_else(|| AuthError::MissingInput("code parameter".to_string()))?;
385 let state = params
386 .get("state")
387 .ok_or_else(|| AuthError::MissingInput("state parameter".to_string()))?;
388
389 (code.to_string(), state.to_string())
390 } else {
391 AnthropicOAuth::parse_callback_code(input)?
393 };
394
395 let tokens = self
397 .oauth_client
398 .exchange_code_for_tokens(&code, &state_param, verifier)
399 .await?;
400
401 self.storage
403 .set_credential("anthropic", Credential::OAuth2(tokens))
404 .await?;
405
406 Ok(AuthProgress::Complete)
407 }
408 AnthropicAuthStateKind::AwaitingApiKey => {
409 if input.trim().is_empty() {
410 return Err(AuthError::InvalidCredential(
411 "API key cannot be empty".to_string(),
412 ));
413 }
414
415 self.storage
417 .set_credential(
418 "anthropic",
419 Credential::ApiKey {
420 value: input.to_string(),
421 },
422 )
423 .await?;
424
425 Ok(AuthProgress::Complete)
426 }
427 }
428 }
429
430 async fn is_authenticated(&self) -> Result<bool> {
431 if let Some(Credential::OAuth2(tokens)) = self
433 .storage
434 .get_credential(&provider::anthropic().storage_key(), CredentialType::OAuth2)
435 .await?
436 {
437 return Ok(!tokens_need_refresh(&tokens));
439 }
440
441 Ok(self
443 .storage
444 .get_credential(&provider::anthropic().storage_key(), CredentialType::ApiKey)
445 .await?
446 .is_some())
447 }
448
449 fn provider_name(&self) -> String {
450 provider::anthropic().storage_key()
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use crate::auth::{AuthStorage, Credential, CredentialType};
458 use async_trait::async_trait;
459 use std::collections::HashMap;
460 use tokio::sync::Mutex;
461
462 #[test]
463 fn test_pkce_generation() {
464 let pkce = AnthropicOAuth::generate_pkce();
465
466 assert_eq!(pkce.verifier.len(), 128);
468
469 assert_eq!(pkce.challenge.len(), 43);
471
472 let expected_challenge = base64_url_encode(&sha256(&pkce.verifier));
474 assert_eq!(pkce.challenge, expected_challenge);
475 }
476
477 #[test]
478 fn test_state_generation() {
479 let pkce = AnthropicOAuth::generate_pkce();
480 assert_eq!(pkce.verifier.len(), 128);
482 }
483
484 #[test]
485 fn test_auth_url_building() {
486 let oauth = AnthropicOAuth::new();
487 let pkce = AnthropicOAuth::generate_pkce();
488
489 let url = oauth.build_auth_url(&pkce);
490
491 assert!(url.contains(AUTHORIZE_URL));
492 assert!(url.contains(&format!("client_id={CLIENT_ID}")));
493 assert!(url.contains("response_type=code"));
494 assert!(url.contains("state="));
496 assert!(url.contains(&format!("code_challenge={}", &pkce.challenge)));
497 assert!(url.contains("code_challenge_method=S256"));
498 assert!(url.contains("code=true"));
499 assert!(url.contains(
501 "redirect_uri=https%3A%2F%2Fconsole.anthropic.com%2Foauth%2Fcode%2Fcallback"
502 ));
503 }
504
505 #[test]
506 fn test_parse_callback_code() {
507 let (code, state) = AnthropicOAuth::parse_callback_code("abc123#xyz789").unwrap();
509 assert_eq!(code, "abc123");
510 assert_eq!(state, "xyz789");
511
512 assert!(AnthropicOAuth::parse_callback_code("abc123").is_err());
514
515 assert!(AnthropicOAuth::parse_callback_code("abc#123#xyz").is_err());
517 }
518
519 struct MockAuthStorage {
521 credentials: Arc<Mutex<HashMap<String, Credential>>>,
522 }
523
524 impl MockAuthStorage {
525 fn new() -> Self {
526 Self {
527 credentials: Arc::new(Mutex::new(HashMap::new())),
528 }
529 }
530 }
531
532 #[async_trait]
533 impl AuthStorage for MockAuthStorage {
534 async fn get_credential(
535 &self,
536 _provider: &str,
537 _credential_type: CredentialType,
538 ) -> Result<Option<Credential>> {
539 Ok(None)
540 }
541
542 async fn set_credential(&self, provider: &str, credential: Credential) -> Result<()> {
543 let mut creds = self.credentials.lock().await;
544 creds.insert(provider.to_string(), credential);
545 Ok(())
546 }
547
548 async fn remove_credential(
549 &self,
550 provider: &str,
551 _credential_type: CredentialType,
552 ) -> Result<()> {
553 let mut creds = self.credentials.lock().await;
554 creds.remove(provider);
555 Ok(())
556 }
557 }
558
559 #[tokio::test]
560 async fn test_auth_flow_api_key() {
561 let storage = Arc::new(MockAuthStorage::new());
562 let auth_flow = AnthropicOAuthFlow::new(storage.clone());
563
564 let methods = auth_flow.available_methods();
566 assert_eq!(methods.len(), 2);
567 assert!(methods.contains(&AuthMethod::OAuth));
568 assert!(methods.contains(&AuthMethod::ApiKey));
569
570 let state = auth_flow.start_auth(AuthMethod::ApiKey).await.unwrap();
572 assert!(matches!(state.kind, AnthropicAuthStateKind::AwaitingApiKey));
573
574 let mut state = state;
576 let progress = auth_flow
577 .handle_input(&mut state, "test-api-key")
578 .await
579 .unwrap();
580 assert!(matches!(progress, AuthProgress::Complete));
581
582 let creds = storage.credentials.lock().await;
584 assert!(creds.contains_key("anthropic"));
585 if let Some(Credential::ApiKey { value }) = creds.get("anthropic") {
586 assert_eq!(value, "test-api-key");
587 } else {
588 panic!("Expected API key credential");
589 }
590 }
591
592 #[tokio::test]
593 async fn test_auth_flow_oauth_start() {
594 let storage = Arc::new(MockAuthStorage::new());
595 let auth_flow = AnthropicOAuthFlow::new(storage);
596
597 let state = auth_flow.start_auth(AuthMethod::OAuth).await.unwrap();
599
600 if let AnthropicAuthStateKind::OAuthStarted { auth_url, verifier } = &state.kind {
601 assert!(auth_url.contains(AUTHORIZE_URL));
603 assert!(auth_url.contains("client_id="));
604 assert!(auth_url.contains("code_challenge="));
605 assert!(!verifier.is_empty());
606 } else {
607 panic!("Expected OAuth started state");
608 }
609 }
610}