Skip to main content

tryaudex_core/
sso.rs

1use serde::{Deserialize, Serialize};
2
3use crate::error::{AvError, Result};
4
5/// SSO provider type.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum SsoProvider {
9    Okta,
10    Auth0,
11    Google,
12    /// Generic OIDC provider
13    Oidc,
14}
15
16/// SSO configuration in `[sso]` config section.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SsoConfig {
19    /// SSO provider type
20    pub provider: SsoProvider,
21    /// OIDC issuer URL (e.g. "https://myorg.okta.com", "https://accounts.google.com")
22    pub issuer: String,
23    /// OIDC client ID
24    pub client_id: String,
25    /// OIDC client secret (optional, for confidential clients)
26    pub client_secret: Option<String>,
27    /// Redirect URI for the device/local auth flow
28    #[serde(default = "default_redirect")]
29    pub redirect_uri: String,
30    /// OIDC scopes to request
31    #[serde(default = "default_scopes")]
32    pub scopes: Vec<String>,
33    /// Claim to use as the identity (default: "email")
34    #[serde(default = "default_identity_claim")]
35    pub identity_claim: String,
36    /// Claim to use for group/team membership (default: "groups")
37    #[serde(default = "default_groups_claim")]
38    pub groups_claim: String,
39    /// Whether SSO is required (if false, falls back to local identity)
40    #[serde(default)]
41    pub required: bool,
42}
43
44fn default_redirect() -> String {
45    "http://localhost:8400/callback".to_string()
46}
47
48fn default_scopes() -> Vec<String> {
49    vec!["openid".to_string(), "email".to_string(), "profile".to_string()]
50}
51
52fn default_identity_claim() -> String {
53    "email".to_string()
54}
55
56fn default_groups_claim() -> String {
57    "groups".to_string()
58}
59
60/// Token response from the OIDC provider.
61#[derive(Debug, Deserialize)]
62struct TokenResponse {
63    access_token: String,
64    id_token: Option<String>,
65    #[allow(dead_code)]
66    token_type: Option<String>,
67    #[allow(dead_code)]
68    expires_in: Option<u64>,
69}
70
71/// Decoded identity from an SSO token.
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct SsoIdentity {
74    /// The user's identity (email, subject, etc.)
75    pub identity: String,
76    /// Group/team memberships from the SSO provider
77    pub groups: Vec<String>,
78    /// Raw claims from the ID token
79    pub claims: serde_json::Value,
80}
81
82/// Cached SSO session stored on disk.
83#[derive(Debug, Serialize, Deserialize)]
84struct CachedSsoSession {
85    identity: SsoIdentity,
86    expires_at: i64,
87}
88
89fn cache_path() -> std::path::PathBuf {
90    dirs::cache_dir()
91        .unwrap_or_else(|| std::path::PathBuf::from("."))
92        .join("audex")
93        .join("sso_session.json")
94}
95
96/// Load cached SSO session if still valid (encrypted at rest).
97pub fn load_cached_identity() -> Option<SsoIdentity> {
98    let path = cache_path();
99    let cached: CachedSsoSession = crate::keystore::decrypt_from_file(&path).ok()??;
100    let now = chrono::Utc::now().timestamp();
101    if now < cached.expires_at {
102        Some(cached.identity)
103    } else {
104        None
105    }
106}
107
108fn save_cached_identity(identity: &SsoIdentity, ttl_secs: u64) {
109    let cached = CachedSsoSession {
110        identity: identity.clone(),
111        expires_at: chrono::Utc::now().timestamp() + ttl_secs as i64,
112    };
113    let path = cache_path();
114    if let Some(parent) = path.parent() {
115        let _ = std::fs::create_dir_all(parent);
116    }
117    let _ = crate::keystore::encrypt_to_file(&path, &cached);
118}
119
120/// Authenticate via SSO using the device authorization flow.
121/// Opens a browser for login and listens on a local callback URL.
122pub async fn authenticate(config: &SsoConfig) -> Result<SsoIdentity> {
123    // Check cache first
124    if let Some(cached) = load_cached_identity() {
125        return Ok(cached);
126    }
127
128    let discovery_url = format!(
129        "{}/.well-known/openid-configuration",
130        config.issuer.trim_end_matches('/')
131    );
132
133    let client = reqwest::Client::new();
134
135    // Fetch OIDC discovery document
136    let discovery: serde_json::Value = client
137        .get(&discovery_url)
138        .send()
139        .await
140        .map_err(|e| AvError::InvalidPolicy(format!("SSO discovery failed: {}", e)))?
141        .json()
142        .await
143        .map_err(|e| AvError::InvalidPolicy(format!("SSO discovery parse failed: {}", e)))?;
144
145    let auth_endpoint = discovery["authorization_endpoint"]
146        .as_str()
147        .ok_or_else(|| AvError::InvalidPolicy("Missing authorization_endpoint in OIDC discovery".to_string()))?;
148    let token_endpoint = discovery["token_endpoint"]
149        .as_str()
150        .ok_or_else(|| AvError::InvalidPolicy("Missing token_endpoint in OIDC discovery".to_string()))?;
151
152    // Generate PKCE challenge
153    let verifier = generate_pkce_verifier();
154    let challenge = generate_pkce_challenge(&verifier);
155    let state = generate_state();
156
157    let scopes = config.scopes.join(" ");
158    let auth_url = format!(
159        "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}&code_challenge={}&code_challenge_method=S256",
160        auth_endpoint,
161        urlencoding::encode(&config.client_id),
162        urlencoding::encode(&config.redirect_uri),
163        urlencoding::encode(&scopes),
164        urlencoding::encode(&state),
165        urlencoding::encode(&challenge),
166    );
167
168    eprintln!("Opening browser for SSO login...");
169    eprintln!("If the browser doesn't open, visit:\n  {}\n", auth_url);
170
171    // Try to open browser
172    let _ = open_browser(&auth_url);
173
174    // Start local server to receive the callback
175    let code = listen_for_callback(&config.redirect_uri, &state).await?;
176
177    // Exchange code for tokens
178    let mut token_params = vec![
179        ("grant_type", "authorization_code".to_string()),
180        ("code", code),
181        ("redirect_uri", config.redirect_uri.clone()),
182        ("client_id", config.client_id.clone()),
183        ("code_verifier", verifier),
184    ];
185    if let Some(ref secret) = config.client_secret {
186        token_params.push(("client_secret", secret.clone()));
187    }
188
189    let token_resp: TokenResponse = client
190        .post(token_endpoint)
191        .form(&token_params)
192        .send()
193        .await
194        .map_err(|e| AvError::InvalidPolicy(format!("SSO token exchange failed: {}", e)))?
195        .json()
196        .await
197        .map_err(|e| AvError::InvalidPolicy(format!("SSO token parse failed: {}", e)))?;
198
199    // Decode the ID token (we only need the claims, not full validation)
200    let claims = if let Some(ref id_token) = token_resp.id_token {
201        decode_jwt_claims(id_token)?
202    } else {
203        // Use userinfo endpoint as fallback
204        let userinfo_url = discovery["userinfo_endpoint"]
205            .as_str()
206            .ok_or_else(|| AvError::InvalidPolicy("No id_token and no userinfo_endpoint".to_string()))?;
207
208        client
209            .get(userinfo_url)
210            .bearer_auth(&token_resp.access_token)
211            .send()
212            .await
213            .map_err(|e| AvError::InvalidPolicy(format!("SSO userinfo failed: {}", e)))?
214            .json()
215            .await
216            .map_err(|e| AvError::InvalidPolicy(format!("SSO userinfo parse failed: {}", e)))?
217    };
218
219    let identity_str = claims[&config.identity_claim]
220        .as_str()
221        .unwrap_or("unknown")
222        .to_string();
223
224    let groups: Vec<String> = claims[&config.groups_claim]
225        .as_array()
226        .map(|arr| {
227            arr.iter()
228                .filter_map(|v| v.as_str().map(|s| s.to_string()))
229                .collect()
230        })
231        .unwrap_or_default();
232
233    let sso_identity = SsoIdentity {
234        identity: identity_str,
235        groups,
236        claims,
237    };
238
239    // Cache for 1 hour
240    save_cached_identity(&sso_identity, 3600);
241
242    Ok(sso_identity)
243}
244
245/// Decode JWT claims without signature verification (we trust the HTTPS connection).
246fn decode_jwt_claims(token: &str) -> Result<serde_json::Value> {
247    let parts: Vec<&str> = token.split('.').collect();
248    if parts.len() != 3 {
249        return Err(AvError::InvalidPolicy("Invalid JWT format".to_string()));
250    }
251
252    use base64::Engine;
253    let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
254        .decode(parts[1])
255        .map_err(|e| AvError::InvalidPolicy(format!("JWT base64 decode failed: {}", e)))?;
256
257    serde_json::from_slice(&payload)
258        .map_err(|e| AvError::InvalidPolicy(format!("JWT claims parse failed: {}", e)))
259}
260
261/// Generate a random PKCE code verifier.
262fn generate_pkce_verifier() -> String {
263    use rand::Rng;
264    let mut rng = rand::rng();
265    let bytes: Vec<u8> = (0..32).map(|_| rng.random::<u8>()).collect();
266    use base64::Engine;
267    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes)
268}
269
270/// Generate PKCE code challenge from verifier (S256).
271fn generate_pkce_challenge(verifier: &str) -> String {
272    use sha2::Digest;
273    let hash = sha2::Sha256::digest(verifier.as_bytes());
274    use base64::Engine;
275    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
276}
277
278/// Generate a random state parameter.
279fn generate_state() -> String {
280    use rand::Rng;
281    let mut rng = rand::rng();
282    let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
283    hex::encode(bytes)
284}
285
286/// Try to open a URL in the default browser.
287fn open_browser(url: &str) -> std::io::Result<()> {
288    #[cfg(target_os = "macos")]
289    {
290        std::process::Command::new("open").arg(url).spawn()?;
291    }
292    #[cfg(target_os = "linux")]
293    {
294        std::process::Command::new("xdg-open").arg(url).spawn()?;
295    }
296    #[cfg(target_os = "windows")]
297    {
298        std::process::Command::new("cmd").args(["/c", "start", url]).spawn()?;
299    }
300    Ok(())
301}
302
303/// Listen on the callback URL for the authorization code.
304async fn listen_for_callback(redirect_uri: &str, expected_state: &str) -> Result<String> {
305    use tokio::io::AsyncReadExt;
306    use tokio::io::AsyncWriteExt;
307
308    // Parse port from redirect URI
309    let url: url::Url = redirect_uri
310        .parse()
311        .map_err(|e| AvError::InvalidPolicy(format!("Invalid redirect URI: {}", e)))?;
312    let port = url.port().unwrap_or(8400);
313
314    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
315        .await
316        .map_err(|e| AvError::InvalidPolicy(format!("Failed to start callback listener on port {}: {}", port, e)))?;
317
318    eprintln!("Waiting for SSO callback on port {}...", port);
319
320    let (mut stream, _) = listener
321        .accept()
322        .await
323        .map_err(|e| AvError::InvalidPolicy(format!("Failed to accept callback: {}", e)))?;
324
325    let mut buf = vec![0u8; 4096];
326    let n = stream
327        .read(&mut buf)
328        .await
329        .map_err(|e| AvError::InvalidPolicy(format!("Failed to read callback: {}", e)))?;
330
331    let request = String::from_utf8_lossy(&buf[..n]);
332
333    // Parse the query string from GET request
334    let first_line = request.lines().next().unwrap_or("");
335    let path = first_line.split_whitespace().nth(1).unwrap_or("/");
336    let query_url = format!("http://localhost{}", path);
337    let parsed: url::Url = query_url
338        .parse()
339        .map_err(|_| AvError::InvalidPolicy("Failed to parse callback URL".to_string()))?;
340
341    let params: std::collections::HashMap<String, String> =
342        parsed.query_pairs().map(|(k, v)| (k.to_string(), v.to_string())).collect();
343
344    // Verify state
345    let state = params.get("state").cloned().unwrap_or_default();
346    if state != expected_state {
347        let response = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>State mismatch. Please try again.</p>";
348        let _ = stream.write_all(response.as_bytes()).await;
349        return Err(AvError::InvalidPolicy("SSO state mismatch — possible CSRF".to_string()));
350    }
351
352    // Check for error
353    if let Some(error) = params.get("error") {
354        let desc = params.get("error_description").cloned().unwrap_or_default();
355        let response = format!("HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>{}: {}</p>", error, desc);
356        let _ = stream.write_all(response.as_bytes()).await;
357        return Err(AvError::InvalidPolicy(format!("SSO error: {}: {}", error, desc)));
358    }
359
360    let code = params.get("code").cloned().ok_or_else(|| {
361        AvError::InvalidPolicy("No authorization code in callback".to_string())
362    })?;
363
364    // Send success response
365    let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<h1>Audex SSO Login Successful</h1><p>You can close this window and return to the terminal.</p>";
366    let _ = stream.write_all(response.as_bytes()).await;
367
368    Ok(code)
369}
370
371/// Clear the cached SSO session.
372pub fn logout() {
373    let path = cache_path();
374    let _ = std::fs::remove_file(path);
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_sso_config_deserialize() {
383        let toml_str = r#"
384provider = "okta"
385issuer = "https://myorg.okta.com"
386client_id = "0oa1234567890"
387scopes = ["openid", "email", "groups"]
388identity_claim = "email"
389groups_claim = "groups"
390required = true
391"#;
392        let config: SsoConfig = toml::from_str(toml_str).unwrap();
393        assert!(matches!(config.provider, SsoProvider::Okta));
394        assert_eq!(config.issuer, "https://myorg.okta.com");
395        assert_eq!(config.client_id, "0oa1234567890");
396        assert!(config.required);
397        assert_eq!(config.scopes.len(), 3);
398    }
399
400    #[test]
401    fn test_sso_config_google() {
402        let toml_str = r#"
403provider = "google"
404issuer = "https://accounts.google.com"
405client_id = "12345.apps.googleusercontent.com"
406client_secret = "GOCSPX-secret"
407"#;
408        let config: SsoConfig = toml::from_str(toml_str).unwrap();
409        assert!(matches!(config.provider, SsoProvider::Google));
410        assert!(config.client_secret.is_some());
411        // Defaults
412        assert_eq!(config.identity_claim, "email");
413        assert_eq!(config.groups_claim, "groups");
414        assert_eq!(config.redirect_uri, "http://localhost:8400/callback");
415        assert!(!config.required);
416    }
417
418    #[test]
419    fn test_sso_identity_serialize() {
420        let identity = SsoIdentity {
421            identity: "alice@example.com".to_string(),
422            groups: vec!["backend-team".to_string(), "devops".to_string()],
423            claims: serde_json::json!({"email": "alice@example.com", "sub": "user123"}),
424        };
425        let json = serde_json::to_string(&identity).unwrap();
426        assert!(json.contains("alice@example.com"));
427        assert!(json.contains("backend-team"));
428        // Roundtrip
429        let parsed: SsoIdentity = serde_json::from_str(&json).unwrap();
430        assert_eq!(parsed.identity, "alice@example.com");
431        assert_eq!(parsed.groups.len(), 2);
432    }
433
434    #[test]
435    fn test_pkce_verifier_and_challenge() {
436        let verifier = generate_pkce_verifier();
437        assert!(!verifier.is_empty());
438        let challenge = generate_pkce_challenge(&verifier);
439        assert!(!challenge.is_empty());
440        assert_ne!(verifier, challenge);
441    }
442
443    #[test]
444    fn test_state_generation() {
445        let state1 = generate_state();
446        let state2 = generate_state();
447        assert_ne!(state1, state2);
448        assert_eq!(state1.len(), 32); // 16 bytes = 32 hex chars
449    }
450
451    #[test]
452    fn test_decode_jwt_claims() {
453        // Create a minimal JWT with base64url-encoded payload
454        use base64::Engine;
455        let claims = serde_json::json!({"email": "test@example.com", "groups": ["admin"]});
456        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
457            .encode(serde_json::to_vec(&claims).unwrap());
458        let fake_jwt = format!("header.{}.signature", payload);
459
460        let decoded = decode_jwt_claims(&fake_jwt).unwrap();
461        assert_eq!(decoded["email"], "test@example.com");
462        assert_eq!(decoded["groups"][0], "admin");
463    }
464
465    #[test]
466    fn test_decode_jwt_invalid() {
467        assert!(decode_jwt_claims("not-a-jwt").is_err());
468        assert!(decode_jwt_claims("a.b").is_err());
469    }
470}