Skip to main content

tryaudex_core/
sso.rs

1use serde::{Deserialize, Serialize};
2
3use crate::error::{AvError, Result};
4
5/// SSO provider type.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum SsoProvider {
9    Okta,
10    Auth0,
11    Google,
12    /// Generic OIDC provider
13    Oidc,
14}
15
16/// SSO configuration in `[sso]` config section.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SsoConfig {
19    /// SSO provider type
20    pub provider: SsoProvider,
21    /// OIDC issuer URL (e.g. "https://myorg.okta.com", "https://accounts.google.com")
22    pub issuer: String,
23    /// OIDC client ID
24    pub client_id: String,
25    /// OIDC client secret (optional, for confidential clients)
26    pub client_secret: Option<String>,
27    /// Redirect URI for the device/local auth flow
28    #[serde(default = "default_redirect")]
29    pub redirect_uri: String,
30    /// OIDC scopes to request
31    #[serde(default = "default_scopes")]
32    pub scopes: Vec<String>,
33    /// Claim to use as the identity (default: "email")
34    #[serde(default = "default_identity_claim")]
35    pub identity_claim: String,
36    /// Claim to use for group/team membership (default: "groups")
37    #[serde(default = "default_groups_claim")]
38    pub groups_claim: String,
39    /// Whether SSO is required (if false, falls back to local identity)
40    #[serde(default)]
41    pub required: bool,
42}
43
44fn default_redirect() -> String {
45    "http://localhost:8400/callback".to_string()
46}
47
48fn default_scopes() -> Vec<String> {
49    vec![
50        "openid".to_string(),
51        "email".to_string(),
52        "profile".to_string(),
53    ]
54}
55
56fn default_identity_claim() -> String {
57    "email".to_string()
58}
59
60fn default_groups_claim() -> String {
61    "groups".to_string()
62}
63
64/// Token response from the OIDC provider.
65#[derive(Debug, Deserialize)]
66struct TokenResponse {
67    access_token: String,
68    id_token: Option<String>,
69    #[allow(dead_code)]
70    token_type: Option<String>,
71    #[allow(dead_code)]
72    expires_in: Option<u64>,
73}
74
75/// Decoded identity from an SSO token.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SsoIdentity {
78    /// The user's identity (email, subject, etc.)
79    pub identity: String,
80    /// Group/team memberships from the SSO provider
81    pub groups: Vec<String>,
82    /// Raw claims from the ID token
83    pub claims: serde_json::Value,
84}
85
86/// Cached SSO session stored on disk.
87#[derive(Debug, Serialize, Deserialize)]
88struct CachedSsoSession {
89    identity: SsoIdentity,
90    expires_at: i64,
91    /// Issuer URL that issued this identity — prevents cross-provider reuse.
92    #[serde(default)]
93    issuer: String,
94    /// Client ID the token was issued for.
95    #[serde(default)]
96    client_id: String,
97    /// Machine fingerprint at time of caching — prevents replay on another host.
98    #[serde(default)]
99    machine_id: String,
100}
101
102/// Read the machine-unique identifier from the OS.
103/// Returns an empty string if unavailable (non-fatal; validation will fail on mismatch).
104fn read_machine_id() -> String {
105    // Linux / systemd: /etc/machine-id
106    if let Ok(id) = std::fs::read_to_string("/etc/machine-id") {
107        let trimmed = id.trim().to_string();
108        if !trimmed.is_empty() {
109            return trimmed;
110        }
111    }
112    // macOS: IOPlatformUUID via ioreg
113    #[cfg(target_os = "macos")]
114    {
115        if let Ok(out) = std::process::Command::new("ioreg")
116            .args(["-rd1", "-c", "IOPlatformExpertDevice"])
117            .output()
118        {
119            let text = String::from_utf8_lossy(&out.stdout);
120            for line in text.lines() {
121                if line.contains("IOPlatformUUID") {
122                    if let Some(start) = line.rfind('"') {
123                        let rest = &line[..start];
124                        if let Some(end) = rest.rfind('"') {
125                            return rest[end + 1..].to_string();
126                        }
127                    }
128                }
129            }
130        }
131    }
132    String::new()
133}
134
135fn cache_path() -> std::path::PathBuf {
136    dirs::cache_dir()
137        .unwrap_or_else(|| std::path::PathBuf::from("."))
138        .join("audex")
139        .join("sso_session.json")
140}
141
142/// Load cached SSO session if still valid and from the same provider and machine (encrypted at rest).
143pub fn load_cached_identity(issuer: &str, client_id: &str) -> Option<SsoIdentity> {
144    let path = cache_path();
145    let cached: CachedSsoSession = crate::keystore::decrypt_from_file(&path).ok()??;
146    let now = chrono::Utc::now().timestamp();
147    if now >= cached.expires_at {
148        return None;
149    }
150    // Reject cached identity from a different provider/client
151    if cached.issuer != issuer || cached.client_id != client_id {
152        tracing::info!("SSO cache issuer/client mismatch — re-authenticating");
153        return None;
154    }
155    // Reject cached identity from a different machine (prevents exfiltrated cache replay)
156    let current_machine_id = read_machine_id();
157    if !current_machine_id.is_empty() && cached.machine_id != current_machine_id {
158        tracing::info!("SSO cache machine-id mismatch — re-authenticating");
159        return None;
160    }
161    Some(cached.identity)
162}
163
164fn save_cached_identity(identity: &SsoIdentity, ttl_secs: u64, issuer: &str, client_id: &str) {
165    let cached = CachedSsoSession {
166        identity: identity.clone(),
167        expires_at: chrono::Utc::now().timestamp() + ttl_secs as i64,
168        issuer: issuer.to_string(),
169        client_id: client_id.to_string(),
170        machine_id: read_machine_id(),
171    };
172    let path = cache_path();
173    if let Some(parent) = path.parent() {
174        let _ = std::fs::create_dir_all(parent);
175    }
176    if let Err(e) = crate::keystore::encrypt_to_file(&path, &cached) {
177        tracing::warn!(error = %e, "Failed to cache SSO identity — next login will require re-authentication");
178    }
179}
180
181/// Authenticate via SSO using the device authorization flow.
182/// Opens a browser for login and listens on a local callback URL.
183pub async fn authenticate(config: &SsoConfig) -> Result<SsoIdentity> {
184    // Check cache first — only reuse if same issuer/client
185    if let Some(cached) = load_cached_identity(&config.issuer, &config.client_id) {
186        return Ok(cached);
187    }
188
189    let discovery_url = format!(
190        "{}/.well-known/openid-configuration",
191        config.issuer.trim_end_matches('/')
192    );
193
194    // R6-H23: do NOT follow redirects on the discovery request.  reqwest's
195    // default policy transparently follows up to 10 hops, which means an
196    // attacker who can inject a 302 on the discovery URL (a compromised
197    // CDN, a misconfigured reverse proxy, a rogue intermediate) can
198    // redirect the OIDC client to their own IdP's discovery document —
199    // returning whatever authorization_endpoint and jwks_uri they like.
200    // The subsequent set_issuer check is the *second* line of defence;
201    // blocking the redirect at the transport is the first.
202    // R6-M37: enforce HTTP timeouts on every SSO network call. Without a
203    // timeout, a slow-loris IdP (or a hung reverse proxy) would hang the
204    // CLI indefinitely while it waits for discovery/JWKS/userinfo bytes.
205    let discovery_client = reqwest::Client::builder()
206        .redirect(reqwest::redirect::Policy::none())
207        .timeout(std::time::Duration::from_secs(30))
208        .connect_timeout(std::time::Duration::from_secs(10))
209        .build()
210        .map_err(|e| AvError::InvalidPolicy(format!("SSO client build failed: {}", e)))?;
211
212    // Fetch OIDC discovery document.
213    // R6-M38: cap the discovery document at 1 MiB. A hostile/compromised
214    // IdP could otherwise stream an arbitrarily large JSON body into
215    // `.json()` and OOM the client.
216    let discovery: serde_json::Value = fetch_json_capped(
217        &discovery_client,
218        &discovery_url,
219        None,
220        1024 * 1024,
221        "SSO discovery",
222    )
223    .await?;
224
225    // R6-H23: validate `issuer` in the discovery document matches the
226    // configured issuer exactly (OIDC Discovery 1.0 §4.3).  Without this,
227    // a discovery document returned from a different origin (e.g. via a
228    // proxy that rewrites the host) would pass through and mint a JWT
229    // validation config pointing at the wrong issuer — later set_issuer
230    // on the JWT would then "succeed" because both sides come from the
231    // same attacker-controlled document.
232    let discovery_issuer = discovery["issuer"].as_str().ok_or_else(|| {
233        AvError::InvalidPolicy("OIDC discovery document missing `issuer` field".to_string())
234    })?;
235    if discovery_issuer.trim_end_matches('/') != config.issuer.trim_end_matches('/') {
236        return Err(AvError::InvalidPolicy(format!(
237            "OIDC discovery issuer mismatch: configured={}, received={}. \
238             This can indicate a redirect on the discovery URL or a misconfigured IdP.",
239            config.issuer, discovery_issuer
240        )));
241    }
242
243    // R6-M37: same timeouts for the token/userinfo client.
244    let client = reqwest::Client::builder()
245        .timeout(std::time::Duration::from_secs(30))
246        .connect_timeout(std::time::Duration::from_secs(10))
247        .build()
248        .map_err(|e| AvError::InvalidPolicy(format!("SSO client build failed: {}", e)))?;
249
250    let auth_endpoint = discovery["authorization_endpoint"]
251        .as_str()
252        .ok_or_else(|| {
253            AvError::InvalidPolicy("Missing authorization_endpoint in OIDC discovery".to_string())
254        })?;
255    let token_endpoint = discovery["token_endpoint"].as_str().ok_or_else(|| {
256        AvError::InvalidPolicy("Missing token_endpoint in OIDC discovery".to_string())
257    })?;
258
259    // Generate PKCE challenge
260    let verifier = generate_pkce_verifier();
261    let challenge = generate_pkce_challenge(&verifier);
262    let state = generate_state();
263    let nonce = generate_state(); // same format: random 16-byte hex
264
265    let scopes = config.scopes.join(" ");
266    let auth_url = format!(
267        "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}&nonce={}&code_challenge={}&code_challenge_method=S256",
268        auth_endpoint,
269        urlencoding::encode(&config.client_id),
270        urlencoding::encode(&config.redirect_uri),
271        urlencoding::encode(&scopes),
272        urlencoding::encode(&state),
273        urlencoding::encode(&nonce),
274        urlencoding::encode(&challenge),
275    );
276
277    eprintln!("Opening browser for SSO login...");
278    eprintln!("If the browser doesn't open, visit:\n  {}\n", auth_url);
279
280    // Try to open browser
281    let _ = open_browser(&auth_url);
282
283    // Start local server to receive the callback
284    let code = listen_for_callback(&config.redirect_uri, &state).await?;
285
286    // Exchange code for tokens
287    let mut token_params = vec![
288        ("grant_type", "authorization_code".to_string()),
289        ("code", code),
290        ("redirect_uri", config.redirect_uri.clone()),
291        ("client_id", config.client_id.clone()),
292        ("code_verifier", verifier),
293    ];
294    if let Some(ref secret) = config.client_secret {
295        token_params.push(("client_secret", secret.clone()));
296    }
297
298    let token_resp: TokenResponse = client
299        .post(token_endpoint)
300        .form(&token_params)
301        .send()
302        .await
303        .map_err(|e| AvError::InvalidPolicy(format!("SSO token exchange failed: {}", e)))?
304        .json()
305        .await
306        .map_err(|e| AvError::InvalidPolicy(format!("SSO token parse failed: {}", e)))?;
307
308    // Decode and verify the ID token (signature + claims)
309    //
310    // R6-H24: previously `jwks_uri` defaulted to the empty string when
311    // missing from the discovery document, which caused
312    // `decode_and_verify_jwt` to skip signature verification and fall
313    // back to transport-trust-only decoding with only a tracing::warn.
314    // An attacker who could influence the discovery document — via a
315    // compromised IdP, a misconfigured reverse proxy, or the redirect
316    // vector covered in R6-H23 — could return `{"jwks_uri": ""}` (or
317    // simply omit the key) and forge arbitrary ID tokens whose only
318    // protection was HTTPS transport.  Treat missing/empty jwks_uri as
319    // fatal when we actually have an id_token to verify.
320    let claims = if let Some(ref id_token) = token_resp.id_token {
321        let jwks_uri = discovery["jwks_uri"].as_str().unwrap_or_default();
322        if jwks_uri.is_empty() {
323            return Err(AvError::InvalidPolicy(
324                "OIDC discovery document is missing `jwks_uri`; refusing to accept an \
325                 id_token without signature verification. If the provider genuinely does \
326                 not expose JWKS, configure audex to use the userinfo endpoint instead."
327                    .to_string(),
328            ));
329        }
330        decode_and_verify_jwt(
331            &client,
332            jwks_uri,
333            id_token,
334            &config.issuer,
335            &config.client_id,
336        )
337        .await?
338    } else {
339        // Use userinfo endpoint as fallback
340        let userinfo_url = discovery["userinfo_endpoint"].as_str().ok_or_else(|| {
341            AvError::InvalidPolicy("No id_token and no userinfo_endpoint".to_string())
342        })?;
343
344        // R6-M38: cap userinfo body at 1 MiB.
345        fetch_json_capped(
346            &client,
347            userinfo_url,
348            Some(&token_resp.access_token),
349            1024 * 1024,
350            "SSO userinfo",
351        )
352        .await?
353    };
354
355    // R6-M36: require nonce presence and equality. Without the explicit
356    // presence check, a token missing the nonce claim would silently
357    // bypass replay protection mandated by OIDC Core §3.1.3.7. Userinfo
358    // responses (the id_token-less fallback above) legitimately do not
359    // carry a nonce; only enforce the check when we are verifying an
360    // id_token.
361    if token_resp.id_token.is_some() {
362        let token_nonce = claims
363            .get("nonce")
364            .and_then(|v| v.as_str())
365            .ok_or_else(|| {
366                AvError::InvalidPolicy(
367                    "SSO id_token is missing required `nonce` claim — \
368                     rejecting to prevent replay (OIDC Core §3.1.3.7)"
369                        .to_string(),
370                )
371            })?;
372        if token_nonce != nonce {
373            return Err(AvError::InvalidPolicy(
374                "SSO nonce mismatch — possible token replay".to_string(),
375            ));
376        }
377    }
378
379    // R6-M41: require the identity claim to be present and a string. The
380    // previous `.unwrap_or("unknown")` let an IdP whose userinfo returned
381    // a JSON scalar/array silently fall through to `"unknown"` as the
382    // principal identity — bypassing identity_claim entirely and letting
383    // every such caller map to the same "unknown" user in policy.
384    let identity_str = claims[&config.identity_claim]
385        .as_str()
386        .ok_or_else(|| {
387            AvError::InvalidPolicy(format!(
388                "SSO identity claim `{}` is missing or not a string. \
389                 Check the identity_claim configuration against the IdP's \
390                 token/userinfo shape.",
391                config.identity_claim
392            ))
393        })?
394        .to_string();
395
396    let groups: Vec<String> = claims[&config.groups_claim]
397        .as_array()
398        .map(|arr| {
399            arr.iter()
400                .filter_map(|v| v.as_str().map(|s| s.to_string()))
401                .collect()
402        })
403        .unwrap_or_default();
404
405    let sso_identity = SsoIdentity {
406        identity: identity_str,
407        groups,
408        claims,
409    };
410
411    // Cache for 1 hour
412    save_cached_identity(&sso_identity, 3600, &config.issuer, &config.client_id);
413
414    Ok(sso_identity)
415}
416
417/// Decode and verify a JWT ID token against the provider's JWKS keys.
418///
419/// Fetches the JWKS from the provider's `jwks_uri`, verifies the cryptographic
420/// signature, and validates `iss`, `aud`, and `exp` claims.
421///
422/// Falls back to transport-trust-only decoding if `jwks_uri` is unavailable,
423/// but logs a prominent warning.
424async fn decode_and_verify_jwt(
425    client: &reqwest::Client,
426    jwks_uri: &str,
427    token: &str,
428    expected_issuer: &str,
429    expected_audience: &str,
430) -> Result<serde_json::Value> {
431    use jsonwebtoken::{decode, DecodingKey, Validation};
432
433    let parts: Vec<&str> = token.split('.').collect();
434    if parts.len() != 3 {
435        return Err(AvError::InvalidPolicy("Invalid JWT format".to_string()));
436    }
437
438    // Try JWKS-based signature verification
439    if !jwks_uri.is_empty() {
440        let header = jsonwebtoken::decode_header(token)
441            .map_err(|e| AvError::InvalidPolicy(format!("JWT header decode failed: {}", e)))?;
442
443        let algorithm = header.alg;
444
445        // Enforce an allowlist of expected asymmetric algorithms to prevent
446        // algorithm confusion attacks where an attacker pairs a JWK with an
447        // unexpected algorithm (e.g. HMAC symmetric key as RSA).
448        const ALLOWED_ALGORITHMS: &[jsonwebtoken::Algorithm] = &[
449            jsonwebtoken::Algorithm::RS256,
450            jsonwebtoken::Algorithm::RS384,
451            jsonwebtoken::Algorithm::RS512,
452            jsonwebtoken::Algorithm::ES256,
453            jsonwebtoken::Algorithm::ES384,
454            jsonwebtoken::Algorithm::PS256,
455            jsonwebtoken::Algorithm::PS384,
456            jsonwebtoken::Algorithm::PS512,
457            jsonwebtoken::Algorithm::EdDSA,
458        ];
459        if !ALLOWED_ALGORITHMS.contains(&algorithm) {
460            return Err(AvError::InvalidPolicy(format!(
461                "JWT algorithm {:?} is not in the allowed set of asymmetric algorithms",
462                algorithm
463            )));
464        }
465
466        // Fetch JWKS from provider.
467        // R6-M38: cap JWKS body at 1 MiB so a hostile IdP can't OOM the
468        // verifier with a giant key set.
469        let jwks_resp: serde_json::Value =
470            fetch_json_capped(client, jwks_uri, None, 1024 * 1024, "JWKS").await?;
471
472        // Parse JWKs into typed structs
473        let jwk_set: jsonwebtoken::jwk::JwkSet = serde_json::from_value(jwks_resp)
474            .map_err(|e| AvError::InvalidPolicy(format!("JWKS parse into JwkSet failed: {}", e)))?;
475
476        // R6-M40: require the token header to carry a `kid` and match it
477        // exactly. The previous `jwk_set.keys.first()` fallback let a
478        // token with no kid verify against an arbitrary key from the
479        // JWKS — if the IdP rotated keys or an attacker could influence
480        // which entry landed first, verification could succeed against
481        // the wrong key.
482        let kid = header.kid.as_deref().ok_or_else(|| {
483            AvError::InvalidPolicy(
484                "JWT header is missing `kid` — refusing to pick an \
485                 arbitrary key from the JWKS. Reject tokens without kid \
486                 to prevent key-confusion attacks."
487                    .to_string(),
488            )
489        })?;
490        let matching_key = jwk_set
491            .keys
492            .iter()
493            .find(|k| k.common.key_id.as_deref() == Some(kid));
494
495        if let Some(jwk) = matching_key {
496            let decoding_key = DecodingKey::from_jwk(jwk)
497                .map_err(|e| AvError::InvalidPolicy(format!("JWK decode failed: {}", e)))?;
498
499            let mut validation = Validation::new(algorithm);
500            validation.set_issuer(&[expected_issuer]);
501            validation.set_audience(&[expected_audience]);
502
503            let token_data = decode::<serde_json::Value>(token, &decoding_key, &validation)
504                .map_err(|e| {
505                    AvError::InvalidPolicy(format!("JWT signature verification failed: {}", e))
506                })?;
507
508            return Ok(token_data.claims);
509        }
510
511        return Err(AvError::InvalidPolicy(format!(
512            "No matching JWK found for token kid={:?}",
513            kid
514        )));
515    }
516
517    // R6-H24: defence-in-depth.  The caller already rejects empty
518    // jwks_uri before we get here, but leave an explicit refusal so a
519    // future caller cannot reintroduce the "silently fall back to
520    // unverified" bug.  Any path that reaches this point with an
521    // empty jwks_uri is a programmer error, not a runtime condition
522    // we should paper over.
523    Err(AvError::InvalidPolicy(
524        "JWT signature verification cannot be skipped: jwks_uri is empty. \
525         This is a programming error — callers must reject missing jwks_uri \
526         before invoking decode_and_verify_jwt."
527            .to_string(),
528    ))
529}
530
531/// Decode JWT claims without cryptographic signature verification.
532///
533/// Validates token format, `exp`, `iss`, and `aud` claims, but does NOT
534/// verify the JWT signature. Retained only for unit tests that exercise
535/// the claim-level validation logic — the production fallback that used
536/// to call this was removed in R6-H24 to prevent silent signature-skip.
537#[cfg(test)]
538fn decode_jwt_claims_unverified(
539    token: &str,
540    expected_issuer: &str,
541    expected_audience: &str,
542) -> Result<serde_json::Value> {
543    let parts: Vec<&str> = token.split('.').collect();
544    if parts.len() != 3 {
545        return Err(AvError::InvalidPolicy("Invalid JWT format".to_string()));
546    }
547
548    use base64::Engine;
549    let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
550        .decode(parts[1])
551        .map_err(|e| AvError::InvalidPolicy(format!("JWT base64 decode failed: {}", e)))?;
552
553    let claims: serde_json::Value = serde_json::from_slice(&payload)
554        .map_err(|e| AvError::InvalidPolicy(format!("JWT claims parse failed: {}", e)))?;
555
556    // Reject expired tokens to prevent replay attacks.
557    if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
558        let now = chrono::Utc::now().timestamp();
559        if now > exp {
560            return Err(AvError::InvalidPolicy(format!(
561                "SSO token expired at {} (current time: {})",
562                exp, now
563            )));
564        }
565    } else {
566        tracing::warn!("SSO token has no 'exp' claim — cannot verify expiry");
567    }
568
569    // Validate issuer
570    match claims.get("iss").and_then(|v| v.as_str()) {
571        Some(iss) if iss == expected_issuer => {}
572        Some(iss) => {
573            return Err(AvError::InvalidPolicy(format!(
574                "SSO token issuer mismatch: got '{}', expected '{}'",
575                iss, expected_issuer
576            )));
577        }
578        None => {
579            return Err(AvError::InvalidPolicy(
580                "SSO token missing 'iss' claim".to_string(),
581            ));
582        }
583    }
584
585    // Validate audience
586    let aud_valid = match claims.get("aud") {
587        Some(serde_json::Value::String(aud)) => aud == expected_audience,
588        Some(serde_json::Value::Array(arr)) => {
589            arr.iter().any(|v| v.as_str() == Some(expected_audience))
590        }
591        _ => false,
592    };
593    if !aud_valid {
594        return Err(AvError::InvalidPolicy(format!(
595            "SSO token audience does not include expected client_id '{}'",
596            expected_audience
597        )));
598    }
599
600    Ok(claims)
601}
602
603/// Generate a random PKCE code verifier.
604fn generate_pkce_verifier() -> String {
605    use rand::Rng;
606    let mut rng = rand::rng();
607    let bytes: Vec<u8> = (0..32).map(|_| rng.random::<u8>()).collect();
608    use base64::Engine;
609    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes)
610}
611
612/// Generate PKCE code challenge from verifier (S256).
613fn generate_pkce_challenge(verifier: &str) -> String {
614    use sha2::Digest;
615    let hash = sha2::Sha256::digest(verifier.as_bytes());
616    use base64::Engine;
617    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
618}
619
620/// Generate a random state parameter.
621fn generate_state() -> String {
622    use rand::Rng;
623    let mut rng = rand::rng();
624    let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
625    hex::encode(bytes)
626}
627
628/// Escape HTML special characters to prevent XSS when embedding user input in HTML responses.
629fn html_escape(s: &str) -> String {
630    let mut out = String::with_capacity(s.len());
631    for ch in s.chars() {
632        match ch {
633            '&' => out.push_str("&amp;"),
634            '<' => out.push_str("&lt;"),
635            '>' => out.push_str("&gt;"),
636            '"' => out.push_str("&quot;"),
637            _ => out.push(ch),
638        }
639    }
640    out
641}
642
643/// Try to open a URL in the default browser.
644fn open_browser(url: &str) -> std::io::Result<()> {
645    #[cfg(target_os = "macos")]
646    {
647        std::process::Command::new("open").arg(url).spawn()?;
648    }
649    #[cfg(target_os = "linux")]
650    {
651        std::process::Command::new("xdg-open").arg(url).spawn()?;
652    }
653    #[cfg(target_os = "windows")]
654    {
655        std::process::Command::new("cmd")
656            .args(["/c", "start", url])
657            .spawn()?;
658    }
659    Ok(())
660}
661
662/// R6-M38: fetch a JSON document with a hard body-size cap.
663///
664/// `.json()` on a reqwest response eagerly buffers the whole body with no
665/// limit, so a hostile or compromised IdP could stream gigabytes into
666/// discovery/userinfo/JWKS and OOM the client. Read the body one chunk
667/// at a time with a running byte counter and bail out at `max_bytes`.
668async fn fetch_json_capped(
669    client: &reqwest::Client,
670    url: &str,
671    bearer: Option<&str>,
672    max_bytes: usize,
673    label: &str,
674) -> Result<serde_json::Value> {
675    let mut req = client.get(url);
676    if let Some(token) = bearer {
677        req = req.bearer_auth(token);
678    }
679    let mut resp = req
680        .send()
681        .await
682        .map_err(|e| AvError::InvalidPolicy(format!("{} failed: {}", label, e)))?;
683    if !resp.status().is_success() {
684        return Err(AvError::InvalidPolicy(format!(
685            "{} returned HTTP {}",
686            label,
687            resp.status()
688        )));
689    }
690    let mut body: Vec<u8> = Vec::new();
691    loop {
692        match resp.chunk().await {
693            Ok(Some(chunk)) => {
694                if body.len() + chunk.len() > max_bytes {
695                    return Err(AvError::InvalidPolicy(format!(
696                        "{} response exceeds {} byte cap",
697                        label, max_bytes
698                    )));
699                }
700                body.extend_from_slice(&chunk);
701            }
702            Ok(None) => break,
703            Err(e) => {
704                return Err(AvError::InvalidPolicy(format!(
705                    "{} body read failed: {}",
706                    label, e
707                )));
708            }
709        }
710    }
711    serde_json::from_slice(&body)
712        .map_err(|e| AvError::InvalidPolicy(format!("{} parse failed: {}", label, e)))
713}
714
715/// Listen on the callback URL for the authorization code.
716async fn listen_for_callback(redirect_uri: &str, expected_state: &str) -> Result<String> {
717    use tokio::io::AsyncReadExt;
718    use tokio::io::AsyncWriteExt;
719
720    // Parse port from redirect URI
721    let url: url::Url = redirect_uri
722        .parse()
723        .map_err(|e| AvError::InvalidPolicy(format!("Invalid redirect URI: {}", e)))?;
724    let port = url.port().unwrap_or(8400);
725
726    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
727        .await
728        .map_err(|e| {
729            AvError::InvalidPolicy(format!(
730                "Failed to start callback listener on port {}: {}",
731                port, e
732            ))
733        })?;
734
735    eprintln!(
736        "Waiting for SSO callback on port {} (5 minute timeout)...",
737        port
738    );
739
740    let (mut stream, _) =
741        tokio::time::timeout(std::time::Duration::from_secs(300), listener.accept())
742            .await
743            .map_err(|_| AvError::InvalidPolicy("SSO login timed out after 5 minutes".to_string()))?
744            .map_err(|e| AvError::InvalidPolicy(format!("Failed to accept callback: {}", e)))?;
745
746    // R6-M39: read the HTTP request until we see the end-of-headers
747    // delimiter (`\r\n\r\n`) or hit a bounded total size. The previous
748    // fixed 4096-byte single read truncated real-world browser callbacks
749    // whose cookies + referer + user-agent blew past 4 KiB, causing the
750    // code/state query string to be silently split and the login to
751    // fail with "Failed to parse callback URL". Cap the total read at
752    // 16 KiB to bound memory use under a malicious client.
753    const MAX_CALLBACK_REQUEST_BYTES: usize = 16 * 1024;
754    let mut buf: Vec<u8> = Vec::with_capacity(4096);
755    let mut scratch = [0u8; 4096];
756    loop {
757        let n = stream
758            .read(&mut scratch)
759            .await
760            .map_err(|e| AvError::InvalidPolicy(format!("Failed to read callback: {}", e)))?;
761        if n == 0 {
762            break;
763        }
764        if buf.len() + n > MAX_CALLBACK_REQUEST_BYTES {
765            return Err(AvError::InvalidPolicy(format!(
766                "SSO callback request exceeds {} byte cap",
767                MAX_CALLBACK_REQUEST_BYTES
768            )));
769        }
770        buf.extend_from_slice(&scratch[..n]);
771        // Only the request line + headers are needed to extract the
772        // code/state query parameters, so stop once we see the
773        // end-of-headers sentinel.
774        if buf.windows(4).any(|w| w == b"\r\n\r\n") {
775            break;
776        }
777    }
778    let request = String::from_utf8_lossy(&buf);
779
780    // Parse the query string from GET request
781    let first_line = request.lines().next().unwrap_or("");
782    let path = first_line.split_whitespace().nth(1).unwrap_or("/");
783    let query_url = format!("http://localhost{}", path);
784    let parsed: url::Url = query_url
785        .parse()
786        .map_err(|_| AvError::InvalidPolicy("Failed to parse callback URL".to_string()))?;
787
788    let params: std::collections::HashMap<String, String> = parsed
789        .query_pairs()
790        .map(|(k, v)| (k.to_string(), v.to_string()))
791        .collect();
792
793    // Verify state
794    let state = params.get("state").cloned().unwrap_or_default();
795    // Constant-time comparison for state to be consistent with other
796    // security-sensitive comparisons, even though loopback + random state
797    // makes timing attacks impractical.
798    let state_valid = {
799        let a = state.as_bytes();
800        let b = expected_state.as_bytes();
801        if a.len() != b.len() {
802            false
803        } else {
804            a.iter()
805                .zip(b.iter())
806                .fold(0u8, |acc, (x, y)| acc | (x ^ y))
807                == 0
808        }
809    };
810    if !state_valid {
811        let response = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>State mismatch. Please try again.</p>";
812        let _ = stream.write_all(response.as_bytes()).await;
813        return Err(AvError::InvalidPolicy(
814            "SSO state mismatch — possible CSRF".to_string(),
815        ));
816    }
817
818    // Check for error
819    if let Some(error) = params.get("error") {
820        let desc = params.get("error_description").cloned().unwrap_or_default();
821        let escaped_error = html_escape(error);
822        let escaped_desc = html_escape(&desc);
823        let response = format!("HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>{}: {}</p>", escaped_error, escaped_desc);
824        let _ = stream.write_all(response.as_bytes()).await;
825        return Err(AvError::InvalidPolicy(format!(
826            "SSO error: {}: {}",
827            error, desc
828        )));
829    }
830
831    let code = params
832        .get("code")
833        .cloned()
834        .ok_or_else(|| AvError::InvalidPolicy("No authorization code in callback".to_string()))?;
835
836    // Send success response
837    let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<h1>Audex SSO Login Successful</h1><p>You can close this window and return to the terminal.</p>";
838    let _ = stream.write_all(response.as_bytes()).await;
839
840    Ok(code)
841}
842
843/// Clear the cached SSO session.
844pub fn logout() {
845    let path = cache_path();
846    let _ = std::fs::remove_file(path);
847}
848
849#[cfg(test)]
850mod tests {
851    use super::*;
852
853    #[test]
854    fn test_sso_config_deserialize() {
855        let toml_str = r#"
856provider = "okta"
857issuer = "https://myorg.okta.com"
858client_id = "0oa1234567890"
859scopes = ["openid", "email", "groups"]
860identity_claim = "email"
861groups_claim = "groups"
862required = true
863"#;
864        let config: SsoConfig = toml::from_str(toml_str).unwrap();
865        assert!(matches!(config.provider, SsoProvider::Okta));
866        assert_eq!(config.issuer, "https://myorg.okta.com");
867        assert_eq!(config.client_id, "0oa1234567890");
868        assert!(config.required);
869        assert_eq!(config.scopes.len(), 3);
870    }
871
872    #[test]
873    fn test_sso_config_google() {
874        let toml_str = r#"
875provider = "google"
876issuer = "https://accounts.google.com"
877client_id = "12345.apps.googleusercontent.com"
878client_secret = "GOCSPX-secret"
879"#;
880        let config: SsoConfig = toml::from_str(toml_str).unwrap();
881        assert!(matches!(config.provider, SsoProvider::Google));
882        assert!(config.client_secret.is_some());
883        // Defaults
884        assert_eq!(config.identity_claim, "email");
885        assert_eq!(config.groups_claim, "groups");
886        assert_eq!(config.redirect_uri, "http://localhost:8400/callback");
887        assert!(!config.required);
888    }
889
890    #[test]
891    fn test_sso_identity_serialize() {
892        let identity = SsoIdentity {
893            identity: "alice@example.com".to_string(),
894            groups: vec!["backend-team".to_string(), "devops".to_string()],
895            claims: serde_json::json!({"email": "alice@example.com", "sub": "user123"}),
896        };
897        let json = serde_json::to_string(&identity).unwrap();
898        assert!(json.contains("alice@example.com"));
899        assert!(json.contains("backend-team"));
900        // Roundtrip
901        let parsed: SsoIdentity = serde_json::from_str(&json).unwrap();
902        assert_eq!(parsed.identity, "alice@example.com");
903        assert_eq!(parsed.groups.len(), 2);
904    }
905
906    #[test]
907    fn test_pkce_verifier_and_challenge() {
908        let verifier = generate_pkce_verifier();
909        assert!(!verifier.is_empty());
910        let challenge = generate_pkce_challenge(&verifier);
911        assert!(!challenge.is_empty());
912        assert_ne!(verifier, challenge);
913    }
914
915    #[test]
916    fn test_state_generation() {
917        let state1 = generate_state();
918        let state2 = generate_state();
919        assert_ne!(state1, state2);
920        assert_eq!(state1.len(), 32); // 16 bytes = 32 hex chars
921    }
922
923    #[test]
924    fn test_decode_jwt_claims_unverified() {
925        // Create a minimal JWT with base64url-encoded payload
926        use base64::Engine;
927        let future_exp = chrono::Utc::now().timestamp() + 3600;
928        let claims = serde_json::json!({
929            "email": "test@example.com",
930            "groups": ["admin"],
931            "exp": future_exp,
932            "iss": "https://accounts.example.com",
933            "aud": "my-client-id"
934        });
935        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
936            .encode(serde_json::to_vec(&claims).unwrap());
937        let fake_jwt = format!("header.{}.signature", payload);
938
939        let decoded =
940            decode_jwt_claims_unverified(&fake_jwt, "https://accounts.example.com", "my-client-id")
941                .unwrap();
942        assert_eq!(decoded["email"], "test@example.com");
943        assert_eq!(decoded["groups"][0], "admin");
944    }
945
946    #[test]
947    fn test_decode_jwt_invalid() {
948        assert!(decode_jwt_claims_unverified("not-a-jwt", "iss", "aud").is_err());
949        assert!(decode_jwt_claims_unverified("a.b", "iss", "aud").is_err());
950    }
951
952    #[test]
953    fn test_decode_jwt_expired() {
954        use base64::Engine;
955        let past_exp = chrono::Utc::now().timestamp() - 3600;
956        let claims = serde_json::json!({"email": "test@example.com", "exp": past_exp, "iss": "https://iss", "aud": "aud"});
957        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
958            .encode(serde_json::to_vec(&claims).unwrap());
959        let fake_jwt = format!("header.{}.signature", payload);
960        let err = decode_jwt_claims_unverified(&fake_jwt, "https://iss", "aud").unwrap_err();
961        assert!(err.to_string().contains("expired"));
962    }
963
964    #[test]
965    fn test_decode_jwt_wrong_issuer() {
966        use base64::Engine;
967        let future_exp = chrono::Utc::now().timestamp() + 3600;
968        let claims =
969            serde_json::json!({"exp": future_exp, "iss": "https://evil.com", "aud": "my-client"});
970        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
971            .encode(serde_json::to_vec(&claims).unwrap());
972        let fake_jwt = format!("header.{}.signature", payload);
973        let err =
974            decode_jwt_claims_unverified(&fake_jwt, "https://legit.com", "my-client").unwrap_err();
975        assert!(err.to_string().contains("issuer mismatch"));
976    }
977
978    #[test]
979    fn test_decode_jwt_wrong_audience() {
980        use base64::Engine;
981        let future_exp = chrono::Utc::now().timestamp() + 3600;
982        let claims =
983            serde_json::json!({"exp": future_exp, "iss": "https://iss", "aud": "other-client"});
984        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
985            .encode(serde_json::to_vec(&claims).unwrap());
986        let fake_jwt = format!("header.{}.signature", payload);
987        let err = decode_jwt_claims_unverified(&fake_jwt, "https://iss", "my-client").unwrap_err();
988        assert!(err.to_string().contains("audience"));
989    }
990}