Skip to main content

shunt/
oauth.rs

1/// OAuth 2.0 PKCE flow + token refresh for claude.ai accounts.
2///
3/// Claude Code authenticates via OAuth, not API keys. Credentials are stored
4/// in ~/.claude/.credentials.json and sent as `Authorization: Bearer <token>`.
5use anyhow::{bail, Context, Result};
6use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::path::PathBuf;
10use std::time::{SystemTime, UNIX_EPOCH};
11
12// ---------------------------------------------------------------------------
13// Anthropic OAuth constants
14// ---------------------------------------------------------------------------
15
16pub const OAUTH_CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
17pub const OAUTH_AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
18pub const OAUTH_TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token";
19
20// ---------------------------------------------------------------------------
21// OpenAI / Codex OAuth constants
22// ---------------------------------------------------------------------------
23
24pub const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
25pub const OPENAI_OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
26pub const OPENAI_DEVICE_CODE_URL: &str = "https://auth.openai.com/api/accounts/deviceauth/usercode";
27pub const OPENAI_DEVICE_TOKEN_URL: &str = "https://auth.openai.com/api/accounts/deviceauth/token";
28
29// ---------------------------------------------------------------------------
30// Credential type
31// ---------------------------------------------------------------------------
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct OAuthCredential {
35    pub access_token: String,
36    pub refresh_token: String,
37    /// Milliseconds since Unix epoch
38    pub expires_at: u64,
39    /// Account email, fetched from roles endpoint after auth
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub email: Option<String>,
42    /// OpenAI id_token — required by the Codex CLI's ~/.codex/auth.json
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub id_token: Option<String>,
45}
46
47impl OAuthCredential {
48    /// True if the token expires within the next 5 minutes.
49    pub fn needs_refresh(&self) -> bool {
50        let now_ms = SystemTime::now()
51            .duration_since(UNIX_EPOCH)
52            .unwrap_or_default()
53            .as_millis() as u64;
54        now_ms >= self.expires_at.saturating_sub(5 * 60 * 1000)
55    }
56}
57
58// ---------------------------------------------------------------------------
59// Auto-import from Claude Code's own credential file
60// ---------------------------------------------------------------------------
61
62// ---------------------------------------------------------------------------
63// Auto-import from Codex CLI's credential file (~/.codex/auth.json)
64// ---------------------------------------------------------------------------
65
66/// Raw format used by ~/.codex/auth.json
67/// The tokens are nested under a "tokens" key; there is no top-level expires_at.
68/// Expiry is read from the JWT `exp` claim inside the access_token.
69#[derive(Deserialize)]
70struct CodexAuth {
71    tokens: CodexTokens,
72}
73
74#[derive(Deserialize)]
75struct CodexTokens {
76    access_token: String,
77    #[serde(default)]
78    refresh_token: Option<String>,
79    #[serde(default)]
80    id_token: Option<String>,
81}
82
83/// Write credentials to ~/.codex/auth.json so the Codex CLI can use them without re-login.
84///
85/// Called automatically after add-account and token refresh for OpenAI accounts.
86pub fn write_codex_auth_file(cred: &OAuthCredential) {
87    let Some(ref id_token) = cred.id_token else { return };
88    let path = codex_credentials_path();
89    if let Some(parent) = path.parent() {
90        let _ = std::fs::create_dir_all(parent);
91    }
92    let auth = serde_json::json!({
93        "tokens": {
94            "access_token": cred.access_token,
95            "refresh_token": cred.refresh_token,
96            "id_token": id_token,
97        }
98    });
99    if let Ok(text) = serde_json::to_string_pretty(&auth) {
100        let _ = std::fs::write(&path, text);
101    }
102}
103
104pub fn codex_credentials_path() -> PathBuf {
105    dirs::home_dir()
106        .unwrap_or_else(|| PathBuf::from("."))
107        .join(".codex")
108        .join("auth.json")
109}
110
111/// Read the OAuth credential from the Codex CLI's stored auth file.
112pub fn read_codex_credentials() -> Option<OAuthCredential> {
113    let text = std::fs::read_to_string(codex_credentials_path()).ok()?;
114    let raw: CodexAuth = serde_json::from_str(&text).ok()?;
115
116    let now_ms = SystemTime::now()
117        .duration_since(UNIX_EPOCH)
118        .unwrap_or_default()
119        .as_millis() as u64;
120
121    // Extract exp from the JWT payload without verifying signature.
122    let expires_at = jwt_exp_ms(&raw.tokens.access_token)
123        .unwrap_or(now_ms + 3600 * 1000); // default: 1 hour from now
124
125    Some(OAuthCredential {
126        access_token: raw.tokens.access_token,
127        refresh_token: raw.tokens.refresh_token.unwrap_or_default(),
128        expires_at,
129        email: None,
130        id_token: raw.tokens.id_token,
131    })
132}
133
134/// Decode the `exp` claim from a JWT payload (no signature verification).
135/// Returns expiry as Unix milliseconds.
136pub(crate) fn jwt_exp_ms(token: &str) -> Option<u64> {
137    let payload_b64 = token.splitn(3, '.').nth(1)?;
138    // base64url decode (no padding)
139    let decoded = base64_url_decode(payload_b64)?;
140    let v: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
141    let exp_secs = v.get("exp")?.as_u64()?;
142    Some(exp_secs * 1000)
143}
144
145/// Minimal base64url decoder (no padding, URL-safe alphabet).
146fn base64_url_decode(s: &str) -> Option<Vec<u8>> {
147    // Convert base64url to standard base64 with padding
148    let mut standard = s.replace('-', "+").replace('_', "/");
149    match standard.len() % 4 {
150        2 => standard.push_str("=="),
151        3 => standard.push('='),
152        _ => {}
153    }
154    use std::io::Read;
155    // Use the standard library's base64 via a simple approach
156    // Rust std doesn't have base64, implement a small decoder
157    let alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
158    let mut table = [0u8; 256];
159    for (i, &c) in alphabet.iter().enumerate() {
160        table[c as usize] = i as u8;
161    }
162    let bytes = standard.as_bytes();
163    let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
164    let mut i = 0;
165    while i + 3 < bytes.len() {
166        let b0 = bytes[i];
167        let b1 = bytes[i + 1];
168        let b2 = bytes[i + 2];
169        let b3 = bytes[i + 3];
170        if b0 == b'=' { break; }
171        let n0 = table[b0 as usize] as u32;
172        let n1 = table[b1 as usize] as u32;
173        let n2 = if b2 == b'=' { 0 } else { table[b2 as usize] as u32 };
174        let n3 = if b3 == b'=' { 0 } else { table[b3 as usize] as u32 };
175        let val = (n0 << 18) | (n1 << 12) | (n2 << 6) | n3;
176        out.push(((val >> 16) & 0xFF) as u8);
177        if b2 != b'=' { out.push(((val >> 8) & 0xFF) as u8); }
178        if b3 != b'=' { out.push((val & 0xFF) as u8); }
179        i += 4;
180    }
181    let _ = Read::read(&mut out.as_slice(), &mut []); // suppress unused import warning
182    Some(out)
183}
184
185
186/// Raw format used by ~/.claude/.credentials.json
187#[derive(Deserialize)]
188#[serde(rename_all = "camelCase")]
189struct ClaudeCredentials {
190    claude_ai_oauth: Option<ClaudeOAuthRaw>,
191}
192
193#[derive(Deserialize)]
194#[serde(rename_all = "camelCase")]
195struct ClaudeOAuthRaw {
196    access_token: String,
197    refresh_token: String,
198    expires_at: u64,
199}
200
201// ---------------------------------------------------------------------------
202// Session info (plan + identity) from stored credentials
203// ---------------------------------------------------------------------------
204
205pub struct SessionInfo {
206    pub email_or_id: String,
207    pub plan: String,
208}
209
210/// Read plan and identity from Claude Code's stored credentials JSON.
211/// Works for both keychain and file-based storage.
212pub fn read_claude_session_info() -> Option<SessionInfo> {
213    #[derive(serde::Deserialize)]
214    #[serde(rename_all = "camelCase")]
215    struct Outer {
216        claude_ai_oauth: Option<Inner>,
217    }
218    #[derive(serde::Deserialize)]
219    #[serde(rename_all = "camelCase")]
220    struct Inner {
221        subscription_type: Option<String>,
222        #[serde(rename = "rateLimitTier")]
223        rate_limit_tier: Option<String>,
224    }
225
226    let text = read_raw_credentials_json()?;
227    let outer: Outer = serde_json::from_str(&text).ok()?;
228    let inner = outer.claude_ai_oauth?;
229
230    let plan = inner.subscription_type.unwrap_or_else(|| "pro".into());
231    let email_or_id = inner.rate_limit_tier.unwrap_or_else(|| "unknown".into());
232
233    Some(SessionInfo { email_or_id, plan })
234}
235
236/// Returns the raw credentials JSON string from keychain (macOS) or file.
237fn read_raw_credentials_json() -> Option<String> {
238    #[cfg(target_os = "macos")]
239    {
240        let out = std::process::Command::new("security")
241            .args(["find-generic-password", "-s", "Claude Code-credentials", "-w"])
242            .output()
243            .ok()?;
244        if out.status.success() {
245            let s = String::from_utf8(out.stdout).ok()?;
246            return Some(s.trim().to_owned());
247        }
248    }
249    std::fs::read_to_string(claude_credentials_path()).ok()
250}
251
252pub fn claude_credentials_path() -> PathBuf {
253    dirs::home_dir()
254        .unwrap_or_else(|| PathBuf::from("."))
255        .join(".claude")
256        .join(".credentials.json")
257}
258
259/// Read the OAuth credential from Claude Code's own credential file.
260/// On macOS, tries the Keychain first (service "Claude Code-credentials"),
261/// then falls back to ~/.claude/.credentials.json.
262pub fn read_claude_credentials() -> Option<OAuthCredential> {
263    // macOS: try Keychain first
264    #[cfg(target_os = "macos")]
265    if let Some(cred) = read_claude_credentials_keychain() {
266        return Some(cred);
267    }
268
269    // Fallback: JSON file (older Claude Code versions / non-macOS)
270    let path = claude_credentials_path();
271    let text = std::fs::read_to_string(&path).ok()?;
272    parse_claude_credentials_json(&text)
273}
274
275#[cfg(target_os = "macos")]
276fn read_claude_credentials_keychain() -> Option<OAuthCredential> {
277    let text = read_raw_credentials_json()?;
278    parse_claude_credentials_json(&text)
279}
280
281fn parse_claude_credentials_json(text: &str) -> Option<OAuthCredential> {
282    let raw: ClaudeCredentials = serde_json::from_str(text).ok()?;
283    let inner = raw.claude_ai_oauth?;
284    Some(OAuthCredential {
285        access_token: inner.access_token,
286        refresh_token: inner.refresh_token,
287        expires_at: inner.expires_at,
288        email: None,
289        id_token: None,
290    })
291}
292
293// ---------------------------------------------------------------------------
294// Token refresh
295// ---------------------------------------------------------------------------
296
297/// Refresh an expired access token. Returns the updated credential.
298pub async fn refresh_token(cred: &OAuthCredential) -> Result<OAuthCredential> {
299    let client = reqwest::Client::new();
300
301    let resp = client
302        .post(OAUTH_TOKEN_URL)
303        .header("content-type", "application/x-www-form-urlencoded")
304        .body(format!(
305            "grant_type=refresh_token&refresh_token={}&client_id={}",
306            urlencoding::encode(&cred.refresh_token),
307            OAUTH_CLIENT_ID,
308        ))
309        .send()
310        .await
311        .context("token refresh request failed")?;
312
313    if !resp.status().is_success() {
314        let status = resp.status();
315        let body = resp.text().await.unwrap_or_default();
316        bail!("token refresh failed ({status}): {body}");
317    }
318
319    let body: serde_json::Value = resp.json().await.context("token refresh: invalid JSON")?;
320
321    let access_token = body["access_token"]
322        .as_str()
323        .context("token refresh: missing access_token")?
324        .to_owned();
325
326    let refresh_token = body["refresh_token"]
327        .as_str()
328        .unwrap_or(&cred.refresh_token)
329        .to_owned();
330
331    // expires_in is seconds from now
332    let expires_in_secs = body["expires_in"].as_u64().unwrap_or(3600);
333    let now_ms = SystemTime::now()
334        .duration_since(UNIX_EPOCH)
335        .unwrap_or_default()
336        .as_millis() as u64;
337    let expires_at = now_ms + expires_in_secs * 1000;
338
339    Ok(OAuthCredential { access_token, refresh_token, expires_at, email: cred.email.clone(), id_token: None })
340}
341
342// ---------------------------------------------------------------------------
343// Account identity (email) from roles endpoint
344// ---------------------------------------------------------------------------
345
346/// Fetch the account email from the Anthropic roles endpoint.
347/// Returns `None` on any error (non-fatal).
348pub async fn fetch_account_email(access_token: &str) -> Option<String> {
349    let client = reqwest::Client::builder()
350        .timeout(std::time::Duration::from_secs(8))
351        .build()
352        .ok()?;
353    let resp = client
354        .get("https://api.anthropic.com/api/oauth/claude_cli/roles")
355        .header("authorization", format!("Bearer {access_token}"))
356        .header("anthropic-version", "2023-06-01")
357        .header("anthropic-dangerous-direct-browser-access", "true")
358        .send()
359        .await
360        .ok()?;
361
362    if !resp.status().is_success() {
363        return None;
364    }
365
366    let body: serde_json::Value = resp.json().await.ok()?;
367    // organization_name is "email's Organization" — extract email prefix
368    let org = body["organization_name"].as_str()?;
369    if let Some(email) = org.strip_suffix("'s Organization") {
370        Some(email.to_owned())
371    } else {
372        Some(org.to_owned())
373    }
374}
375
376// ---------------------------------------------------------------------------
377// PKCE browser OAuth flow (for adding additional accounts)
378// ---------------------------------------------------------------------------
379
380struct Pkce {
381    verifier: String,
382    challenge: String,
383}
384
385fn generate_pkce() -> Pkce {
386    let verifier_bytes: [u8; 32] = rand_bytes();
387    let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
388
389    let hash = Sha256::digest(verifier.as_bytes());
390    let challenge = URL_SAFE_NO_PAD.encode(hash);
391
392    Pkce { verifier, challenge }
393}
394
395/// Generate N cryptographically random bytes using the OS entropy source.
396/// Panics if the system RNG is unavailable (unrecoverable error in a security context).
397pub fn rand_bytes<const N: usize>() -> [u8; N] {
398    let mut bytes = [0u8; N];
399    getrandom::getrandom(&mut bytes)
400        .expect("OS random number generator unavailable — cannot generate secure random bytes");
401    bytes
402}
403
404fn random_state() -> String {
405    let bytes: [u8; 16] = rand_bytes();
406    hex::encode(bytes)
407}
408
409pub const OAUTH_REDIRECT_URI: &str = "https://platform.claude.com/oauth/code/callback";
410
411/// Run the PKCE OAuth flow using the registered redirect URI.
412///
413/// Opens the browser to claude.ai. After the user authorizes, the callback page
414/// displays a code (format: CODE#STATE). The user pastes it here; we split out
415/// the state and exchange the code at the token endpoint.
416pub async fn run_oauth_flow() -> Result<OAuthCredential> {
417    use std::io::{self, Write};
418
419    let pkce = generate_pkce();
420    let state = random_state();
421    let redirect_uri = OAUTH_REDIRECT_URI;
422
423    let scope = urlencoding::encode(
424        "user:inference user:profile user:file_upload user:mcp_servers user:sessions:claude_code",
425    );
426    let auth_url = format!(
427        "{base}?response_type=code\
428         &client_id={client_id}\
429         &redirect_uri={redirect}\
430         &scope={scope}\
431         &state={state}\
432         &code_challenge={challenge}\
433         &code_challenge_method=S256",
434        base = OAUTH_AUTHORIZE_URL,
435        client_id = OAUTH_CLIENT_ID,
436        redirect = urlencoding::encode(redirect_uri),
437        scope = scope,
438        state = state,
439        challenge = pkce.challenge,
440    );
441
442    println!("\nOpening browser for claude.ai login...");
443    println!("If it does not open automatically, visit:\n  {auth_url}\n");
444    open_browser(&auth_url);
445
446    println!("After you authorize, the page will show an authorization code.");
447    println!("Copy it and paste it here.");
448    println!();
449    print!("Paste code: ");
450    io::stdout().flush()?;
451
452    let mut pasted = String::new();
453    io::stdin().read_line(&mut pasted)?;
454    // Page shows "code#state"
455    let pasted = pasted.trim();
456    let (code, pasted_state) = if let Some((c, s)) = pasted.split_once('#') {
457        (c.trim(), s.trim())
458    } else {
459        (pasted, state.as_str())
460    };
461
462    if code.is_empty() {
463        bail!("No code entered.");
464    }
465
466    let cred = exchange_code(code, pasted_state, redirect_uri, &pkce.verifier).await?;
467    Ok(cred)
468}
469
470async fn exchange_code(code: &str, state: &str, redirect_uri: &str, verifier: &str) -> Result<OAuthCredential> {
471    let client = reqwest::Client::new();
472
473    let body = serde_json::json!({
474        "grant_type": "authorization_code",
475        "code": code,
476        "state": state,
477        "redirect_uri": redirect_uri,
478        "client_id": OAUTH_CLIENT_ID,
479        "code_verifier": verifier,
480    });
481
482    let resp = client
483        .post(OAUTH_TOKEN_URL)
484        .header("content-type", "application/json")
485        .header("anthropic-version", "2023-06-01")
486        .json(&body)
487        .send()
488        .await
489        .context("token exchange request failed")?;
490
491    if !resp.status().is_success() {
492        let status = resp.status();
493        let body = resp.text().await.unwrap_or_default();
494        bail!("token exchange failed ({status}): {body}");
495    }
496
497    let body: serde_json::Value = resp.json().await.context("token exchange: invalid JSON")?;
498
499    let access_token = body["access_token"]
500        .as_str()
501        .context("token exchange: missing access_token")?
502        .to_owned();
503    let refresh_token = body["refresh_token"]
504        .as_str()
505        .unwrap_or("")
506        .to_owned();
507    let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
508    let now_ms = SystemTime::now()
509        .duration_since(UNIX_EPOCH)
510        .unwrap_or_default()
511        .as_millis() as u64;
512
513    Ok(OAuthCredential {
514        access_token,
515        refresh_token,
516        expires_at: now_ms + expires_in * 1000,
517        email: None,
518        id_token: None,
519    })
520}
521
522// ---------------------------------------------------------------------------
523// Token revocation
524// ---------------------------------------------------------------------------
525
526pub const OAUTH_REVOKE_URL: &str = "https://platform.claude.com/v1/oauth/revoke";
527
528/// Revoke an OAuth token on the server. Best-effort — errors are non-fatal.
529pub async fn revoke_token(access_token: &str) -> bool {
530    let client = reqwest::Client::builder()
531        .timeout(std::time::Duration::from_secs(8))
532        .build()
533        .unwrap_or_default();
534    client
535        .post(OAUTH_REVOKE_URL)
536        .header("content-type", "application/x-www-form-urlencoded")
537        .header("anthropic-version", "2023-06-01")
538        .body(format!("token={}", urlencoding::encode(access_token)))
539        .send()
540        .await
541        .map(|r| r.status().is_success())
542        .unwrap_or(false)
543}
544
545// ---------------------------------------------------------------------------
546// OpenAI token refresh
547// ---------------------------------------------------------------------------
548
549/// Refresh an expired OpenAI / Codex access token using the stored refresh_token.
550pub async fn refresh_openai_token(cred: &OAuthCredential) -> Result<OAuthCredential> {
551    let client = reqwest::Client::new();
552
553    let resp = client
554        .post(OPENAI_OAUTH_TOKEN_URL)
555        .header("content-type", "application/x-www-form-urlencoded")
556        .body(format!(
557            "grant_type=refresh_token&refresh_token={}&client_id={}",
558            urlencoding::encode(&cred.refresh_token),
559            OPENAI_OAUTH_CLIENT_ID,
560        ))
561        .send()
562        .await
563        .context("OpenAI token refresh request failed")?;
564
565    if !resp.status().is_success() {
566        let status = resp.status();
567        let body = resp.text().await.unwrap_or_default();
568        bail!("OpenAI token refresh failed ({status}): {body}");
569    }
570
571    let body: serde_json::Value = resp.json().await.context("OpenAI token refresh: invalid JSON")?;
572
573    let access_token = body["access_token"]
574        .as_str()
575        .context("OpenAI token refresh: missing access_token")?
576        .to_owned();
577
578    let refresh_token = body["refresh_token"]
579        .as_str()
580        .unwrap_or(&cred.refresh_token)
581        .to_owned();
582
583    let id_token = body["id_token"].as_str().map(|s| s.to_owned())
584        .or_else(|| cred.id_token.clone());
585
586    let expires_in_secs = body["expires_in"].as_u64().unwrap_or(3600);
587    let now_ms = SystemTime::now()
588        .duration_since(UNIX_EPOCH)
589        .unwrap_or_default()
590        .as_millis() as u64;
591
592    Ok(OAuthCredential {
593        access_token,
594        refresh_token,
595        expires_at: now_ms + expires_in_secs * 1000,
596        email: cred.email.clone(),
597        id_token,
598    })
599}
600
601// ---------------------------------------------------------------------------
602// OpenAI / Codex device code flow (custom 3-step, not RFC 8628)
603// ---------------------------------------------------------------------------
604//
605// Codex uses its own device auth protocol:
606//   1. POST /deviceauth/usercode  {"client_id"} → {device_auth_id, user_code, interval}
607//   2. Poll  POST /deviceauth/token  {"device_auth_id","user_code"} until 200
608//            → {authorization_code, code_verifier, code_challenge}
609//   3. POST /oauth/token  PKCE exchange → {access_token, refresh_token, id_token}
610//
611// Verification URI where the user enters the code: https://auth.openai.com/codex/device
612
613/// Run the Codex device authorization flow. No local HTTP server required.
614///
615/// Displays a short user_code; the user visits `https://auth.openai.com/codex/device`
616/// and enters it. We poll until authorized, then exchange for tokens.
617pub async fn run_openai_oauth_flow() -> Result<OAuthCredential> {
618    const VERIFY_URI: &str = "https://auth.openai.com/codex/device";
619    const TIMEOUT_SECS: u64 = 15 * 60;
620
621    let client = reqwest::Client::new();
622
623    // Step 1: request user code
624    let resp = client
625        .post(OPENAI_DEVICE_CODE_URL)
626        .header("content-type", "application/json")
627        .json(&serde_json::json!({"client_id": OPENAI_OAUTH_CLIENT_ID}))
628        .send()
629        .await
630        .context("Codex device code request failed")?;
631
632    if !resp.status().is_success() {
633        let status = resp.status();
634        let body = resp.text().await.unwrap_or_default();
635        bail!("Codex device code request failed ({status}): {body}");
636    }
637
638    let info: serde_json::Value = resp.json().await.context("device code: invalid JSON")?;
639    let device_auth_id = info["device_auth_id"].as_str().context("missing device_auth_id")?.to_owned();
640    let user_code = info["user_code"].as_str().context("missing user_code")?.to_owned();
641    let interval_secs = info["interval"].as_u64().unwrap_or(5);
642
643    println!();
644    println!("  Visit:  {VERIFY_URI}");
645    println!("  Code:   \x1b[1;33m{user_code}\x1b[0m");
646    println!();
647    println!("  Waiting for authorization...");
648
649    open_browser(VERIFY_URI);
650
651    // Step 2: poll until code is approved
652    let deadline = std::time::Instant::now() + std::time::Duration::from_secs(TIMEOUT_SECS);
653    let poll_interval = std::time::Duration::from_secs(interval_secs);
654    let poll_body = serde_json::json!({
655        "device_auth_id": device_auth_id,
656        "user_code": user_code,
657    });
658
659    let (authorization_code, code_verifier) = loop {
660        tokio::time::sleep(poll_interval).await;
661
662        if std::time::Instant::now() > deadline {
663            bail!("Device code expired (15 min). Run `shunt add-account` again.");
664        }
665
666        let resp = client
667            .post(OPENAI_DEVICE_TOKEN_URL)
668            .header("content-type", "application/json")
669            .json(&poll_body)
670            .send()
671            .await
672            .context("Codex device poll request failed")?;
673
674        let status = resp.status();
675        // 403/404 = still pending; any 2xx = authorized
676        if status.as_u16() == 403 || status.as_u16() == 404 {
677            continue;
678        }
679        if !status.is_success() {
680            let body = resp.text().await.unwrap_or_default();
681            bail!("Codex device poll error ({status}): {body}");
682        }
683
684        let body: serde_json::Value = resp.json().await.context("device poll: invalid JSON")?;
685        let code = body["authorization_code"].as_str().context("missing authorization_code")?.to_owned();
686        let verifier = body["code_verifier"].as_str().context("missing code_verifier")?.to_owned();
687        break (code, verifier);
688    };
689
690    // Step 3: exchange authorization_code for tokens
691    let redirect_uri = format!("https://auth.openai.com/deviceauth/callback");
692    let token_body = format!(
693        "grant_type=authorization_code&code={}&redirect_uri={}&client_id={}&code_verifier={}",
694        urlencoding::encode(&authorization_code),
695        urlencoding::encode(&redirect_uri),
696        OPENAI_OAUTH_CLIENT_ID,
697        urlencoding::encode(&code_verifier),
698    );
699    let resp = client
700        .post(OPENAI_OAUTH_TOKEN_URL)
701        .header("content-type", "application/x-www-form-urlencoded")
702        .body(token_body)
703        .send()
704        .await
705        .context("Codex token exchange failed")?;
706
707    if !resp.status().is_success() {
708        let status = resp.status();
709        let body = resp.text().await.unwrap_or_default();
710        bail!("Codex token exchange failed ({status}): {body}");
711    }
712
713    let body: serde_json::Value = resp.json().await.context("token exchange: invalid JSON")?;
714    let access_token = body["access_token"]
715        .as_str()
716        .or_else(|| body["id_token"].as_str())
717        .context("token exchange: missing access_token")?
718        .to_owned();
719    let refresh_token = body["refresh_token"].as_str().unwrap_or("").to_owned();
720    let id_token = body["id_token"].as_str().map(|s| s.to_owned());
721    let expires_at = jwt_exp_ms(&access_token).unwrap_or_else(|| {
722        let now_ms = SystemTime::now()
723            .duration_since(UNIX_EPOCH)
724            .unwrap_or_default()
725            .as_millis() as u64;
726        now_ms + body["expires_in"].as_u64().unwrap_or(3600) * 1000
727    });
728
729    Ok(OAuthCredential { access_token, refresh_token, expires_at, email: None, id_token })
730}
731
732// ---------------------------------------------------------------------------
733// OpenAI account identity
734// ---------------------------------------------------------------------------
735
736/// Fetch the account email from OpenAI's userinfo endpoint.
737pub async fn fetch_openai_account_email(access_token: &str) -> Option<String> {
738    let client = reqwest::Client::builder()
739        .timeout(std::time::Duration::from_secs(8))
740        .build()
741        .ok()?;
742    let resp = client
743        .get("https://auth.openai.com/userinfo")
744        .header("authorization", format!("Bearer {access_token}"))
745        .send()
746        .await
747        .ok()?;
748    if !resp.status().is_success() { return None; }
749    let body: serde_json::Value = resp.json().await.ok()?;
750    body["email"].as_str().map(|s| s.to_owned())
751}
752
753fn open_browser(url: &str) {
754    #[cfg(target_os = "macos")]
755    { std::process::Command::new("open").arg(url).spawn().ok(); }
756
757    #[cfg(target_os = "linux")]
758    { std::process::Command::new("xdg-open").arg(url).spawn().ok(); }
759
760    // Use explorer.exe directly — avoids cmd.exe shell expansion of OAuth URL
761    // special characters (& % etc.) that would misparse with `cmd /c start`.
762    #[cfg(target_os = "windows")]
763    { std::process::Command::new("explorer").arg(url).spawn().ok(); }
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769
770    #[test]
771    fn test_rand_bytes_correct_length() {
772        let a: [u8; 16] = rand_bytes();
773        assert_eq!(a.len(), 16);
774        let b: [u8; 32] = rand_bytes();
775        assert_eq!(b.len(), 32);
776    }
777
778    #[test]
779    fn test_rand_bytes_not_all_zeros() {
780        // The probability of 32 random bytes all being zero is 1/2^256 — effectively impossible.
781        let bytes: [u8; 32] = rand_bytes();
782        assert!(bytes.iter().any(|&b| b != 0), "rand_bytes must not return all-zero output");
783    }
784
785    #[test]
786    fn test_rand_bytes_unique() {
787        // Two calls must not return the same value (probability 1/2^256 they'd collide).
788        let a: [u8; 32] = rand_bytes();
789        let b: [u8; 32] = rand_bytes();
790        assert_ne!(a, b, "rand_bytes must return unique values each call");
791    }
792
793    #[test]
794    fn test_pkce_pair_properties() {
795        let pkce = generate_pkce();
796        // Verifier must be base64url-safe (no padding, only URL-safe chars)
797        assert!(pkce.verifier.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
798            "PKCE verifier must be base64url-safe");
799        // Challenge must differ from verifier (it's the SHA-256 hash)
800        assert_ne!(pkce.verifier, pkce.challenge,
801            "PKCE challenge must not equal verifier");
802        assert!(!pkce.challenge.is_empty());
803        assert!(!pkce.verifier.is_empty());
804    }
805}