Skip to main content

shunt/
oauth.rs

1/// OAuth 2.0 PKCE flow + token refresh for claude.ai accounts.
2///
3/// Claude Code authenticates via OAuth, not API keys. Credentials are stored
4/// in ~/.claude/.credentials.json and sent as `Authorization: Bearer <token>`.
5use anyhow::{bail, Context, Result};
6use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::path::PathBuf;
10use std::time::{SystemTime, UNIX_EPOCH};
11
12// ---------------------------------------------------------------------------
13// Anthropic OAuth constants
14// ---------------------------------------------------------------------------
15
16pub const OAUTH_CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
17pub const OAUTH_AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
18pub const OAUTH_TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token";
19
20// ---------------------------------------------------------------------------
21// OpenAI / Codex OAuth constants
22// ---------------------------------------------------------------------------
23
24pub const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
25pub const OPENAI_OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
26pub const OPENAI_DEVICE_CODE_URL: &str = "https://auth.openai.com/api/accounts/deviceauth/usercode";
27pub const OPENAI_DEVICE_TOKEN_URL: &str = "https://auth.openai.com/api/accounts/deviceauth/token";
28
29// ---------------------------------------------------------------------------
30// Credential type
31// ---------------------------------------------------------------------------
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct OAuthCredential {
35    pub access_token: String,
36    pub refresh_token: String,
37    /// Milliseconds since Unix epoch
38    pub expires_at: u64,
39    /// Account email, fetched from roles endpoint after auth
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub email: Option<String>,
42    /// OpenAI id_token — required by the Codex CLI's ~/.codex/auth.json
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub id_token: Option<String>,
45}
46
47impl OAuthCredential {
48    /// True if the token expires within the next 5 minutes.
49    pub fn needs_refresh(&self) -> bool {
50        let now_ms = SystemTime::now()
51            .duration_since(UNIX_EPOCH)
52            .unwrap_or_default()
53            .as_millis() as u64;
54        now_ms >= self.expires_at.saturating_sub(5 * 60 * 1000)
55    }
56}
57
58// ---------------------------------------------------------------------------
59// Auto-import from Claude Code's own credential file
60// ---------------------------------------------------------------------------
61
62// ---------------------------------------------------------------------------
63// Auto-import from Codex CLI's credential file (~/.codex/auth.json)
64// ---------------------------------------------------------------------------
65
66/// Raw format used by ~/.codex/auth.json
67/// The tokens are nested under a "tokens" key; there is no top-level expires_at.
68/// Expiry is read from the JWT `exp` claim inside the access_token.
69#[derive(Deserialize)]
70struct CodexAuth {
71    tokens: CodexTokens,
72}
73
74#[derive(Deserialize)]
75struct CodexTokens {
76    access_token: String,
77    #[serde(default)]
78    refresh_token: Option<String>,
79    #[serde(default)]
80    id_token: Option<String>,
81}
82
83/// Write credentials to ~/.codex/auth.json so the Codex CLI can use them without re-login.
84///
85/// Called automatically after add-account and token refresh for OpenAI accounts.
86pub fn write_codex_auth_file(cred: &OAuthCredential) {
87    let Some(ref id_token) = cred.id_token else { return };
88    let path = codex_credentials_path();
89    if let Some(parent) = path.parent() {
90        let _ = std::fs::create_dir_all(parent);
91    }
92    let auth = serde_json::json!({
93        "tokens": {
94            "access_token": cred.access_token,
95            "refresh_token": cred.refresh_token,
96            "id_token": id_token,
97        }
98    });
99    if let Ok(text) = serde_json::to_string_pretty(&auth) {
100        let _ = std::fs::write(&path, text);
101    }
102}
103
104pub fn codex_credentials_path() -> PathBuf {
105    dirs::home_dir()
106        .unwrap_or_else(|| PathBuf::from("."))
107        .join(".codex")
108        .join("auth.json")
109}
110
111/// Read the OAuth credential from the Codex CLI's stored auth file.
112pub fn read_codex_credentials() -> Option<OAuthCredential> {
113    let text = std::fs::read_to_string(codex_credentials_path()).ok()?;
114    let raw: CodexAuth = serde_json::from_str(&text).ok()?;
115
116    let now_ms = SystemTime::now()
117        .duration_since(UNIX_EPOCH)
118        .unwrap_or_default()
119        .as_millis() as u64;
120
121    // Extract exp from the JWT payload without verifying signature.
122    let expires_at = jwt_exp_ms(&raw.tokens.access_token)
123        .unwrap_or(now_ms + 3600 * 1000); // default: 1 hour from now
124
125    Some(OAuthCredential {
126        access_token: raw.tokens.access_token,
127        refresh_token: raw.tokens.refresh_token.unwrap_or_default(),
128        expires_at,
129        email: None,
130        id_token: raw.tokens.id_token,
131    })
132}
133
134/// Decode the `exp` claim from a JWT payload (no signature verification).
135/// Returns expiry as Unix milliseconds.
136pub(crate) fn jwt_exp_ms(token: &str) -> Option<u64> {
137    let payload_b64 = token.splitn(3, '.').nth(1)?;
138    // base64url decode (no padding)
139    let decoded = base64_url_decode(payload_b64)?;
140    let v: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
141    let exp_secs = v.get("exp")?.as_u64()?;
142    Some(exp_secs * 1000)
143}
144
145/// Minimal base64url decoder (no padding, URL-safe alphabet).
146fn base64_url_decode(s: &str) -> Option<Vec<u8>> {
147    // Convert base64url to standard base64 with padding
148    let mut standard = s.replace('-', "+").replace('_', "/");
149    match standard.len() % 4 {
150        2 => standard.push_str("=="),
151        3 => standard.push('='),
152        _ => {}
153    }
154    use std::io::Read;
155    // Use the standard library's base64 via a simple approach
156    // Rust std doesn't have base64, implement a small decoder
157    let alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
158    let mut table = [0u8; 256];
159    for (i, &c) in alphabet.iter().enumerate() {
160        table[c as usize] = i as u8;
161    }
162    let bytes = standard.as_bytes();
163    let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
164    let mut i = 0;
165    while i + 3 < bytes.len() {
166        let b0 = bytes[i];
167        let b1 = bytes[i + 1];
168        let b2 = bytes[i + 2];
169        let b3 = bytes[i + 3];
170        if b0 == b'=' { break; }
171        let n0 = table[b0 as usize] as u32;
172        let n1 = table[b1 as usize] as u32;
173        let n2 = if b2 == b'=' { 0 } else { table[b2 as usize] as u32 };
174        let n3 = if b3 == b'=' { 0 } else { table[b3 as usize] as u32 };
175        let val = (n0 << 18) | (n1 << 12) | (n2 << 6) | n3;
176        out.push(((val >> 16) & 0xFF) as u8);
177        if b2 != b'=' { out.push(((val >> 8) & 0xFF) as u8); }
178        if b3 != b'=' { out.push((val & 0xFF) as u8); }
179        i += 4;
180    }
181    let _ = Read::read(&mut out.as_slice(), &mut []); // suppress unused import warning
182    Some(out)
183}
184
185
186/// Raw format used by ~/.claude/.credentials.json
187#[derive(Deserialize)]
188#[serde(rename_all = "camelCase")]
189struct ClaudeCredentials {
190    claude_ai_oauth: Option<ClaudeOAuthRaw>,
191}
192
193#[derive(Deserialize)]
194#[serde(rename_all = "camelCase")]
195struct ClaudeOAuthRaw {
196    access_token: String,
197    refresh_token: String,
198    expires_at: u64,
199}
200
201// ---------------------------------------------------------------------------
202// Session info (plan + identity) from stored credentials
203// ---------------------------------------------------------------------------
204
205pub struct SessionInfo {
206    pub email_or_id: String,
207    pub plan: String,
208}
209
210/// Read plan and identity from Claude Code's stored credentials JSON.
211/// Works for both keychain and file-based storage.
212pub fn read_claude_session_info() -> Option<SessionInfo> {
213    #[derive(serde::Deserialize)]
214    #[serde(rename_all = "camelCase")]
215    struct Outer {
216        claude_ai_oauth: Option<Inner>,
217    }
218    #[derive(serde::Deserialize)]
219    #[serde(rename_all = "camelCase")]
220    struct Inner {
221        subscription_type: Option<String>,
222        #[serde(rename = "rateLimitTier")]
223        rate_limit_tier: Option<String>,
224    }
225
226    let text = read_raw_credentials_json()?;
227    let outer: Outer = serde_json::from_str(&text).ok()?;
228    let inner = outer.claude_ai_oauth?;
229
230    let plan = inner.subscription_type.unwrap_or_else(|| "pro".into());
231    let email_or_id = inner.rate_limit_tier.unwrap_or_else(|| "unknown".into());
232
233    Some(SessionInfo { email_or_id, plan })
234}
235
236/// Returns the raw credentials JSON string from keychain (macOS) or file.
237fn read_raw_credentials_json() -> Option<String> {
238    #[cfg(target_os = "macos")]
239    {
240        let out = std::process::Command::new("security")
241            .args(["find-generic-password", "-s", "Claude Code-credentials", "-w"])
242            .output()
243            .ok()?;
244        if out.status.success() {
245            let s = String::from_utf8(out.stdout).ok()?;
246            return Some(s.trim().to_owned());
247        }
248    }
249    std::fs::read_to_string(claude_credentials_path()).ok()
250}
251
252pub fn claude_credentials_path() -> PathBuf {
253    dirs::home_dir()
254        .unwrap_or_else(|| PathBuf::from("."))
255        .join(".claude")
256        .join(".credentials.json")
257}
258
259/// Read the OAuth credential from Claude Code's own credential file.
260/// On macOS, tries the Keychain first (service "Claude Code-credentials"),
261/// then falls back to ~/.claude/.credentials.json.
262pub fn read_claude_credentials() -> Option<OAuthCredential> {
263    // macOS: try Keychain first
264    #[cfg(target_os = "macos")]
265    if let Some(cred) = read_claude_credentials_keychain() {
266        return Some(cred);
267    }
268
269    // Fallback: JSON file (older Claude Code versions / non-macOS)
270    let path = claude_credentials_path();
271    let text = std::fs::read_to_string(&path).ok()?;
272    parse_claude_credentials_json(&text)
273}
274
275#[cfg(target_os = "macos")]
276fn read_claude_credentials_keychain() -> Option<OAuthCredential> {
277    let text = read_raw_credentials_json()?;
278    parse_claude_credentials_json(&text)
279}
280
281fn parse_claude_credentials_json(text: &str) -> Option<OAuthCredential> {
282    let raw: ClaudeCredentials = serde_json::from_str(text).ok()?;
283    let inner = raw.claude_ai_oauth?;
284    Some(OAuthCredential {
285        access_token: inner.access_token,
286        refresh_token: inner.refresh_token,
287        expires_at: inner.expires_at,
288        email: None,
289        id_token: None,
290    })
291}
292
293// ---------------------------------------------------------------------------
294// Token refresh
295// ---------------------------------------------------------------------------
296
297/// Refresh an expired access token. Returns the updated credential.
298pub async fn refresh_token(cred: &OAuthCredential) -> Result<OAuthCredential> {
299    let client = reqwest::Client::new();
300
301    let resp = client
302        .post(OAUTH_TOKEN_URL)
303        .header("content-type", "application/x-www-form-urlencoded")
304        .body(format!(
305            "grant_type=refresh_token&refresh_token={}&client_id={}",
306            urlencoding::encode(&cred.refresh_token),
307            OAUTH_CLIENT_ID,
308        ))
309        .send()
310        .await
311        .context("token refresh request failed")?;
312
313    if !resp.status().is_success() {
314        let status = resp.status();
315        let body = resp.text().await.unwrap_or_default();
316        bail!("token refresh failed ({status}): {body}");
317    }
318
319    let body: serde_json::Value = resp.json().await.context("token refresh: invalid JSON")?;
320
321    let access_token = body["access_token"]
322        .as_str()
323        .context("token refresh: missing access_token")?
324        .to_owned();
325
326    let refresh_token = body["refresh_token"]
327        .as_str()
328        .unwrap_or(&cred.refresh_token)
329        .to_owned();
330
331    // expires_in is seconds from now
332    let expires_in_secs = body["expires_in"].as_u64().unwrap_or(3600);
333    let now_ms = SystemTime::now()
334        .duration_since(UNIX_EPOCH)
335        .unwrap_or_default()
336        .as_millis() as u64;
337    let expires_at = now_ms + expires_in_secs * 1000;
338
339    Ok(OAuthCredential { access_token, refresh_token, expires_at, email: cred.email.clone(), id_token: None })
340}
341
342// ---------------------------------------------------------------------------
343// Account identity (email) from roles endpoint
344// ---------------------------------------------------------------------------
345
346/// Fetch the account email from the Anthropic roles endpoint.
347/// Returns `None` on any error (non-fatal).
348pub async fn fetch_account_email(access_token: &str) -> Option<String> {
349    let client = reqwest::Client::builder()
350        .timeout(std::time::Duration::from_secs(8))
351        .build()
352        .ok()?;
353    let resp = client
354        .get("https://api.anthropic.com/api/oauth/claude_cli/roles")
355        .header("authorization", format!("Bearer {access_token}"))
356        .header("anthropic-version", "2023-06-01")
357        .header("anthropic-dangerous-direct-browser-access", "true")
358        .send()
359        .await
360        .ok()?;
361
362    if !resp.status().is_success() {
363        return None;
364    }
365
366    let body: serde_json::Value = resp.json().await.ok()?;
367
368    // Try dedicated email fields first.
369    for field in &["email", "emailAddress", "email_address"] {
370        if let Some(e) = body[field].as_str().filter(|s| s.contains('@')) {
371            return Some(e.to_owned());
372        }
373    }
374
375    // Fall back to extracting from organization_name ("addr@example.com's Organization").
376    let org = body["organization_name"].as_str()?;
377    let email = org.strip_suffix("'s Organization").unwrap_or(org).trim();
378    if !email.is_empty() { Some(email.to_owned()) } else { None }
379}
380
381// ---------------------------------------------------------------------------
382// PKCE browser OAuth flow (for adding additional accounts)
383// ---------------------------------------------------------------------------
384
385struct Pkce {
386    verifier: String,
387    challenge: String,
388}
389
390fn generate_pkce() -> Pkce {
391    let verifier_bytes: [u8; 32] = rand_bytes();
392    let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
393
394    let hash = Sha256::digest(verifier.as_bytes());
395    let challenge = URL_SAFE_NO_PAD.encode(hash);
396
397    Pkce { verifier, challenge }
398}
399
400/// Generate N cryptographically random bytes using the OS entropy source.
401/// Panics if the system RNG is unavailable (unrecoverable error in a security context).
402pub fn rand_bytes<const N: usize>() -> [u8; N] {
403    let mut bytes = [0u8; N];
404    getrandom::getrandom(&mut bytes)
405        .expect("OS random number generator unavailable — cannot generate secure random bytes");
406    bytes
407}
408
409fn random_state() -> String {
410    let bytes: [u8; 16] = rand_bytes();
411    hex::encode(bytes)
412}
413
414pub const OAUTH_REDIRECT_URI: &str = "https://platform.claude.com/oauth/code/callback";
415
416/// Run the PKCE OAuth flow using the registered redirect URI.
417///
418/// Opens the browser to claude.ai. After the user authorizes, the callback page
419/// displays a code (format: CODE#STATE). The user pastes it here; we split out
420/// the state and exchange the code at the token endpoint.
421pub async fn run_oauth_flow() -> Result<OAuthCredential> {
422    use std::io::{self, Write};
423
424    let pkce = generate_pkce();
425    let state = random_state();
426    let redirect_uri = OAUTH_REDIRECT_URI;
427
428    let scope = urlencoding::encode(
429        "user:inference user:profile user:file_upload user:mcp_servers user:sessions:claude_code",
430    );
431    let auth_url = format!(
432        "{base}?response_type=code\
433         &client_id={client_id}\
434         &redirect_uri={redirect}\
435         &scope={scope}\
436         &state={state}\
437         &code_challenge={challenge}\
438         &code_challenge_method=S256",
439        base = OAUTH_AUTHORIZE_URL,
440        client_id = OAUTH_CLIENT_ID,
441        redirect = urlencoding::encode(redirect_uri),
442        scope = scope,
443        state = state,
444        challenge = pkce.challenge,
445    );
446
447    println!("\nOpening browser for claude.ai login...");
448    println!("If it does not open automatically, visit:\n  {auth_url}\n");
449    open_browser(&auth_url);
450
451    println!("After you authorize, the page will show an authorization code.");
452    println!("Copy it and paste it here.");
453    println!();
454    print!("Paste code: ");
455    io::stdout().flush()?;
456
457    let mut pasted = String::new();
458    io::stdin().read_line(&mut pasted)?;
459    // Page shows "code#state"
460    let pasted = pasted.trim();
461    let (code, pasted_state) = if let Some((c, s)) = pasted.split_once('#') {
462        (c.trim(), s.trim())
463    } else {
464        (pasted, state.as_str())
465    };
466
467    if code.is_empty() {
468        bail!("No code entered.");
469    }
470
471    let cred = exchange_code(code, pasted_state, redirect_uri, &pkce.verifier).await?;
472    Ok(cred)
473}
474
475async fn exchange_code(code: &str, state: &str, redirect_uri: &str, verifier: &str) -> Result<OAuthCredential> {
476    let client = reqwest::Client::new();
477
478    let body = serde_json::json!({
479        "grant_type": "authorization_code",
480        "code": code,
481        "state": state,
482        "redirect_uri": redirect_uri,
483        "client_id": OAUTH_CLIENT_ID,
484        "code_verifier": verifier,
485    });
486
487    let resp = client
488        .post(OAUTH_TOKEN_URL)
489        .header("content-type", "application/json")
490        .header("anthropic-version", "2023-06-01")
491        .json(&body)
492        .send()
493        .await
494        .context("token exchange request failed")?;
495
496    if !resp.status().is_success() {
497        let status = resp.status();
498        let body = resp.text().await.unwrap_or_default();
499        bail!("token exchange failed ({status}): {body}");
500    }
501
502    let body: serde_json::Value = resp.json().await.context("token exchange: invalid JSON")?;
503
504    let access_token = body["access_token"]
505        .as_str()
506        .context("token exchange: missing access_token")?
507        .to_owned();
508    let refresh_token = body["refresh_token"]
509        .as_str()
510        .unwrap_or("")
511        .to_owned();
512    let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
513    let now_ms = SystemTime::now()
514        .duration_since(UNIX_EPOCH)
515        .unwrap_or_default()
516        .as_millis() as u64;
517
518    Ok(OAuthCredential {
519        access_token,
520        refresh_token,
521        expires_at: now_ms + expires_in * 1000,
522        email: None,
523        id_token: None,
524    })
525}
526
527// ---------------------------------------------------------------------------
528// Token revocation
529// ---------------------------------------------------------------------------
530
531pub const OAUTH_REVOKE_URL: &str = "https://platform.claude.com/v1/oauth/revoke";
532
533/// Revoke an OAuth token on the server. Best-effort — errors are non-fatal.
534pub async fn revoke_token(access_token: &str) -> bool {
535    let client = reqwest::Client::builder()
536        .timeout(std::time::Duration::from_secs(8))
537        .build()
538        .unwrap_or_default();
539    client
540        .post(OAUTH_REVOKE_URL)
541        .header("content-type", "application/x-www-form-urlencoded")
542        .header("anthropic-version", "2023-06-01")
543        .body(format!("token={}", urlencoding::encode(access_token)))
544        .send()
545        .await
546        .map(|r| r.status().is_success())
547        .unwrap_or(false)
548}
549
550// ---------------------------------------------------------------------------
551// OpenAI token refresh
552// ---------------------------------------------------------------------------
553
554/// Refresh an expired OpenAI / Codex access token using the stored refresh_token.
555pub async fn refresh_openai_token(cred: &OAuthCredential) -> Result<OAuthCredential> {
556    let client = reqwest::Client::new();
557
558    let resp = client
559        .post(OPENAI_OAUTH_TOKEN_URL)
560        .header("content-type", "application/x-www-form-urlencoded")
561        .body(format!(
562            "grant_type=refresh_token&refresh_token={}&client_id={}",
563            urlencoding::encode(&cred.refresh_token),
564            OPENAI_OAUTH_CLIENT_ID,
565        ))
566        .send()
567        .await
568        .context("OpenAI token refresh request failed")?;
569
570    if !resp.status().is_success() {
571        let status = resp.status();
572        let body = resp.text().await.unwrap_or_default();
573        bail!("OpenAI token refresh failed ({status}): {body}");
574    }
575
576    let body: serde_json::Value = resp.json().await.context("OpenAI token refresh: invalid JSON")?;
577
578    let access_token = body["access_token"]
579        .as_str()
580        .context("OpenAI token refresh: missing access_token")?
581        .to_owned();
582
583    let refresh_token = body["refresh_token"]
584        .as_str()
585        .unwrap_or(&cred.refresh_token)
586        .to_owned();
587
588    let id_token = body["id_token"].as_str().map(|s| s.to_owned())
589        .or_else(|| cred.id_token.clone());
590
591    let expires_in_secs = body["expires_in"].as_u64().unwrap_or(3600);
592    let now_ms = SystemTime::now()
593        .duration_since(UNIX_EPOCH)
594        .unwrap_or_default()
595        .as_millis() as u64;
596
597    Ok(OAuthCredential {
598        access_token,
599        refresh_token,
600        expires_at: now_ms + expires_in_secs * 1000,
601        email: cred.email.clone(),
602        id_token,
603    })
604}
605
606// ---------------------------------------------------------------------------
607// OpenAI / Codex device code flow (custom 3-step, not RFC 8628)
608// ---------------------------------------------------------------------------
609//
610// Codex uses its own device auth protocol:
611//   1. POST /deviceauth/usercode  {"client_id"} → {device_auth_id, user_code, interval}
612//   2. Poll  POST /deviceauth/token  {"device_auth_id","user_code"} until 200
613//            → {authorization_code, code_verifier, code_challenge}
614//   3. POST /oauth/token  PKCE exchange → {access_token, refresh_token, id_token}
615//
616// Verification URI where the user enters the code: https://auth.openai.com/codex/device
617
618/// Run the Codex device authorization flow. No local HTTP server required.
619///
620/// Displays a short user_code; the user visits `https://auth.openai.com/codex/device`
621/// and enters it. We poll until authorized, then exchange for tokens.
622pub async fn run_openai_oauth_flow() -> Result<OAuthCredential> {
623    const VERIFY_URI: &str = "https://auth.openai.com/codex/device";
624    const TIMEOUT_SECS: u64 = 15 * 60;
625
626    let client = reqwest::Client::new();
627
628    // Step 1: request user code
629    let resp = client
630        .post(OPENAI_DEVICE_CODE_URL)
631        .header("content-type", "application/json")
632        .json(&serde_json::json!({"client_id": OPENAI_OAUTH_CLIENT_ID}))
633        .send()
634        .await
635        .context("Codex device code request failed")?;
636
637    if !resp.status().is_success() {
638        let status = resp.status();
639        let body = resp.text().await.unwrap_or_default();
640        bail!("Codex device code request failed ({status}): {body}");
641    }
642
643    let info: serde_json::Value = resp.json().await.context("device code: invalid JSON")?;
644    let device_auth_id = info["device_auth_id"].as_str().context("missing device_auth_id")?.to_owned();
645    let user_code = info["user_code"].as_str().context("missing user_code")?.to_owned();
646    let interval_secs = info["interval"].as_u64().unwrap_or(5);
647
648    println!();
649    println!("  Visit:  {VERIFY_URI}");
650    println!("  Code:   \x1b[1;33m{user_code}\x1b[0m");
651    println!();
652    println!("  Waiting for authorization...");
653
654    open_browser(VERIFY_URI);
655
656    // Step 2: poll until code is approved
657    let deadline = std::time::Instant::now() + std::time::Duration::from_secs(TIMEOUT_SECS);
658    let poll_interval = std::time::Duration::from_secs(interval_secs);
659    let poll_body = serde_json::json!({
660        "device_auth_id": device_auth_id,
661        "user_code": user_code,
662    });
663
664    let (authorization_code, code_verifier) = loop {
665        tokio::time::sleep(poll_interval).await;
666
667        if std::time::Instant::now() > deadline {
668            bail!("Device code expired (15 min). Run `shunt add-account` again.");
669        }
670
671        let resp = client
672            .post(OPENAI_DEVICE_TOKEN_URL)
673            .header("content-type", "application/json")
674            .json(&poll_body)
675            .send()
676            .await
677            .context("Codex device poll request failed")?;
678
679        let status = resp.status();
680        // 403/404 = still pending; any 2xx = authorized
681        if status.as_u16() == 403 || status.as_u16() == 404 {
682            continue;
683        }
684        if !status.is_success() {
685            let body = resp.text().await.unwrap_or_default();
686            bail!("Codex device poll error ({status}): {body}");
687        }
688
689        let body: serde_json::Value = resp.json().await.context("device poll: invalid JSON")?;
690        let code = body["authorization_code"].as_str().context("missing authorization_code")?.to_owned();
691        let verifier = body["code_verifier"].as_str().context("missing code_verifier")?.to_owned();
692        break (code, verifier);
693    };
694
695    // Step 3: exchange authorization_code for tokens
696    let redirect_uri = format!("https://auth.openai.com/deviceauth/callback");
697    let token_body = format!(
698        "grant_type=authorization_code&code={}&redirect_uri={}&client_id={}&code_verifier={}",
699        urlencoding::encode(&authorization_code),
700        urlencoding::encode(&redirect_uri),
701        OPENAI_OAUTH_CLIENT_ID,
702        urlencoding::encode(&code_verifier),
703    );
704    let resp = client
705        .post(OPENAI_OAUTH_TOKEN_URL)
706        .header("content-type", "application/x-www-form-urlencoded")
707        .body(token_body)
708        .send()
709        .await
710        .context("Codex token exchange failed")?;
711
712    if !resp.status().is_success() {
713        let status = resp.status();
714        let body = resp.text().await.unwrap_or_default();
715        bail!("Codex token exchange failed ({status}): {body}");
716    }
717
718    let body: serde_json::Value = resp.json().await.context("token exchange: invalid JSON")?;
719    let access_token = body["access_token"]
720        .as_str()
721        .or_else(|| body["id_token"].as_str())
722        .context("token exchange: missing access_token")?
723        .to_owned();
724    let refresh_token = body["refresh_token"].as_str().unwrap_or("").to_owned();
725    let id_token = body["id_token"].as_str().map(|s| s.to_owned());
726    let expires_at = jwt_exp_ms(&access_token).unwrap_or_else(|| {
727        let now_ms = SystemTime::now()
728            .duration_since(UNIX_EPOCH)
729            .unwrap_or_default()
730            .as_millis() as u64;
731        now_ms + body["expires_in"].as_u64().unwrap_or(3600) * 1000
732    });
733
734    Ok(OAuthCredential { access_token, refresh_token, expires_at, email: None, id_token })
735}
736
737// ---------------------------------------------------------------------------
738// OpenAI account identity
739// ---------------------------------------------------------------------------
740
741/// Fetch the account email from OpenAI's userinfo endpoint.
742pub async fn fetch_openai_account_email(access_token: &str) -> Option<String> {
743    let client = reqwest::Client::builder()
744        .timeout(std::time::Duration::from_secs(8))
745        .build()
746        .ok()?;
747    let resp = client
748        .get("https://auth.openai.com/userinfo")
749        .header("authorization", format!("Bearer {access_token}"))
750        .send()
751        .await
752        .ok()?;
753    if !resp.status().is_success() { return None; }
754    let body: serde_json::Value = resp.json().await.ok()?;
755    body["email"].as_str().map(|s| s.to_owned())
756}
757
758fn open_browser(url: &str) {
759    #[cfg(target_os = "macos")]
760    { std::process::Command::new("open").arg(url).spawn().ok(); }
761
762    #[cfg(target_os = "linux")]
763    { std::process::Command::new("xdg-open").arg(url).spawn().ok(); }
764
765    // Use explorer.exe directly — avoids cmd.exe shell expansion of OAuth URL
766    // special characters (& % etc.) that would misparse with `cmd /c start`.
767    #[cfg(target_os = "windows")]
768    { std::process::Command::new("explorer").arg(url).spawn().ok(); }
769}
770
771#[cfg(test)]
772mod tests {
773    use super::*;
774
775    #[test]
776    fn test_rand_bytes_correct_length() {
777        let a: [u8; 16] = rand_bytes();
778        assert_eq!(a.len(), 16);
779        let b: [u8; 32] = rand_bytes();
780        assert_eq!(b.len(), 32);
781    }
782
783    #[test]
784    fn test_rand_bytes_not_all_zeros() {
785        // The probability of 32 random bytes all being zero is 1/2^256 — effectively impossible.
786        let bytes: [u8; 32] = rand_bytes();
787        assert!(bytes.iter().any(|&b| b != 0), "rand_bytes must not return all-zero output");
788    }
789
790    #[test]
791    fn test_rand_bytes_unique() {
792        // Two calls must not return the same value (probability 1/2^256 they'd collide).
793        let a: [u8; 32] = rand_bytes();
794        let b: [u8; 32] = rand_bytes();
795        assert_ne!(a, b, "rand_bytes must return unique values each call");
796    }
797
798    #[test]
799    fn test_pkce_pair_properties() {
800        let pkce = generate_pkce();
801        // Verifier must be base64url-safe (no padding, only URL-safe chars)
802        assert!(pkce.verifier.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
803            "PKCE verifier must be base64url-safe");
804        // Challenge must differ from verifier (it's the SHA-256 hash)
805        assert_ne!(pkce.verifier, pkce.challenge,
806            "PKCE challenge must not equal verifier");
807        assert!(!pkce.challenge.is_empty());
808        assert!(!pkce.verifier.is_empty());
809    }
810}