Skip to main content

shunt/
oauth.rs

1/// OAuth 2.0 PKCE flow + token refresh for claude.ai accounts.
2///
3/// Claude Code authenticates via OAuth, not API keys. Credentials are stored
4/// in ~/.claude/.credentials.json and sent as `Authorization: Bearer <token>`.
5use anyhow::{bail, Context, Result};
6use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::path::PathBuf;
10use std::time::{SystemTime, UNIX_EPOCH};
11
12// ---------------------------------------------------------------------------
13// Anthropic OAuth constants
14// ---------------------------------------------------------------------------
15
16pub const OAUTH_CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
17pub const OAUTH_AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
18pub const OAUTH_TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token";
19
20// ---------------------------------------------------------------------------
21// OpenAI / Codex OAuth constants
22// ---------------------------------------------------------------------------
23
24pub const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
25pub const OPENAI_OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
26pub const OPENAI_DEVICE_CODE_URL: &str = "https://auth.openai.com/api/accounts/deviceauth/usercode";
27pub const OPENAI_DEVICE_TOKEN_URL: &str = "https://auth.openai.com/api/accounts/deviceauth/token";
28
29// ---------------------------------------------------------------------------
30// Credential type
31// ---------------------------------------------------------------------------
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct OAuthCredential {
35    pub access_token: String,
36    pub refresh_token: String,
37    /// Milliseconds since Unix epoch
38    pub expires_at: u64,
39    /// Account email, fetched from roles endpoint after auth
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub email: Option<String>,
42    /// OpenAI id_token — required by the Codex CLI's ~/.codex/auth.json
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub id_token: Option<String>,
45}
46
47impl OAuthCredential {
48    /// True if the token expires within the next 5 minutes.
49    pub fn needs_refresh(&self) -> bool {
50        let now_ms = SystemTime::now()
51            .duration_since(UNIX_EPOCH)
52            .unwrap_or_default()
53            .as_millis() as u64;
54        now_ms >= self.expires_at.saturating_sub(5 * 60 * 1000)
55    }
56}
57
58// ---------------------------------------------------------------------------
59// Auto-import from Claude Code's own credential file
60// ---------------------------------------------------------------------------
61
62// ---------------------------------------------------------------------------
63// Auto-import from Codex CLI's credential file (~/.codex/auth.json)
64// ---------------------------------------------------------------------------
65
66/// Raw format used by ~/.codex/auth.json
67/// The tokens are nested under a "tokens" key; there is no top-level expires_at.
68/// Expiry is read from the JWT `exp` claim inside the access_token.
69#[derive(Deserialize)]
70struct CodexAuth {
71    tokens: CodexTokens,
72}
73
74#[derive(Deserialize)]
75struct CodexTokens {
76    access_token: String,
77    #[serde(default)]
78    refresh_token: Option<String>,
79    #[serde(default)]
80    id_token: Option<String>,
81}
82
83/// Write credentials to ~/.codex/auth.json so the Codex CLI can use them without re-login.
84///
85/// Called automatically after add-account and token refresh for OpenAI accounts.
86pub fn write_codex_auth_file(cred: &OAuthCredential) {
87    let Some(ref id_token) = cred.id_token else { return };
88    let path = codex_credentials_path();
89    if let Some(parent) = path.parent() {
90        let _ = std::fs::create_dir_all(parent);
91    }
92    let auth = serde_json::json!({
93        "tokens": {
94            "access_token": cred.access_token,
95            "refresh_token": cred.refresh_token,
96            "id_token": id_token,
97        }
98    });
99    if let Ok(text) = serde_json::to_string_pretty(&auth) {
100        let _ = std::fs::write(&path, text);
101    }
102}
103
104pub fn codex_credentials_path() -> PathBuf {
105    dirs::home_dir()
106        .unwrap_or_else(|| PathBuf::from("."))
107        .join(".codex")
108        .join("auth.json")
109}
110
111/// Read the OAuth credential from the Codex CLI's stored auth file.
112pub fn read_codex_credentials() -> Option<OAuthCredential> {
113    let text = std::fs::read_to_string(codex_credentials_path()).ok()?;
114    let raw: CodexAuth = serde_json::from_str(&text).ok()?;
115
116    let now_ms = SystemTime::now()
117        .duration_since(UNIX_EPOCH)
118        .unwrap_or_default()
119        .as_millis() as u64;
120
121    // Extract exp from the JWT payload without verifying signature.
122    let expires_at = jwt_exp_ms(&raw.tokens.access_token)
123        .unwrap_or(now_ms + 3600 * 1000); // default: 1 hour from now
124
125    Some(OAuthCredential {
126        access_token: raw.tokens.access_token,
127        refresh_token: raw.tokens.refresh_token.unwrap_or_default(),
128        expires_at,
129        email: None,
130        id_token: raw.tokens.id_token,
131    })
132}
133
134/// Decode the `exp` claim from a JWT payload (no signature verification).
135/// Returns expiry as Unix milliseconds.
136pub(crate) fn jwt_exp_ms(token: &str) -> Option<u64> {
137    let payload_b64 = token.splitn(3, '.').nth(1)?;
138    // base64url decode (no padding)
139    let decoded = base64_url_decode(payload_b64)?;
140    let v: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
141    let exp_secs = v.get("exp")?.as_u64()?;
142    Some(exp_secs * 1000)
143}
144
145/// Minimal base64url decoder (no padding, URL-safe alphabet).
146fn base64_url_decode(s: &str) -> Option<Vec<u8>> {
147    // Convert base64url to standard base64 with padding
148    let mut standard = s.replace('-', "+").replace('_', "/");
149    match standard.len() % 4 {
150        2 => standard.push_str("=="),
151        3 => standard.push('='),
152        _ => {}
153    }
154    use std::io::Read;
155    // Use the standard library's base64 via a simple approach
156    // Rust std doesn't have base64, implement a small decoder
157    let alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
158    let mut table = [0u8; 256];
159    for (i, &c) in alphabet.iter().enumerate() {
160        table[c as usize] = i as u8;
161    }
162    let bytes = standard.as_bytes();
163    let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
164    let mut i = 0;
165    while i + 3 < bytes.len() {
166        let b0 = bytes[i];
167        let b1 = bytes[i + 1];
168        let b2 = bytes[i + 2];
169        let b3 = bytes[i + 3];
170        if b0 == b'=' { break; }
171        let n0 = table[b0 as usize] as u32;
172        let n1 = table[b1 as usize] as u32;
173        let n2 = if b2 == b'=' { 0 } else { table[b2 as usize] as u32 };
174        let n3 = if b3 == b'=' { 0 } else { table[b3 as usize] as u32 };
175        let val = (n0 << 18) | (n1 << 12) | (n2 << 6) | n3;
176        out.push(((val >> 16) & 0xFF) as u8);
177        if b2 != b'=' { out.push(((val >> 8) & 0xFF) as u8); }
178        if b3 != b'=' { out.push((val & 0xFF) as u8); }
179        i += 4;
180    }
181    let _ = Read::read(&mut out.as_slice(), &mut []); // suppress unused import warning
182    Some(out)
183}
184
185
186/// Raw format used by ~/.claude/.credentials.json
187#[derive(Deserialize)]
188#[serde(rename_all = "camelCase")]
189struct ClaudeCredentials {
190    claude_ai_oauth: Option<ClaudeOAuthRaw>,
191}
192
193#[derive(Deserialize)]
194#[serde(rename_all = "camelCase")]
195struct ClaudeOAuthRaw {
196    access_token: String,
197    refresh_token: String,
198    expires_at: u64,
199}
200
201// ---------------------------------------------------------------------------
202// Session info (plan + identity) from stored credentials
203// ---------------------------------------------------------------------------
204
205pub struct SessionInfo {
206    pub email_or_id: String,
207    pub plan: String,
208}
209
210/// Read plan and identity from Claude Code's stored credentials JSON.
211/// Works for both keychain and file-based storage.
212pub fn read_claude_session_info() -> Option<SessionInfo> {
213    #[derive(serde::Deserialize)]
214    #[serde(rename_all = "camelCase")]
215    struct Outer {
216        claude_ai_oauth: Option<Inner>,
217    }
218    #[derive(serde::Deserialize)]
219    #[serde(rename_all = "camelCase")]
220    struct Inner {
221        subscription_type: Option<String>,
222        #[serde(rename = "rateLimitTier")]
223        rate_limit_tier: Option<String>,
224    }
225
226    let text = read_raw_credentials_json()?;
227    let outer: Outer = serde_json::from_str(&text).ok()?;
228    let inner = outer.claude_ai_oauth?;
229
230    let plan = inner.subscription_type.unwrap_or_else(|| "pro".into());
231    let email_or_id = inner.rate_limit_tier.unwrap_or_else(|| "unknown".into());
232
233    Some(SessionInfo { email_or_id, plan })
234}
235
236/// Returns the raw credentials JSON string from keychain (macOS) or file.
237fn read_raw_credentials_json() -> Option<String> {
238    #[cfg(target_os = "macos")]
239    {
240        // `security` can hang indefinitely in SSH sessions (no GUI keychain
241        // context). Run it in a thread with a 5-second timeout and fall
242        // through to the file-based fallback if it doesn't respond.
243        let (tx, rx) = std::sync::mpsc::channel();
244        std::thread::spawn(move || {
245            let out = std::process::Command::new("security")
246                .args(["find-generic-password", "-s", "Claude Code-credentials", "-w"])
247                .output()
248                .ok();
249            let _ = tx.send(out);
250        });
251        if let Ok(Some(out)) = rx.recv_timeout(std::time::Duration::from_secs(5)) {
252            if out.status.success() {
253                if let Ok(s) = String::from_utf8(out.stdout) {
254                    return Some(s.trim().to_owned());
255                }
256            }
257        }
258    }
259    std::fs::read_to_string(claude_credentials_path()).ok()
260}
261
262pub fn claude_credentials_path() -> PathBuf {
263    dirs::home_dir()
264        .unwrap_or_else(|| PathBuf::from("."))
265        .join(".claude")
266        .join(".credentials.json")
267}
268
269/// Read the OAuth credential from Claude Code's own credential file.
270/// On macOS, tries the Keychain first (service "Claude Code-credentials"),
271/// then falls back to ~/.claude/.credentials.json.
272pub fn read_claude_credentials() -> Option<OAuthCredential> {
273    // macOS: try Keychain first
274    #[cfg(target_os = "macos")]
275    if let Some(cred) = read_claude_credentials_keychain() {
276        return Some(cred);
277    }
278
279    // Fallback: JSON file (older Claude Code versions / non-macOS)
280    let path = claude_credentials_path();
281    let text = std::fs::read_to_string(&path).ok()?;
282    parse_claude_credentials_json(&text)
283}
284
285#[cfg(target_os = "macos")]
286fn read_claude_credentials_keychain() -> Option<OAuthCredential> {
287    let text = read_raw_credentials_json()?;
288    parse_claude_credentials_json(&text)
289}
290
291fn parse_claude_credentials_json(text: &str) -> Option<OAuthCredential> {
292    let raw: ClaudeCredentials = serde_json::from_str(text).ok()?;
293    let inner = raw.claude_ai_oauth?;
294    Some(OAuthCredential {
295        access_token: inner.access_token,
296        refresh_token: inner.refresh_token,
297        expires_at: inner.expires_at,
298        email: None,
299        id_token: None,
300    })
301}
302
303// ---------------------------------------------------------------------------
304// Token refresh
305// ---------------------------------------------------------------------------
306
307/// Refresh an expired access token. Returns the updated credential.
308pub async fn refresh_token(cred: &OAuthCredential) -> Result<OAuthCredential> {
309    let client = reqwest::Client::new();
310
311    let resp = client
312        .post(OAUTH_TOKEN_URL)
313        .header("content-type", "application/x-www-form-urlencoded")
314        .body(format!(
315            "grant_type=refresh_token&refresh_token={}&client_id={}",
316            urlencoding::encode(&cred.refresh_token),
317            OAUTH_CLIENT_ID,
318        ))
319        .send()
320        .await
321        .context("token refresh request failed")?;
322
323    if !resp.status().is_success() {
324        let status = resp.status();
325        let body = resp.text().await.unwrap_or_default();
326        bail!("token refresh failed ({status}): {body}");
327    }
328
329    let body: serde_json::Value = resp.json().await.context("token refresh: invalid JSON")?;
330
331    let access_token = body["access_token"]
332        .as_str()
333        .context("token refresh: missing access_token")?
334        .to_owned();
335
336    let refresh_token = body["refresh_token"]
337        .as_str()
338        .unwrap_or(&cred.refresh_token)
339        .to_owned();
340
341    // expires_in is seconds from now
342    let expires_in_secs = body["expires_in"].as_u64().unwrap_or(3600);
343    let now_ms = SystemTime::now()
344        .duration_since(UNIX_EPOCH)
345        .unwrap_or_default()
346        .as_millis() as u64;
347    let expires_at = now_ms + expires_in_secs * 1000;
348
349    Ok(OAuthCredential { access_token, refresh_token, expires_at, email: cred.email.clone(), id_token: None })
350}
351
352// ---------------------------------------------------------------------------
353// Account identity (email) from roles endpoint
354// ---------------------------------------------------------------------------
355
356/// Fetch the account email from the Anthropic roles endpoint.
357/// Returns `None` on any error (non-fatal).
358pub async fn fetch_account_email(access_token: &str) -> Option<String> {
359    let client = reqwest::Client::builder()
360        .timeout(std::time::Duration::from_secs(8))
361        .build()
362        .ok()?;
363    let resp = client
364        .get("https://api.anthropic.com/api/oauth/claude_cli/roles")
365        .header("authorization", format!("Bearer {access_token}"))
366        .header("anthropic-version", "2023-06-01")
367        .header("anthropic-dangerous-direct-browser-access", "true")
368        .send()
369        .await
370        .ok()?;
371
372    if !resp.status().is_success() {
373        return None;
374    }
375
376    let body: serde_json::Value = resp.json().await.ok()?;
377
378    // Try dedicated email fields first.
379    for field in &["email", "emailAddress", "email_address"] {
380        if let Some(e) = body[field].as_str().filter(|s| s.contains('@')) {
381            return Some(e.to_owned());
382        }
383    }
384
385    // Fall back to extracting from organization_name ("addr@example.com's Organization").
386    let org = body["organization_name"].as_str()?;
387    let email = org.strip_suffix("'s Organization").unwrap_or(org).trim();
388    if !email.is_empty() { Some(email.to_owned()) } else { None }
389}
390
391// ---------------------------------------------------------------------------
392// PKCE browser OAuth flow (for adding additional accounts)
393// ---------------------------------------------------------------------------
394
395struct Pkce {
396    verifier: String,
397    challenge: String,
398}
399
400fn generate_pkce() -> Pkce {
401    let verifier_bytes: [u8; 32] = rand_bytes();
402    let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
403
404    let hash = Sha256::digest(verifier.as_bytes());
405    let challenge = URL_SAFE_NO_PAD.encode(hash);
406
407    Pkce { verifier, challenge }
408}
409
410/// Generate N cryptographically random bytes using the OS entropy source.
411/// Panics if the system RNG is unavailable (unrecoverable error in a security context).
412pub fn rand_bytes<const N: usize>() -> [u8; N] {
413    let mut bytes = [0u8; N];
414    getrandom::getrandom(&mut bytes)
415        .expect("OS random number generator unavailable — cannot generate secure random bytes");
416    bytes
417}
418
419fn random_state() -> String {
420    let bytes: [u8; 16] = rand_bytes();
421    hex::encode(bytes)
422}
423
424pub const OAUTH_REDIRECT_URI: &str = "https://platform.claude.com/oauth/code/callback";
425
426/// Run the PKCE OAuth flow using the registered redirect URI.
427///
428/// Opens the browser to claude.ai. After the user authorizes, the callback page
429/// displays a code (format: CODE#STATE). The user pastes it here; we split out
430/// the state and exchange the code at the token endpoint.
431pub async fn run_oauth_flow() -> Result<OAuthCredential> {
432    use std::io::{self, Write};
433
434    let pkce = generate_pkce();
435    let state = random_state();
436    let redirect_uri = OAUTH_REDIRECT_URI;
437
438    let scope = urlencoding::encode(
439        "user:inference user:profile user:file_upload user:mcp_servers user:sessions:claude_code",
440    );
441    let auth_url = format!(
442        "{base}?response_type=code\
443         &client_id={client_id}\
444         &redirect_uri={redirect}\
445         &scope={scope}\
446         &state={state}\
447         &code_challenge={challenge}\
448         &code_challenge_method=S256",
449        base = OAUTH_AUTHORIZE_URL,
450        client_id = OAUTH_CLIENT_ID,
451        redirect = urlencoding::encode(redirect_uri),
452        scope = scope,
453        state = state,
454        challenge = pkce.challenge,
455    );
456
457    println!("\nOpening browser for claude.ai login...");
458    println!("If it does not open automatically, visit:\n  {auth_url}\n");
459    open_browser(&auth_url);
460
461    println!("After you authorize, the page will show an authorization code.");
462    println!("Copy it and paste it here.");
463    println!();
464    print!("Paste code: ");
465    io::stdout().flush()?;
466
467    let mut pasted = String::new();
468    io::stdin().read_line(&mut pasted)?;
469    // Page shows "code#state"
470    let pasted = pasted.trim();
471    let (code, pasted_state) = if let Some((c, s)) = pasted.split_once('#') {
472        (c.trim(), s.trim())
473    } else {
474        (pasted, state.as_str())
475    };
476
477    if code.is_empty() {
478        bail!("No code entered.");
479    }
480
481    let cred = exchange_code(code, pasted_state, redirect_uri, &pkce.verifier).await?;
482    Ok(cred)
483}
484
485async fn exchange_code(code: &str, state: &str, redirect_uri: &str, verifier: &str) -> Result<OAuthCredential> {
486    let client = reqwest::Client::new();
487
488    let body = serde_json::json!({
489        "grant_type": "authorization_code",
490        "code": code,
491        "state": state,
492        "redirect_uri": redirect_uri,
493        "client_id": OAUTH_CLIENT_ID,
494        "code_verifier": verifier,
495    });
496
497    let resp = client
498        .post(OAUTH_TOKEN_URL)
499        .header("content-type", "application/json")
500        .header("anthropic-version", "2023-06-01")
501        .json(&body)
502        .send()
503        .await
504        .context("token exchange request failed")?;
505
506    if !resp.status().is_success() {
507        let status = resp.status();
508        let body = resp.text().await.unwrap_or_default();
509        bail!("token exchange failed ({status}): {body}");
510    }
511
512    let body: serde_json::Value = resp.json().await.context("token exchange: invalid JSON")?;
513
514    let access_token = body["access_token"]
515        .as_str()
516        .context("token exchange: missing access_token")?
517        .to_owned();
518    let refresh_token = body["refresh_token"]
519        .as_str()
520        .unwrap_or("")
521        .to_owned();
522    let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
523    let now_ms = SystemTime::now()
524        .duration_since(UNIX_EPOCH)
525        .unwrap_or_default()
526        .as_millis() as u64;
527
528    Ok(OAuthCredential {
529        access_token,
530        refresh_token,
531        expires_at: now_ms + expires_in * 1000,
532        email: None,
533        id_token: None,
534    })
535}
536
537// ---------------------------------------------------------------------------
538// Token revocation
539// ---------------------------------------------------------------------------
540
541pub const OAUTH_REVOKE_URL: &str = "https://platform.claude.com/v1/oauth/revoke";
542
543/// Revoke an OAuth token on the server. Best-effort — errors are non-fatal.
544pub async fn revoke_token(access_token: &str) -> bool {
545    let client = reqwest::Client::builder()
546        .timeout(std::time::Duration::from_secs(8))
547        .build()
548        .unwrap_or_default();
549    client
550        .post(OAUTH_REVOKE_URL)
551        .header("content-type", "application/x-www-form-urlencoded")
552        .header("anthropic-version", "2023-06-01")
553        .body(format!("token={}", urlencoding::encode(access_token)))
554        .send()
555        .await
556        .map(|r| r.status().is_success())
557        .unwrap_or(false)
558}
559
560// ---------------------------------------------------------------------------
561// OpenAI token refresh
562// ---------------------------------------------------------------------------
563
564/// Refresh an expired OpenAI / Codex access token using the stored refresh_token.
565pub async fn refresh_openai_token(cred: &OAuthCredential) -> Result<OAuthCredential> {
566    let client = reqwest::Client::new();
567
568    let resp = client
569        .post(OPENAI_OAUTH_TOKEN_URL)
570        .header("content-type", "application/x-www-form-urlencoded")
571        .body(format!(
572            "grant_type=refresh_token&refresh_token={}&client_id={}",
573            urlencoding::encode(&cred.refresh_token),
574            OPENAI_OAUTH_CLIENT_ID,
575        ))
576        .send()
577        .await
578        .context("OpenAI token refresh request failed")?;
579
580    if !resp.status().is_success() {
581        let status = resp.status();
582        let body = resp.text().await.unwrap_or_default();
583        bail!("OpenAI token refresh failed ({status}): {body}");
584    }
585
586    let body: serde_json::Value = resp.json().await.context("OpenAI token refresh: invalid JSON")?;
587
588    let access_token = body["access_token"]
589        .as_str()
590        .context("OpenAI token refresh: missing access_token")?
591        .to_owned();
592
593    let refresh_token = body["refresh_token"]
594        .as_str()
595        .unwrap_or(&cred.refresh_token)
596        .to_owned();
597
598    let id_token = body["id_token"].as_str().map(|s| s.to_owned())
599        .or_else(|| cred.id_token.clone());
600
601    let expires_in_secs = body["expires_in"].as_u64().unwrap_or(3600);
602    let now_ms = SystemTime::now()
603        .duration_since(UNIX_EPOCH)
604        .unwrap_or_default()
605        .as_millis() as u64;
606
607    Ok(OAuthCredential {
608        access_token,
609        refresh_token,
610        expires_at: now_ms + expires_in_secs * 1000,
611        email: cred.email.clone(),
612        id_token,
613    })
614}
615
616// ---------------------------------------------------------------------------
617// OpenAI / Codex device code flow (custom 3-step, not RFC 8628)
618// ---------------------------------------------------------------------------
619//
620// Codex uses its own device auth protocol:
621//   1. POST /deviceauth/usercode  {"client_id"} → {device_auth_id, user_code, interval}
622//   2. Poll  POST /deviceauth/token  {"device_auth_id","user_code"} until 200
623//            → {authorization_code, code_verifier, code_challenge}
624//   3. POST /oauth/token  PKCE exchange → {access_token, refresh_token, id_token}
625//
626// Verification URI where the user enters the code: https://auth.openai.com/codex/device
627
628/// Run the Codex device authorization flow. No local HTTP server required.
629///
630/// Displays a short user_code; the user visits `https://auth.openai.com/codex/device`
631/// and enters it. We poll until authorized, then exchange for tokens.
632pub async fn run_openai_oauth_flow() -> Result<OAuthCredential> {
633    const VERIFY_URI: &str = "https://auth.openai.com/codex/device";
634    const TIMEOUT_SECS: u64 = 15 * 60;
635
636    let client = reqwest::Client::new();
637
638    // Step 1: request user code
639    let resp = client
640        .post(OPENAI_DEVICE_CODE_URL)
641        .header("content-type", "application/json")
642        .json(&serde_json::json!({"client_id": OPENAI_OAUTH_CLIENT_ID}))
643        .send()
644        .await
645        .context("Codex device code request failed")?;
646
647    if !resp.status().is_success() {
648        let status = resp.status();
649        let body = resp.text().await.unwrap_or_default();
650        bail!("Codex device code request failed ({status}): {body}");
651    }
652
653    let info: serde_json::Value = resp.json().await.context("device code: invalid JSON")?;
654    let device_auth_id = info["device_auth_id"].as_str().context("missing device_auth_id")?.to_owned();
655    let user_code = info["user_code"].as_str().context("missing user_code")?.to_owned();
656    let interval_secs = info["interval"].as_u64().unwrap_or(5);
657
658    println!();
659    println!("  Visit:  {VERIFY_URI}");
660    println!("  Code:   \x1b[1;33m{user_code}\x1b[0m");
661    println!();
662    println!("  Waiting for authorization...");
663
664    open_browser(VERIFY_URI);
665
666    // Step 2: poll until code is approved
667    let deadline = std::time::Instant::now() + std::time::Duration::from_secs(TIMEOUT_SECS);
668    let poll_interval = std::time::Duration::from_secs(interval_secs);
669    let poll_body = serde_json::json!({
670        "device_auth_id": device_auth_id,
671        "user_code": user_code,
672    });
673
674    let (authorization_code, code_verifier) = loop {
675        tokio::time::sleep(poll_interval).await;
676
677        if std::time::Instant::now() > deadline {
678            bail!("Device code expired (15 min). Run `shunt add-account` again.");
679        }
680
681        let resp = client
682            .post(OPENAI_DEVICE_TOKEN_URL)
683            .header("content-type", "application/json")
684            .json(&poll_body)
685            .send()
686            .await
687            .context("Codex device poll request failed")?;
688
689        let status = resp.status();
690        // 403/404 = still pending; any 2xx = authorized
691        if status.as_u16() == 403 || status.as_u16() == 404 {
692            continue;
693        }
694        if !status.is_success() {
695            let body = resp.text().await.unwrap_or_default();
696            bail!("Codex device poll error ({status}): {body}");
697        }
698
699        let body: serde_json::Value = resp.json().await.context("device poll: invalid JSON")?;
700        let code = body["authorization_code"].as_str().context("missing authorization_code")?.to_owned();
701        let verifier = body["code_verifier"].as_str().context("missing code_verifier")?.to_owned();
702        break (code, verifier);
703    };
704
705    // Step 3: exchange authorization_code for tokens
706    let redirect_uri = format!("https://auth.openai.com/deviceauth/callback");
707    let token_body = format!(
708        "grant_type=authorization_code&code={}&redirect_uri={}&client_id={}&code_verifier={}",
709        urlencoding::encode(&authorization_code),
710        urlencoding::encode(&redirect_uri),
711        OPENAI_OAUTH_CLIENT_ID,
712        urlencoding::encode(&code_verifier),
713    );
714    let resp = client
715        .post(OPENAI_OAUTH_TOKEN_URL)
716        .header("content-type", "application/x-www-form-urlencoded")
717        .body(token_body)
718        .send()
719        .await
720        .context("Codex token exchange failed")?;
721
722    if !resp.status().is_success() {
723        let status = resp.status();
724        let body = resp.text().await.unwrap_or_default();
725        bail!("Codex token exchange failed ({status}): {body}");
726    }
727
728    let body: serde_json::Value = resp.json().await.context("token exchange: invalid JSON")?;
729    let access_token = body["access_token"]
730        .as_str()
731        .or_else(|| body["id_token"].as_str())
732        .context("token exchange: missing access_token")?
733        .to_owned();
734    let refresh_token = body["refresh_token"].as_str().unwrap_or("").to_owned();
735    let id_token = body["id_token"].as_str().map(|s| s.to_owned());
736    let expires_at = jwt_exp_ms(&access_token).unwrap_or_else(|| {
737        let now_ms = SystemTime::now()
738            .duration_since(UNIX_EPOCH)
739            .unwrap_or_default()
740            .as_millis() as u64;
741        now_ms + body["expires_in"].as_u64().unwrap_or(3600) * 1000
742    });
743
744    Ok(OAuthCredential { access_token, refresh_token, expires_at, email: None, id_token })
745}
746
747// ---------------------------------------------------------------------------
748// OpenAI account identity
749// ---------------------------------------------------------------------------
750
751/// Fetch the account email from OpenAI's userinfo endpoint.
752pub async fn fetch_openai_account_email(access_token: &str) -> Option<String> {
753    let client = reqwest::Client::builder()
754        .timeout(std::time::Duration::from_secs(8))
755        .build()
756        .ok()?;
757    let resp = client
758        .get("https://auth.openai.com/userinfo")
759        .header("authorization", format!("Bearer {access_token}"))
760        .send()
761        .await
762        .ok()?;
763    if !resp.status().is_success() { return None; }
764    let body: serde_json::Value = resp.json().await.ok()?;
765    body["email"].as_str().map(|s| s.to_owned())
766}
767
768fn open_browser(url: &str) {
769    #[cfg(target_os = "macos")]
770    { std::process::Command::new("open").arg(url).spawn().ok(); }
771
772    #[cfg(target_os = "linux")]
773    { std::process::Command::new("xdg-open").arg(url).spawn().ok(); }
774
775    // Use explorer.exe directly — avoids cmd.exe shell expansion of OAuth URL
776    // special characters (& % etc.) that would misparse with `cmd /c start`.
777    #[cfg(target_os = "windows")]
778    { std::process::Command::new("explorer").arg(url).spawn().ok(); }
779}
780
781#[cfg(test)]
782mod tests {
783    use super::*;
784
785    #[test]
786    fn test_rand_bytes_correct_length() {
787        let a: [u8; 16] = rand_bytes();
788        assert_eq!(a.len(), 16);
789        let b: [u8; 32] = rand_bytes();
790        assert_eq!(b.len(), 32);
791    }
792
793    #[test]
794    fn test_rand_bytes_not_all_zeros() {
795        // The probability of 32 random bytes all being zero is 1/2^256 — effectively impossible.
796        let bytes: [u8; 32] = rand_bytes();
797        assert!(bytes.iter().any(|&b| b != 0), "rand_bytes must not return all-zero output");
798    }
799
800    #[test]
801    fn test_rand_bytes_unique() {
802        // Two calls must not return the same value (probability 1/2^256 they'd collide).
803        let a: [u8; 32] = rand_bytes();
804        let b: [u8; 32] = rand_bytes();
805        assert_ne!(a, b, "rand_bytes must return unique values each call");
806    }
807
808    #[test]
809    fn test_pkce_pair_properties() {
810        let pkce = generate_pkce();
811        // Verifier must be base64url-safe (no padding, only URL-safe chars)
812        assert!(pkce.verifier.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_'),
813            "PKCE verifier must be base64url-safe");
814        // Challenge must differ from verifier (it's the SHA-256 hash)
815        assert_ne!(pkce.verifier, pkce.challenge,
816            "PKCE challenge must not equal verifier");
817        assert!(!pkce.challenge.is_empty());
818        assert!(!pkce.verifier.is_empty());
819    }
820}