Skip to main content

tryaudex_core/
sso.rs

1use serde::{Deserialize, Serialize};
2
3use crate::error::{AvError, Result};
4
5/// SSO provider type.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum SsoProvider {
9    Okta,
10    Auth0,
11    Google,
12    /// Generic OIDC provider
13    Oidc,
14}
15
16/// SSO configuration in `[sso]` config section.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SsoConfig {
19    /// SSO provider type
20    pub provider: SsoProvider,
21    /// OIDC issuer URL (e.g. "https://myorg.okta.com", "https://accounts.google.com")
22    pub issuer: String,
23    /// OIDC client ID
24    pub client_id: String,
25    /// OIDC client secret (optional, for confidential clients)
26    pub client_secret: Option<String>,
27    /// Redirect URI for the device/local auth flow
28    #[serde(default = "default_redirect")]
29    pub redirect_uri: String,
30    /// OIDC scopes to request
31    #[serde(default = "default_scopes")]
32    pub scopes: Vec<String>,
33    /// Claim to use as the identity (default: "email")
34    #[serde(default = "default_identity_claim")]
35    pub identity_claim: String,
36    /// Claim to use for group/team membership (default: "groups")
37    #[serde(default = "default_groups_claim")]
38    pub groups_claim: String,
39    /// Whether SSO is required (if false, falls back to local identity)
40    #[serde(default)]
41    pub required: bool,
42}
43
44fn default_redirect() -> String {
45    "http://localhost:8400/callback".to_string()
46}
47
48fn default_scopes() -> Vec<String> {
49    vec![
50        "openid".to_string(),
51        "email".to_string(),
52        "profile".to_string(),
53    ]
54}
55
56fn default_identity_claim() -> String {
57    "email".to_string()
58}
59
60fn default_groups_claim() -> String {
61    "groups".to_string()
62}
63
64/// Token response from the OIDC provider.
65#[derive(Debug, Deserialize)]
66struct TokenResponse {
67    access_token: String,
68    id_token: Option<String>,
69    #[allow(dead_code)]
70    token_type: Option<String>,
71    #[allow(dead_code)]
72    expires_in: Option<u64>,
73}
74
75/// Decoded identity from an SSO token.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SsoIdentity {
78    /// The user's identity (email, subject, etc.)
79    pub identity: String,
80    /// Group/team memberships from the SSO provider
81    pub groups: Vec<String>,
82    /// Raw claims from the ID token
83    pub claims: serde_json::Value,
84}
85
86/// Cached SSO session stored on disk.
87#[derive(Debug, Serialize, Deserialize)]
88struct CachedSsoSession {
89    identity: SsoIdentity,
90    expires_at: i64,
91}
92
93fn cache_path() -> std::path::PathBuf {
94    dirs::cache_dir()
95        .unwrap_or_else(|| std::path::PathBuf::from("."))
96        .join("audex")
97        .join("sso_session.json")
98}
99
100/// Load cached SSO session if still valid (encrypted at rest).
101pub fn load_cached_identity() -> Option<SsoIdentity> {
102    let path = cache_path();
103    let cached: CachedSsoSession = crate::keystore::decrypt_from_file(&path).ok()??;
104    let now = chrono::Utc::now().timestamp();
105    if now < cached.expires_at {
106        Some(cached.identity)
107    } else {
108        None
109    }
110}
111
112fn save_cached_identity(identity: &SsoIdentity, ttl_secs: u64) {
113    let cached = CachedSsoSession {
114        identity: identity.clone(),
115        expires_at: chrono::Utc::now().timestamp() + ttl_secs as i64,
116    };
117    let path = cache_path();
118    if let Some(parent) = path.parent() {
119        let _ = std::fs::create_dir_all(parent);
120    }
121    let _ = crate::keystore::encrypt_to_file(&path, &cached);
122}
123
124/// Authenticate via SSO using the device authorization flow.
125/// Opens a browser for login and listens on a local callback URL.
126pub async fn authenticate(config: &SsoConfig) -> Result<SsoIdentity> {
127    // Check cache first
128    if let Some(cached) = load_cached_identity() {
129        return Ok(cached);
130    }
131
132    let discovery_url = format!(
133        "{}/.well-known/openid-configuration",
134        config.issuer.trim_end_matches('/')
135    );
136
137    let client = reqwest::Client::new();
138
139    // Fetch OIDC discovery document
140    let discovery: serde_json::Value = client
141        .get(&discovery_url)
142        .send()
143        .await
144        .map_err(|e| AvError::InvalidPolicy(format!("SSO discovery failed: {}", e)))?
145        .json()
146        .await
147        .map_err(|e| AvError::InvalidPolicy(format!("SSO discovery parse failed: {}", e)))?;
148
149    let auth_endpoint = discovery["authorization_endpoint"]
150        .as_str()
151        .ok_or_else(|| {
152            AvError::InvalidPolicy("Missing authorization_endpoint in OIDC discovery".to_string())
153        })?;
154    let token_endpoint = discovery["token_endpoint"].as_str().ok_or_else(|| {
155        AvError::InvalidPolicy("Missing token_endpoint in OIDC discovery".to_string())
156    })?;
157
158    // Generate PKCE challenge
159    let verifier = generate_pkce_verifier();
160    let challenge = generate_pkce_challenge(&verifier);
161    let state = generate_state();
162
163    let scopes = config.scopes.join(" ");
164    let auth_url = format!(
165        "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}&code_challenge={}&code_challenge_method=S256",
166        auth_endpoint,
167        urlencoding::encode(&config.client_id),
168        urlencoding::encode(&config.redirect_uri),
169        urlencoding::encode(&scopes),
170        urlencoding::encode(&state),
171        urlencoding::encode(&challenge),
172    );
173
174    eprintln!("Opening browser for SSO login...");
175    eprintln!("If the browser doesn't open, visit:\n  {}\n", auth_url);
176
177    // Try to open browser
178    let _ = open_browser(&auth_url);
179
180    // Start local server to receive the callback
181    let code = listen_for_callback(&config.redirect_uri, &state).await?;
182
183    // Exchange code for tokens
184    let mut token_params = vec![
185        ("grant_type", "authorization_code".to_string()),
186        ("code", code),
187        ("redirect_uri", config.redirect_uri.clone()),
188        ("client_id", config.client_id.clone()),
189        ("code_verifier", verifier),
190    ];
191    if let Some(ref secret) = config.client_secret {
192        token_params.push(("client_secret", secret.clone()));
193    }
194
195    let token_resp: TokenResponse = client
196        .post(token_endpoint)
197        .form(&token_params)
198        .send()
199        .await
200        .map_err(|e| AvError::InvalidPolicy(format!("SSO token exchange failed: {}", e)))?
201        .json()
202        .await
203        .map_err(|e| AvError::InvalidPolicy(format!("SSO token parse failed: {}", e)))?;
204
205    // Decode the ID token (we only need the claims, not full validation)
206    let claims = if let Some(ref id_token) = token_resp.id_token {
207        decode_jwt_claims(id_token)?
208    } else {
209        // Use userinfo endpoint as fallback
210        let userinfo_url = discovery["userinfo_endpoint"].as_str().ok_or_else(|| {
211            AvError::InvalidPolicy("No id_token and no userinfo_endpoint".to_string())
212        })?;
213
214        client
215            .get(userinfo_url)
216            .bearer_auth(&token_resp.access_token)
217            .send()
218            .await
219            .map_err(|e| AvError::InvalidPolicy(format!("SSO userinfo failed: {}", e)))?
220            .json()
221            .await
222            .map_err(|e| AvError::InvalidPolicy(format!("SSO userinfo parse failed: {}", e)))?
223    };
224
225    let identity_str = claims[&config.identity_claim]
226        .as_str()
227        .unwrap_or("unknown")
228        .to_string();
229
230    let groups: Vec<String> = claims[&config.groups_claim]
231        .as_array()
232        .map(|arr| {
233            arr.iter()
234                .filter_map(|v| v.as_str().map(|s| s.to_string()))
235                .collect()
236        })
237        .unwrap_or_default();
238
239    let sso_identity = SsoIdentity {
240        identity: identity_str,
241        groups,
242        claims,
243    };
244
245    // Cache for 1 hour
246    save_cached_identity(&sso_identity, 3600);
247
248    Ok(sso_identity)
249}
250
251/// Decode JWT claims without signature verification (we trust the HTTPS connection).
252fn decode_jwt_claims(token: &str) -> Result<serde_json::Value> {
253    let parts: Vec<&str> = token.split('.').collect();
254    if parts.len() != 3 {
255        return Err(AvError::InvalidPolicy("Invalid JWT format".to_string()));
256    }
257
258    use base64::Engine;
259    let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
260        .decode(parts[1])
261        .map_err(|e| AvError::InvalidPolicy(format!("JWT base64 decode failed: {}", e)))?;
262
263    serde_json::from_slice(&payload)
264        .map_err(|e| AvError::InvalidPolicy(format!("JWT claims parse failed: {}", e)))
265}
266
267/// Generate a random PKCE code verifier.
268fn generate_pkce_verifier() -> String {
269    use rand::Rng;
270    let mut rng = rand::rng();
271    let bytes: Vec<u8> = (0..32).map(|_| rng.random::<u8>()).collect();
272    use base64::Engine;
273    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes)
274}
275
276/// Generate PKCE code challenge from verifier (S256).
277fn generate_pkce_challenge(verifier: &str) -> String {
278    use sha2::Digest;
279    let hash = sha2::Sha256::digest(verifier.as_bytes());
280    use base64::Engine;
281    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
282}
283
284/// Generate a random state parameter.
285fn generate_state() -> String {
286    use rand::Rng;
287    let mut rng = rand::rng();
288    let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
289    hex::encode(bytes)
290}
291
292/// Try to open a URL in the default browser.
293fn open_browser(url: &str) -> std::io::Result<()> {
294    #[cfg(target_os = "macos")]
295    {
296        std::process::Command::new("open").arg(url).spawn()?;
297    }
298    #[cfg(target_os = "linux")]
299    {
300        std::process::Command::new("xdg-open").arg(url).spawn()?;
301    }
302    #[cfg(target_os = "windows")]
303    {
304        std::process::Command::new("cmd")
305            .args(["/c", "start", url])
306            .spawn()?;
307    }
308    Ok(())
309}
310
311/// Listen on the callback URL for the authorization code.
312async fn listen_for_callback(redirect_uri: &str, expected_state: &str) -> Result<String> {
313    use tokio::io::AsyncReadExt;
314    use tokio::io::AsyncWriteExt;
315
316    // Parse port from redirect URI
317    let url: url::Url = redirect_uri
318        .parse()
319        .map_err(|e| AvError::InvalidPolicy(format!("Invalid redirect URI: {}", e)))?;
320    let port = url.port().unwrap_or(8400);
321
322    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
323        .await
324        .map_err(|e| {
325            AvError::InvalidPolicy(format!(
326                "Failed to start callback listener on port {}: {}",
327                port, e
328            ))
329        })?;
330
331    eprintln!("Waiting for SSO callback on port {}...", port);
332
333    let (mut stream, _) = listener
334        .accept()
335        .await
336        .map_err(|e| AvError::InvalidPolicy(format!("Failed to accept callback: {}", e)))?;
337
338    let mut buf = vec![0u8; 4096];
339    let n = stream
340        .read(&mut buf)
341        .await
342        .map_err(|e| AvError::InvalidPolicy(format!("Failed to read callback: {}", e)))?;
343
344    let request = String::from_utf8_lossy(&buf[..n]);
345
346    // Parse the query string from GET request
347    let first_line = request.lines().next().unwrap_or("");
348    let path = first_line.split_whitespace().nth(1).unwrap_or("/");
349    let query_url = format!("http://localhost{}", path);
350    let parsed: url::Url = query_url
351        .parse()
352        .map_err(|_| AvError::InvalidPolicy("Failed to parse callback URL".to_string()))?;
353
354    let params: std::collections::HashMap<String, String> = parsed
355        .query_pairs()
356        .map(|(k, v)| (k.to_string(), v.to_string()))
357        .collect();
358
359    // Verify state
360    let state = params.get("state").cloned().unwrap_or_default();
361    if state != expected_state {
362        let response = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>State mismatch. Please try again.</p>";
363        let _ = stream.write_all(response.as_bytes()).await;
364        return Err(AvError::InvalidPolicy(
365            "SSO state mismatch — possible CSRF".to_string(),
366        ));
367    }
368
369    // Check for error
370    if let Some(error) = params.get("error") {
371        let desc = params.get("error_description").cloned().unwrap_or_default();
372        let response = format!("HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>{}: {}</p>", error, desc);
373        let _ = stream.write_all(response.as_bytes()).await;
374        return Err(AvError::InvalidPolicy(format!(
375            "SSO error: {}: {}",
376            error, desc
377        )));
378    }
379
380    let code = params
381        .get("code")
382        .cloned()
383        .ok_or_else(|| AvError::InvalidPolicy("No authorization code in callback".to_string()))?;
384
385    // Send success response
386    let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<h1>Audex SSO Login Successful</h1><p>You can close this window and return to the terminal.</p>";
387    let _ = stream.write_all(response.as_bytes()).await;
388
389    Ok(code)
390}
391
392/// Clear the cached SSO session.
393pub fn logout() {
394    let path = cache_path();
395    let _ = std::fs::remove_file(path);
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    #[test]
403    fn test_sso_config_deserialize() {
404        let toml_str = r#"
405provider = "okta"
406issuer = "https://myorg.okta.com"
407client_id = "0oa1234567890"
408scopes = ["openid", "email", "groups"]
409identity_claim = "email"
410groups_claim = "groups"
411required = true
412"#;
413        let config: SsoConfig = toml::from_str(toml_str).unwrap();
414        assert!(matches!(config.provider, SsoProvider::Okta));
415        assert_eq!(config.issuer, "https://myorg.okta.com");
416        assert_eq!(config.client_id, "0oa1234567890");
417        assert!(config.required);
418        assert_eq!(config.scopes.len(), 3);
419    }
420
421    #[test]
422    fn test_sso_config_google() {
423        let toml_str = r#"
424provider = "google"
425issuer = "https://accounts.google.com"
426client_id = "12345.apps.googleusercontent.com"
427client_secret = "GOCSPX-secret"
428"#;
429        let config: SsoConfig = toml::from_str(toml_str).unwrap();
430        assert!(matches!(config.provider, SsoProvider::Google));
431        assert!(config.client_secret.is_some());
432        // Defaults
433        assert_eq!(config.identity_claim, "email");
434        assert_eq!(config.groups_claim, "groups");
435        assert_eq!(config.redirect_uri, "http://localhost:8400/callback");
436        assert!(!config.required);
437    }
438
439    #[test]
440    fn test_sso_identity_serialize() {
441        let identity = SsoIdentity {
442            identity: "alice@example.com".to_string(),
443            groups: vec!["backend-team".to_string(), "devops".to_string()],
444            claims: serde_json::json!({"email": "alice@example.com", "sub": "user123"}),
445        };
446        let json = serde_json::to_string(&identity).unwrap();
447        assert!(json.contains("alice@example.com"));
448        assert!(json.contains("backend-team"));
449        // Roundtrip
450        let parsed: SsoIdentity = serde_json::from_str(&json).unwrap();
451        assert_eq!(parsed.identity, "alice@example.com");
452        assert_eq!(parsed.groups.len(), 2);
453    }
454
455    #[test]
456    fn test_pkce_verifier_and_challenge() {
457        let verifier = generate_pkce_verifier();
458        assert!(!verifier.is_empty());
459        let challenge = generate_pkce_challenge(&verifier);
460        assert!(!challenge.is_empty());
461        assert_ne!(verifier, challenge);
462    }
463
464    #[test]
465    fn test_state_generation() {
466        let state1 = generate_state();
467        let state2 = generate_state();
468        assert_ne!(state1, state2);
469        assert_eq!(state1.len(), 32); // 16 bytes = 32 hex chars
470    }
471
472    #[test]
473    fn test_decode_jwt_claims() {
474        // Create a minimal JWT with base64url-encoded payload
475        use base64::Engine;
476        let claims = serde_json::json!({"email": "test@example.com", "groups": ["admin"]});
477        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
478            .encode(serde_json::to_vec(&claims).unwrap());
479        let fake_jwt = format!("header.{}.signature", payload);
480
481        let decoded = decode_jwt_claims(&fake_jwt).unwrap();
482        assert_eq!(decoded["email"], "test@example.com");
483        assert_eq!(decoded["groups"][0], "admin");
484    }
485
486    #[test]
487    fn test_decode_jwt_invalid() {
488        assert!(decode_jwt_claims("not-a-jwt").is_err());
489        assert!(decode_jwt_claims("a.b").is_err());
490    }
491}