Skip to main content

shunt/
oauth.rs

1/// OAuth 2.0 PKCE flow + token refresh for claude.ai accounts.
2///
3/// Claude Code authenticates via OAuth, not API keys. Credentials are stored
4/// in ~/.claude/.credentials.json and sent as `Authorization: Bearer <token>`.
5use anyhow::{bail, Context, Result};
6use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::path::PathBuf;
10use std::time::{SystemTime, UNIX_EPOCH};
11use zeroize::ZeroizeOnDrop;
12
13// ---------------------------------------------------------------------------
14// Anthropic OAuth constants
15// ---------------------------------------------------------------------------
16
17pub const OAUTH_CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
18pub const OAUTH_AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
19pub const OAUTH_TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token";
20
21// ---------------------------------------------------------------------------
22// OpenAI / Codex OAuth constants
23// ---------------------------------------------------------------------------
24
25pub const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
26pub const OPENAI_OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
27pub const OPENAI_DEVICE_CODE_URL: &str = "https://auth.openai.com/api/accounts/deviceauth/usercode";
28pub const OPENAI_DEVICE_TOKEN_URL: &str = "https://auth.openai.com/api/accounts/deviceauth/token";
29
30// ---------------------------------------------------------------------------
31// Credential type
32// ---------------------------------------------------------------------------
33
34/// #20: ZeroizeOnDrop wipes access_token, refresh_token, and id_token from
35/// memory when the credential is dropped (e.g. after a token rotation).
36/// Clone is implemented manually so the derive doesn't conflict with ZeroizeOnDrop.
37#[derive(Serialize, Deserialize, ZeroizeOnDrop)]
38pub struct OAuthCredential {
39    pub access_token: String,
40    pub refresh_token: String,
41    /// Milliseconds since Unix epoch
42    #[zeroize(skip)]
43    pub expires_at: u64,
44    /// Account email, fetched from roles endpoint after auth
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    #[zeroize(skip)]
47    pub email: Option<String>,
48    /// OpenAI id_token — required by the Codex CLI's ~/.codex/auth.json
49    #[serde(default, skip_serializing_if = "Option::is_none")]
50    pub id_token: Option<String>,
51}
52
53impl Clone for OAuthCredential {
54    fn clone(&self) -> Self {
55        Self {
56            access_token: self.access_token.clone(),
57            refresh_token: self.refresh_token.clone(),
58            expires_at: self.expires_at,
59            email: self.email.clone(),
60            id_token: self.id_token.clone(),
61        }
62    }
63}
64
65impl std::fmt::Debug for OAuthCredential {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("OAuthCredential")
68            .field("access_token", &"[REDACTED]")
69            .field("refresh_token", &"[REDACTED]")
70            .field("expires_at", &self.expires_at)
71            .field("email", &self.email)
72            .field("id_token", &self.id_token.as_ref().map(|_| "[REDACTED]"))
73            .finish()
74    }
75}
76
77impl OAuthCredential {
78    /// True if the token expires within the next 5 minutes.
79    pub fn needs_refresh(&self) -> bool {
80        let now_ms = SystemTime::now()
81            .duration_since(UNIX_EPOCH)
82            .unwrap_or_default()
83            .as_millis() as u64;
84        now_ms >= self.expires_at.saturating_sub(5 * 60 * 1000)
85    }
86}
87
88// ---------------------------------------------------------------------------
89// Auto-import from Claude Code's own credential file
90// ---------------------------------------------------------------------------
91
92// ---------------------------------------------------------------------------
93// Auto-import from Codex CLI's credential file (~/.codex/auth.json)
94// ---------------------------------------------------------------------------
95
96/// Raw format used by ~/.codex/auth.json
97/// The tokens are nested under a "tokens" key; there is no top-level expires_at.
98/// Expiry is read from the JWT `exp` claim inside the access_token.
99#[derive(Deserialize)]
100struct CodexAuth {
101    tokens: CodexTokens,
102}
103
104#[derive(Deserialize)]
105struct CodexTokens {
106    access_token: String,
107    #[serde(default)]
108    refresh_token: Option<String>,
109    #[serde(default)]
110    id_token: Option<String>,
111}
112
113/// Write credentials to ~/.codex/auth.json so the Codex CLI can use them without re-login.
114///
115/// Called automatically after add-account and token refresh for OpenAI accounts.
116pub fn write_codex_auth_file(cred: &OAuthCredential) {
117    let Some(ref id_token) = cred.id_token else { return };
118    let path = codex_credentials_path();
119    if let Some(parent) = path.parent() {
120        let _ = std::fs::create_dir_all(parent);
121    }
122    let auth = serde_json::json!({
123        "tokens": {
124            "access_token": cred.access_token,
125            "refresh_token": cred.refresh_token,
126            "id_token": id_token,
127        }
128    });
129    if let Ok(text) = serde_json::to_string_pretty(&auth) {
130        let tmp = path.with_extension("tmp");
131        if std::fs::write(&tmp, &text).is_ok() {
132            let _ = std::fs::rename(&tmp, &path);
133        }
134    }
135}
136
137pub fn codex_credentials_path() -> PathBuf {
138    dirs::home_dir()
139        .unwrap_or_else(|| PathBuf::from("."))
140        .join(".codex")
141        .join("auth.json")
142}
143
144/// Read the OAuth credential from the Codex CLI's stored auth file.
145pub fn read_codex_credentials() -> Option<OAuthCredential> {
146    let text = std::fs::read_to_string(codex_credentials_path()).ok()?;
147    let raw: CodexAuth = serde_json::from_str(&text).ok()?;
148
149    let now_ms = SystemTime::now()
150        .duration_since(UNIX_EPOCH)
151        .unwrap_or_default()
152        .as_millis() as u64;
153
154    // Extract exp from the JWT payload without verifying signature.
155    let expires_at = jwt_exp_ms(&raw.tokens.access_token)
156        .unwrap_or(now_ms + 3600 * 1000); // default: 1 hour from now
157
158    Some(OAuthCredential {
159        access_token: raw.tokens.access_token,
160        refresh_token: raw.tokens.refresh_token.unwrap_or_default(),
161        expires_at,
162        email: None,
163        id_token: raw.tokens.id_token,
164    })
165}
166
167/// Decode the `exp` claim from a JWT payload (no signature verification).
168/// Returns expiry as Unix milliseconds.
169///
170/// Applies a sanity cap (#8): rejects tokens already expired or claiming to
171/// expire more than 25 hours in the future (which would suggest a forged `exp`).
172/// Callers fall back to `now + 1h` when this returns None.
173pub(crate) fn jwt_exp_ms(token: &str) -> Option<u64> {
174    let payload_b64 = token.splitn(3, '.').nth(1)?;
175    let decoded = base64_url_decode(payload_b64)?;
176    let v: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
177    let exp_secs = v.get("exp")?.as_u64()?;
178    let exp_ms = exp_secs.saturating_mul(1_000);
179    let now_ms = SystemTime::now()
180        .duration_since(UNIX_EPOCH)
181        .unwrap_or_default()
182        .as_millis() as u64;
183    // Reject already-expired tokens or tokens expiring more than 25 hours out.
184    let max_exp_ms = now_ms.saturating_add(25 * 60 * 60 * 1_000);
185    if exp_ms > max_exp_ms || exp_ms < now_ms {
186        return None;
187    }
188    Some(exp_ms)
189}
190
191/// Minimal base64url decoder (no padding, URL-safe alphabet).
192fn base64_url_decode(s: &str) -> Option<Vec<u8>> {
193    URL_SAFE_NO_PAD.decode(s).ok()
194}
195
196
197/// Raw format used by ~/.claude/.credentials.json
198#[derive(Deserialize)]
199#[serde(rename_all = "camelCase")]
200struct ClaudeCredentials {
201    claude_ai_oauth: Option<ClaudeOAuthRaw>,
202}
203
204#[derive(Deserialize)]
205#[serde(rename_all = "camelCase")]
206struct ClaudeOAuthRaw {
207    access_token: String,
208    refresh_token: String,
209    expires_at: u64,
210}
211
212// ---------------------------------------------------------------------------
213// Session info (plan + identity) from stored credentials
214// ---------------------------------------------------------------------------
215
216pub struct SessionInfo {
217    pub email_or_id: String,
218    pub plan: String,
219}
220
221/// Read plan and identity from Claude Code's stored credentials JSON.
222/// Works for both keychain and file-based storage.
223pub fn read_claude_session_info() -> Option<SessionInfo> {
224    #[derive(serde::Deserialize)]
225    #[serde(rename_all = "camelCase")]
226    struct Outer {
227        claude_ai_oauth: Option<Inner>,
228    }
229    #[derive(serde::Deserialize)]
230    #[serde(rename_all = "camelCase")]
231    struct Inner {
232        subscription_type: Option<String>,
233        #[serde(rename = "rateLimitTier")]
234        rate_limit_tier: Option<String>,
235    }
236
237    let text = read_raw_credentials_json()?;
238    let outer: Outer = serde_json::from_str(&text).ok()?;
239    let inner = outer.claude_ai_oauth?;
240
241    let plan = inner.subscription_type.unwrap_or_else(|| "pro".into());
242    let email_or_id = inner.rate_limit_tier.unwrap_or_else(|| "unknown".into());
243
244    Some(SessionInfo { email_or_id, plan })
245}
246
247/// Returns the raw credentials JSON string from keychain (macOS) or file.
248fn read_raw_credentials_json() -> Option<String> {
249    #[cfg(target_os = "macos")]
250    {
251        // `security` can hang indefinitely in SSH sessions (no GUI keychain
252        // context). Run it in a thread with a 5-second timeout and fall
253        // through to the file-based fallback if it doesn't respond.
254        let (tx, rx) = std::sync::mpsc::channel();
255        std::thread::spawn(move || {
256            let out = std::process::Command::new("security")
257                .args(["find-generic-password", "-s", "Claude Code-credentials", "-w"])
258                .output()
259                .ok();
260            let _ = tx.send(out);
261        });
262        if let Ok(Some(out)) = rx.recv_timeout(std::time::Duration::from_secs(5)) {
263            if out.status.success() {
264                if let Ok(s) = String::from_utf8(out.stdout) {
265                    return Some(s.trim().to_owned());
266                }
267            }
268        }
269    }
270    std::fs::read_to_string(claude_credentials_path()).ok()
271}
272
273pub fn claude_credentials_path() -> PathBuf {
274    dirs::home_dir()
275        .unwrap_or_else(|| PathBuf::from("."))
276        .join(".claude")
277        .join(".credentials.json")
278}
279
280/// Read the OAuth credential from Claude Code's own credential file.
281/// On macOS, tries the Keychain first (service "Claude Code-credentials"),
282/// then falls back to ~/.claude/.credentials.json.
283pub fn read_claude_credentials() -> Option<OAuthCredential> {
284    // macOS: try Keychain first
285    #[cfg(target_os = "macos")]
286    if let Some(cred) = read_claude_credentials_keychain() {
287        return Some(cred);
288    }
289
290    // Fallback: JSON file (older Claude Code versions / non-macOS)
291    let path = claude_credentials_path();
292    let text = std::fs::read_to_string(&path).ok()?;
293    parse_claude_credentials_json(&text)
294}
295
296#[cfg(target_os = "macos")]
297fn read_claude_credentials_keychain() -> Option<OAuthCredential> {
298    let text = read_raw_credentials_json()?;
299    parse_claude_credentials_json(&text)
300}
301
302fn parse_claude_credentials_json(text: &str) -> Option<OAuthCredential> {
303    let raw: ClaudeCredentials = serde_json::from_str(text).ok()?;
304    let inner = raw.claude_ai_oauth?;
305    Some(OAuthCredential {
306        access_token: inner.access_token,
307        refresh_token: inner.refresh_token,
308        expires_at: inner.expires_at,
309        email: None,
310        id_token: None,
311    })
312}
313
314// ---------------------------------------------------------------------------
315// Token refresh
316// ---------------------------------------------------------------------------
317
318/// Refresh an expired access token. Returns the updated credential.
319pub async fn refresh_token(cred: &OAuthCredential) -> Result<OAuthCredential> {
320    let client = reqwest::Client::new();
321
322    let resp = client
323        .post(OAUTH_TOKEN_URL)
324        .header("content-type", "application/x-www-form-urlencoded")
325        .body(format!(
326            "grant_type=refresh_token&refresh_token={}&client_id={}",
327            urlencoding::encode(&cred.refresh_token),
328            OAUTH_CLIENT_ID,
329        ))
330        .send()
331        .await
332        .context("token refresh request failed")?;
333
334    if !resp.status().is_success() {
335        let status = resp.status();
336        let body = resp.text().await.unwrap_or_default();
337        let err = serde_json::from_str::<serde_json::Value>(&body).ok()
338            .and_then(|v| v["error"].as_str().or_else(|| v["error_description"].as_str()).map(String::from))
339            .unwrap_or_else(|| "unknown error".to_string());
340        bail!("token refresh failed ({status}): {err}");
341    }
342
343    let body: serde_json::Value = resp.json().await.context("token refresh: invalid JSON")?;
344
345    let access_token = body["access_token"]
346        .as_str()
347        .context("token refresh: missing access_token")?
348        .to_owned();
349
350    let refresh_token = body["refresh_token"]
351        .as_str()
352        .unwrap_or(&cred.refresh_token)
353        .to_owned();
354
355    // expires_in is seconds from now
356    let expires_in_secs = body["expires_in"].as_u64().unwrap_or(3600);
357    let now_ms = SystemTime::now()
358        .duration_since(UNIX_EPOCH)
359        .unwrap_or_default()
360        .as_millis() as u64;
361    let expires_at = now_ms + expires_in_secs * 1000;
362
363    Ok(OAuthCredential { access_token, refresh_token, expires_at, email: cred.email.clone(), id_token: None })
364}
365
366// ---------------------------------------------------------------------------
367// Account identity (email) from roles endpoint
368// ---------------------------------------------------------------------------
369
370/// Fetch the account email from the Anthropic roles endpoint.
371/// Returns `None` on any error (non-fatal).
372pub async fn fetch_account_email(access_token: &str) -> Option<String> {
373    let client = reqwest::Client::builder()
374        .timeout(std::time::Duration::from_secs(8))
375        .build()
376        .ok()?;
377    let resp = client
378        .get("https://api.anthropic.com/api/oauth/claude_cli/roles")
379        .header("authorization", format!("Bearer {access_token}"))
380        .header("anthropic-version", "2023-06-01")
381        .header("anthropic-dangerous-direct-browser-access", "true")
382        .send()
383        .await
384        .ok()?;
385
386    if !resp.status().is_success() {
387        return None;
388    }
389
390    let body: serde_json::Value = resp.json().await.ok()?;
391
392    // Try dedicated email fields first.
393    for field in &["email", "emailAddress", "email_address"] {
394        if let Some(e) = body[field].as_str().filter(|s| s.contains('@')) {
395            return Some(e.to_owned());
396        }
397    }
398
399    // Fall back to extracting from organization_name ("addr@example.com's Organization").
400    let org = body["organization_name"].as_str()?;
401    let email = org.strip_suffix("'s Organization").unwrap_or(org).trim();
402    if !email.is_empty() { Some(email.to_owned()) } else { None }
403}
404
405// ---------------------------------------------------------------------------
406// PKCE browser OAuth flow (for adding additional accounts)
407// ---------------------------------------------------------------------------
408
409struct Pkce {
410    verifier: String,
411    challenge: String,
412}
413
414fn generate_pkce() -> Pkce {
415    let verifier_bytes: [u8; 32] = rand_bytes();
416    let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
417
418    let hash = Sha256::digest(verifier.as_bytes());
419    let challenge = URL_SAFE_NO_PAD.encode(hash);
420
421    Pkce { verifier, challenge }
422}
423
424/// Generate N cryptographically random bytes using the OS entropy source.
425/// Panics if the system RNG is unavailable (unrecoverable error in a security context).
426pub fn rand_bytes<const N: usize>() -> [u8; N] {
427    let mut bytes = [0u8; N];
428    getrandom::getrandom(&mut bytes)
429        .expect("OS random number generator unavailable — cannot generate secure random bytes");
430    bytes
431}
432
433fn random_state() -> String {
434    let bytes: [u8; 16] = rand_bytes();
435    hex::encode(bytes)
436}
437
438pub const OAUTH_REDIRECT_URI: &str = "https://platform.claude.com/oauth/code/callback";
439
440/// Run the PKCE OAuth flow using the registered redirect URI.
441///
442/// Opens the browser to claude.ai. After the user authorizes, the callback page
443/// displays a code (format: CODE#STATE). The user pastes it here; we split out
444/// the state and exchange the code at the token endpoint.
445pub async fn run_oauth_flow() -> Result<OAuthCredential> {
446    use std::io::{self, Write};
447
448    let pkce = generate_pkce();
449    let state = random_state();
450    let redirect_uri = OAUTH_REDIRECT_URI;
451
452    let scope = urlencoding::encode(
453        "user:inference user:profile user:file_upload user:mcp_servers user:sessions:claude_code",
454    );
455    let auth_url = format!(
456        "{base}?response_type=code\
457         &client_id={client_id}\
458         &redirect_uri={redirect}\
459         &scope={scope}\
460         &state={state}\
461         &code_challenge={challenge}\
462         &code_challenge_method=S256",
463        base = OAUTH_AUTHORIZE_URL,
464        client_id = OAUTH_CLIENT_ID,
465        redirect = urlencoding::encode(redirect_uri),
466        scope = scope,
467        state = state,
468        challenge = pkce.challenge,
469    );
470
471    println!("\nOpening browser for claude.ai login...");
472    println!("If it does not open automatically, visit:\n  {auth_url}\n");
473    open_browser(&auth_url);
474
475    println!("After you authorize, the page will show an authorization code.");
476    println!("Copy it and paste it here.");
477    println!();
478    print!("Paste code: ");
479    io::stdout().flush()?;
480
481    let mut pasted = String::new();
482    io::stdin().read_line(&mut pasted)?;
483    // Page shows "code#state"
484    let pasted = pasted.trim();
485    let (code, pasted_state) = if let Some((c, s)) = pasted.split_once('#') {
486        (c.trim(), s.trim())
487    } else {
488        (pasted, state.as_str())
489    };
490
491    if code.is_empty() {
492        bail!("No code entered.");
493    }
494
495    let cred = exchange_code(code, pasted_state, redirect_uri, &pkce.verifier).await?;
496    Ok(cred)
497}
498
499async fn exchange_code(code: &str, state: &str, redirect_uri: &str, verifier: &str) -> Result<OAuthCredential> {
500    let client = reqwest::Client::new();
501
502    let body = serde_json::json!({
503        "grant_type": "authorization_code",
504        "code": code,
505        "state": state,
506        "redirect_uri": redirect_uri,
507        "client_id": OAUTH_CLIENT_ID,
508        "code_verifier": verifier,
509    });
510
511    let resp = client
512        .post(OAUTH_TOKEN_URL)
513        .header("content-type", "application/json")
514        .header("anthropic-version", "2023-06-01")
515        .json(&body)
516        .send()
517        .await
518        .context("token exchange request failed")?;
519
520    if !resp.status().is_success() {
521        let status = resp.status();
522        let body = resp.text().await.unwrap_or_default();
523        let err = serde_json::from_str::<serde_json::Value>(&body).ok()
524            .and_then(|v| v["error"].as_str().or_else(|| v["error_description"].as_str()).map(String::from))
525            .unwrap_or_else(|| "unknown error".to_string());
526        bail!("token exchange failed ({status}): {err}");
527    }
528
529    let body: serde_json::Value = resp.json().await.context("token exchange: invalid JSON")?;
530
531    let access_token = body["access_token"]
532        .as_str()
533        .context("token exchange: missing access_token")?
534        .to_owned();
535    let refresh_token = body["refresh_token"]
536        .as_str()
537        .unwrap_or("")
538        .to_owned();
539    let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
540    let now_ms = SystemTime::now()
541        .duration_since(UNIX_EPOCH)
542        .unwrap_or_default()
543        .as_millis() as u64;
544
545    Ok(OAuthCredential {
546        access_token,
547        refresh_token,
548        expires_at: now_ms + expires_in * 1000,
549        email: None,
550        id_token: None,
551    })
552}
553
554// ---------------------------------------------------------------------------
555// Token revocation
556// ---------------------------------------------------------------------------
557
558pub const OAUTH_REVOKE_URL: &str = "https://platform.claude.com/v1/oauth/revoke";
559
560/// Revoke an OAuth token on the server. Best-effort — errors are non-fatal.
561pub async fn revoke_token(access_token: &str) -> bool {
562    let client = reqwest::Client::builder()
563        .timeout(std::time::Duration::from_secs(8))
564        .build()
565        .unwrap_or_default();
566    client
567        .post(OAUTH_REVOKE_URL)
568        .header("content-type", "application/x-www-form-urlencoded")
569        .header("anthropic-version", "2023-06-01")
570        .body(format!("token={}", urlencoding::encode(access_token)))
571        .send()
572        .await
573        .map(|r| r.status().is_success())
574        .unwrap_or(false)
575}
576
577// ---------------------------------------------------------------------------
578// OpenAI token refresh
579// ---------------------------------------------------------------------------
580
581/// Refresh an expired OpenAI / Codex access token using the stored refresh_token.
582pub async fn refresh_openai_token(cred: &OAuthCredential) -> Result<OAuthCredential> {
583    let client = reqwest::Client::new();
584
585    let resp = client
586        .post(OPENAI_OAUTH_TOKEN_URL)
587        .header("content-type", "application/x-www-form-urlencoded")
588        .body(format!(
589            "grant_type=refresh_token&refresh_token={}&client_id={}",
590            urlencoding::encode(&cred.refresh_token),
591            OPENAI_OAUTH_CLIENT_ID,
592        ))
593        .send()
594        .await
595        .context("OpenAI token refresh request failed")?;
596
597    if !resp.status().is_success() {
598        let status = resp.status();
599        let body = resp.text().await.unwrap_or_default();
600        let err = serde_json::from_str::<serde_json::Value>(&body).ok()
601            .and_then(|v| v["error"].as_str().or_else(|| v["error_description"].as_str()).map(String::from))
602            .unwrap_or_else(|| "unknown error".to_string());
603        bail!("OpenAI token refresh failed ({status}): {err}");
604    }
605
606    let body: serde_json::Value = resp.json().await.context("OpenAI token refresh: invalid JSON")?;
607
608    let access_token = body["access_token"]
609        .as_str()
610        .context("OpenAI token refresh: missing access_token")?
611        .to_owned();
612
613    let refresh_token = body["refresh_token"]
614        .as_str()
615        .unwrap_or(&cred.refresh_token)
616        .to_owned();
617
618    let id_token = body["id_token"].as_str().map(|s| s.to_owned())
619        .or_else(|| cred.id_token.clone());
620
621    let expires_in_secs = body["expires_in"].as_u64().unwrap_or(3600);
622    let now_ms = SystemTime::now()
623        .duration_since(UNIX_EPOCH)
624        .unwrap_or_default()
625        .as_millis() as u64;
626
627    Ok(OAuthCredential {
628        access_token,
629        refresh_token,
630        expires_at: now_ms + expires_in_secs * 1000,
631        email: cred.email.clone(),
632        id_token,
633    })
634}
635
636// ---------------------------------------------------------------------------
637// OpenAI / Codex device code flow (custom 3-step, not RFC 8628)
638// ---------------------------------------------------------------------------
639//
640// Codex uses its own device auth protocol:
641//   1. POST /deviceauth/usercode  {"client_id"} → {device_auth_id, user_code, interval}
642//   2. Poll  POST /deviceauth/token  {"device_auth_id","user_code"} until 200
643//            → {authorization_code, code_verifier, code_challenge}
644//   3. POST /oauth/token  PKCE exchange → {access_token, refresh_token, id_token}
645//
646// Verification URI where the user enters the code: https://auth.openai.com/codex/device
647
648/// Run the Codex device authorization flow. No local HTTP server required.
649///
650/// Displays a short user_code; the user visits `https://auth.openai.com/codex/device`
651/// and enters it. We poll until authorized, then exchange for tokens.
652pub async fn run_openai_oauth_flow() -> Result<OAuthCredential> {
653    const VERIFY_URI: &str = "https://auth.openai.com/codex/device";
654    const TIMEOUT_SECS: u64 = 15 * 60;
655
656    let client = reqwest::Client::new();
657
658    // Step 1: request user code
659    let resp = client
660        .post(OPENAI_DEVICE_CODE_URL)
661        .header("content-type", "application/json")
662        .json(&serde_json::json!({"client_id": OPENAI_OAUTH_CLIENT_ID}))
663        .send()
664        .await
665        .context("Codex device code request failed")?;
666
667    if !resp.status().is_success() {
668        let status = resp.status();
669        let body = resp.text().await.unwrap_or_default();
670        let err = serde_json::from_str::<serde_json::Value>(&body).ok()
671            .and_then(|v| v["error"].as_str().or_else(|| v["error_description"].as_str()).map(String::from))
672            .unwrap_or_else(|| "unknown error".to_string());
673        bail!("Codex device code request failed ({status}): {err}");
674    }
675
676    let info: serde_json::Value = resp.json().await.context("device code: invalid JSON")?;
677    let device_auth_id = info["device_auth_id"].as_str().context("missing device_auth_id")?.to_owned();
678    let user_code = info["user_code"].as_str().context("missing user_code")?.to_owned();
679    let interval_secs = info["interval"].as_u64().unwrap_or(5);
680
681    println!();
682    println!("  Visit:  {VERIFY_URI}");
683    println!("  Code:   \x1b[1;33m{user_code}\x1b[0m");
684    println!();
685    println!("  Waiting for authorization...");
686
687    open_browser(VERIFY_URI);
688
689    // Step 2: poll until code is approved
690    let deadline = std::time::Instant::now() + std::time::Duration::from_secs(TIMEOUT_SECS);
691    let poll_interval = std::time::Duration::from_secs(interval_secs);
692    let poll_body = serde_json::json!({
693        "device_auth_id": device_auth_id,
694        "user_code": user_code,
695    });
696
697    let (authorization_code, code_verifier) = loop {
698        tokio::time::sleep(poll_interval).await;
699
700        if std::time::Instant::now() > deadline {
701            bail!("Device code expired (15 min). Run `shunt add-account` again.");
702        }
703
704        let resp = client
705            .post(OPENAI_DEVICE_TOKEN_URL)
706            .header("content-type", "application/json")
707            .json(&poll_body)
708            .send()
709            .await
710            .context("Codex device poll request failed")?;
711
712        let status = resp.status();
713        // 403/404 = still pending; any 2xx = authorized
714        if status.as_u16() == 403 || status.as_u16() == 404 {
715            continue;
716        }
717        if !status.is_success() {
718            let body = resp.text().await.unwrap_or_default();
719            let err = serde_json::from_str::<serde_json::Value>(&body).ok()
720                .and_then(|v| v["error"].as_str().or_else(|| v["error_description"].as_str()).map(String::from))
721                .unwrap_or_else(|| "unknown error".to_string());
722            bail!("Codex device poll error ({status}): {err}");
723        }
724
725        let body: serde_json::Value = resp.json().await.context("device poll: invalid JSON")?;
726        let code = body["authorization_code"].as_str().context("missing authorization_code")?.to_owned();
727        let verifier = body["code_verifier"].as_str().context("missing code_verifier")?.to_owned();
728        break (code, verifier);
729    };
730
731    // Step 3: exchange authorization_code for tokens
732    let redirect_uri = format!("https://auth.openai.com/deviceauth/callback");
733    let token_body = format!(
734        "grant_type=authorization_code&code={}&redirect_uri={}&client_id={}&code_verifier={}",
735        urlencoding::encode(&authorization_code),
736        urlencoding::encode(&redirect_uri),
737        OPENAI_OAUTH_CLIENT_ID,
738        urlencoding::encode(&code_verifier),
739    );
740    let resp = client
741        .post(OPENAI_OAUTH_TOKEN_URL)
742        .header("content-type", "application/x-www-form-urlencoded")
743        .body(token_body)
744        .send()
745        .await
746        .context("Codex token exchange failed")?;
747
748    if !resp.status().is_success() {
749        let status = resp.status();
750        let body = resp.text().await.unwrap_or_default();
751        let err = serde_json::from_str::<serde_json::Value>(&body).ok()
752            .and_then(|v| v["error"].as_str().or_else(|| v["error_description"].as_str()).map(String::from))
753            .unwrap_or_else(|| "unknown error".to_string());
754        bail!("Codex token exchange failed ({status}): {err}");
755    }
756
757    let body: serde_json::Value = resp.json().await.context("token exchange: invalid JSON")?;
758    let access_token = body["access_token"]
759        .as_str()
760        .or_else(|| body["id_token"].as_str())
761        .context("token exchange: missing access_token")?
762        .to_owned();
763    let refresh_token = body["refresh_token"].as_str().unwrap_or("").to_owned();
764    let id_token = body["id_token"].as_str().map(|s| s.to_owned());
765    let expires_at = jwt_exp_ms(&access_token).unwrap_or_else(|| {
766        let now_ms = SystemTime::now()
767            .duration_since(UNIX_EPOCH)
768            .unwrap_or_default()
769            .as_millis() as u64;
770        now_ms + body["expires_in"].as_u64().unwrap_or(3600) * 1000
771    });
772
773    Ok(OAuthCredential { access_token, refresh_token, expires_at, email: None, id_token })
774}
775
776// ---------------------------------------------------------------------------
777// OpenAI account identity
778// ---------------------------------------------------------------------------
779
780/// Fetch the account email from OpenAI's userinfo endpoint.
781pub async fn fetch_openai_account_email(access_token: &str) -> Option<String> {
782    let client = reqwest::Client::builder()
783        .timeout(std::time::Duration::from_secs(8))
784        .build()
785        .ok()?;
786    let resp = client
787        .get("https://auth.openai.com/userinfo")
788        .header("authorization", format!("Bearer {access_token}"))
789        .send()
790        .await
791        .ok()?;
792    if !resp.status().is_success() { return None; }
793    let body: serde_json::Value = resp.json().await.ok()?;
794    body["email"].as_str().map(|s| s.to_owned())
795}
796
797fn open_browser(url: &str) {
798    #[cfg(target_os = "macos")]
799    { std::process::Command::new("open").arg(url).spawn().ok(); }
800
801    #[cfg(target_os = "linux")]
802    { std::process::Command::new("xdg-open").arg(url).spawn().ok(); }
803
804    // Use explorer.exe directly — avoids cmd.exe shell expansion of OAuth URL
805    // special characters (& % etc.) that would misparse with `cmd /c start`.
806    #[cfg(target_os = "windows")]
807    { std::process::Command::new("explorer").arg(url).spawn().ok(); }
808}
809
810#[cfg(test)]
811mod tests {
812    use super::*;
813
814    #[test]
815    fn test_rand_bytes_correct_length() {
816        let a: [u8; 16] = rand_bytes();
817        assert_eq!(a.len(), 16);
818        let b: [u8; 32] = rand_bytes();
819        assert_eq!(b.len(), 32);
820    }
821
822    #[test]
823    fn test_rand_bytes_not_all_zeros() {
824        // The probability of 32 random bytes all being zero is 1/2^256 — effectively impossible.
825        let bytes: [u8; 32] = rand_bytes();
826        assert!(bytes.iter().any(|&b| b != 0), "rand_bytes must not return all-zero output");
827    }
828
829    #[test]
830    fn test_rand_bytes_unique() {
831        // Two calls must not return the same value (probability 1/2^256 they'd collide).
832        let a: [u8; 32] = rand_bytes();
833        let b: [u8; 32] = rand_bytes();
834        assert_ne!(a, b, "rand_bytes must return unique values each call");
835    }
836
837    #[test]
838    fn test_jwt_exp_ms_sanity_cap() {
839        use std::time::{SystemTime, UNIX_EPOCH};
840        use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
841
842        fn make_token(exp_offset_secs: i64) -> String {
843            let exp = SystemTime::now()
844                .duration_since(UNIX_EPOCH).unwrap().as_secs() as i64 + exp_offset_secs;
845            let payload = serde_json::json!({"sub":"test","exp": exp as u64});
846            let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"HS256\"}");
847            let body   = URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
848            format!("{header}.{body}.fakesig")
849        }
850
851        // Already expired: should return None
852        assert!(jwt_exp_ms(&make_token(-3600)).is_none(), "expired token must return None");
853        // Valid 1h: should return Some
854        assert!(jwt_exp_ms(&make_token(3600)).is_some(), "1h-future token must return Some");
855        // Valid 24h: should return Some
856        assert!(jwt_exp_ms(&make_token(86400)).is_some(), "24h-future token must return Some");
857        // 26h in the future: exceeds 25h cap → None
858        assert!(jwt_exp_ms(&make_token(26 * 3600)).is_none(), "26h-future token must return None (forged exp)");
859    }
860
861    #[test]
862    fn test_pkce_pair_properties() {
863        let pkce = generate_pkce();
864        // Verifier must be base64url-safe (no padding, only URL-safe chars)
865        assert!(pkce.verifier.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
866            "PKCE verifier must be base64url-safe");
867        // Challenge must differ from verifier (it's the SHA-256 hash)
868        assert_ne!(pkce.verifier, pkce.challenge,
869            "PKCE challenge must not equal verifier");
870        assert!(!pkce.challenge.is_empty());
871        assert!(!pkce.verifier.is_empty());
872    }
873}