1use serde::{Deserialize, Serialize};
2
3use crate::error::{AvError, Result};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum SsoProvider {
9 Okta,
10 Auth0,
11 Google,
12 Oidc,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SsoConfig {
19 pub provider: SsoProvider,
21 pub issuer: String,
23 pub client_id: String,
25 pub client_secret: Option<String>,
27 #[serde(default = "default_redirect")]
29 pub redirect_uri: String,
30 #[serde(default = "default_scopes")]
32 pub scopes: Vec<String>,
33 #[serde(default = "default_identity_claim")]
35 pub identity_claim: String,
36 #[serde(default = "default_groups_claim")]
38 pub groups_claim: String,
39 #[serde(default)]
41 pub required: bool,
42}
43
44fn default_redirect() -> String {
45 "http://localhost:8400/callback".to_string()
46}
47
48fn default_scopes() -> Vec<String> {
49 vec![
50 "openid".to_string(),
51 "email".to_string(),
52 "profile".to_string(),
53 ]
54}
55
56fn default_identity_claim() -> String {
57 "email".to_string()
58}
59
60fn default_groups_claim() -> String {
61 "groups".to_string()
62}
63
64#[derive(Debug, Deserialize)]
66struct TokenResponse {
67 access_token: String,
68 id_token: Option<String>,
69 #[allow(dead_code)]
70 token_type: Option<String>,
71 #[allow(dead_code)]
72 expires_in: Option<u64>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SsoIdentity {
78 pub identity: String,
80 pub groups: Vec<String>,
82 pub claims: serde_json::Value,
84}
85
86#[derive(Debug, Serialize, Deserialize)]
88struct CachedSsoSession {
89 identity: SsoIdentity,
90 expires_at: i64,
91}
92
93fn cache_path() -> std::path::PathBuf {
94 dirs::cache_dir()
95 .unwrap_or_else(|| std::path::PathBuf::from("."))
96 .join("audex")
97 .join("sso_session.json")
98}
99
100pub fn load_cached_identity() -> Option<SsoIdentity> {
102 let path = cache_path();
103 let cached: CachedSsoSession = crate::keystore::decrypt_from_file(&path).ok()??;
104 let now = chrono::Utc::now().timestamp();
105 if now < cached.expires_at {
106 Some(cached.identity)
107 } else {
108 None
109 }
110}
111
112fn save_cached_identity(identity: &SsoIdentity, ttl_secs: u64) {
113 let cached = CachedSsoSession {
114 identity: identity.clone(),
115 expires_at: chrono::Utc::now().timestamp() + ttl_secs as i64,
116 };
117 let path = cache_path();
118 if let Some(parent) = path.parent() {
119 let _ = std::fs::create_dir_all(parent);
120 }
121 let _ = crate::keystore::encrypt_to_file(&path, &cached);
122}
123
124pub async fn authenticate(config: &SsoConfig) -> Result<SsoIdentity> {
127 if let Some(cached) = load_cached_identity() {
129 return Ok(cached);
130 }
131
132 let discovery_url = format!(
133 "{}/.well-known/openid-configuration",
134 config.issuer.trim_end_matches('/')
135 );
136
137 let client = reqwest::Client::new();
138
139 let discovery: serde_json::Value = client
141 .get(&discovery_url)
142 .send()
143 .await
144 .map_err(|e| AvError::InvalidPolicy(format!("SSO discovery failed: {}", e)))?
145 .json()
146 .await
147 .map_err(|e| AvError::InvalidPolicy(format!("SSO discovery parse failed: {}", e)))?;
148
149 let auth_endpoint = discovery["authorization_endpoint"]
150 .as_str()
151 .ok_or_else(|| {
152 AvError::InvalidPolicy("Missing authorization_endpoint in OIDC discovery".to_string())
153 })?;
154 let token_endpoint = discovery["token_endpoint"].as_str().ok_or_else(|| {
155 AvError::InvalidPolicy("Missing token_endpoint in OIDC discovery".to_string())
156 })?;
157
158 let verifier = generate_pkce_verifier();
160 let challenge = generate_pkce_challenge(&verifier);
161 let state = generate_state();
162
163 let scopes = config.scopes.join(" ");
164 let auth_url = format!(
165 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}&code_challenge={}&code_challenge_method=S256",
166 auth_endpoint,
167 urlencoding::encode(&config.client_id),
168 urlencoding::encode(&config.redirect_uri),
169 urlencoding::encode(&scopes),
170 urlencoding::encode(&state),
171 urlencoding::encode(&challenge),
172 );
173
174 eprintln!("Opening browser for SSO login...");
175 eprintln!("If the browser doesn't open, visit:\n {}\n", auth_url);
176
177 let _ = open_browser(&auth_url);
179
180 let code = listen_for_callback(&config.redirect_uri, &state).await?;
182
183 let mut token_params = vec![
185 ("grant_type", "authorization_code".to_string()),
186 ("code", code),
187 ("redirect_uri", config.redirect_uri.clone()),
188 ("client_id", config.client_id.clone()),
189 ("code_verifier", verifier),
190 ];
191 if let Some(ref secret) = config.client_secret {
192 token_params.push(("client_secret", secret.clone()));
193 }
194
195 let token_resp: TokenResponse = client
196 .post(token_endpoint)
197 .form(&token_params)
198 .send()
199 .await
200 .map_err(|e| AvError::InvalidPolicy(format!("SSO token exchange failed: {}", e)))?
201 .json()
202 .await
203 .map_err(|e| AvError::InvalidPolicy(format!("SSO token parse failed: {}", e)))?;
204
205 let claims = if let Some(ref id_token) = token_resp.id_token {
207 decode_jwt_claims(id_token)?
208 } else {
209 let userinfo_url = discovery["userinfo_endpoint"].as_str().ok_or_else(|| {
211 AvError::InvalidPolicy("No id_token and no userinfo_endpoint".to_string())
212 })?;
213
214 client
215 .get(userinfo_url)
216 .bearer_auth(&token_resp.access_token)
217 .send()
218 .await
219 .map_err(|e| AvError::InvalidPolicy(format!("SSO userinfo failed: {}", e)))?
220 .json()
221 .await
222 .map_err(|e| AvError::InvalidPolicy(format!("SSO userinfo parse failed: {}", e)))?
223 };
224
225 let identity_str = claims[&config.identity_claim]
226 .as_str()
227 .unwrap_or("unknown")
228 .to_string();
229
230 let groups: Vec<String> = claims[&config.groups_claim]
231 .as_array()
232 .map(|arr| {
233 arr.iter()
234 .filter_map(|v| v.as_str().map(|s| s.to_string()))
235 .collect()
236 })
237 .unwrap_or_default();
238
239 let sso_identity = SsoIdentity {
240 identity: identity_str,
241 groups,
242 claims,
243 };
244
245 save_cached_identity(&sso_identity, 3600);
247
248 Ok(sso_identity)
249}
250
251fn decode_jwt_claims(token: &str) -> Result<serde_json::Value> {
253 let parts: Vec<&str> = token.split('.').collect();
254 if parts.len() != 3 {
255 return Err(AvError::InvalidPolicy("Invalid JWT format".to_string()));
256 }
257
258 use base64::Engine;
259 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
260 .decode(parts[1])
261 .map_err(|e| AvError::InvalidPolicy(format!("JWT base64 decode failed: {}", e)))?;
262
263 serde_json::from_slice(&payload)
264 .map_err(|e| AvError::InvalidPolicy(format!("JWT claims parse failed: {}", e)))
265}
266
267fn generate_pkce_verifier() -> String {
269 use rand::Rng;
270 let mut rng = rand::rng();
271 let bytes: Vec<u8> = (0..32).map(|_| rng.random::<u8>()).collect();
272 use base64::Engine;
273 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes)
274}
275
276fn generate_pkce_challenge(verifier: &str) -> String {
278 use sha2::Digest;
279 let hash = sha2::Sha256::digest(verifier.as_bytes());
280 use base64::Engine;
281 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
282}
283
284fn generate_state() -> String {
286 use rand::Rng;
287 let mut rng = rand::rng();
288 let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
289 hex::encode(bytes)
290}
291
292fn open_browser(url: &str) -> std::io::Result<()> {
294 #[cfg(target_os = "macos")]
295 {
296 std::process::Command::new("open").arg(url).spawn()?;
297 }
298 #[cfg(target_os = "linux")]
299 {
300 std::process::Command::new("xdg-open").arg(url).spawn()?;
301 }
302 #[cfg(target_os = "windows")]
303 {
304 std::process::Command::new("cmd")
305 .args(["/c", "start", url])
306 .spawn()?;
307 }
308 Ok(())
309}
310
311async fn listen_for_callback(redirect_uri: &str, expected_state: &str) -> Result<String> {
313 use tokio::io::AsyncReadExt;
314 use tokio::io::AsyncWriteExt;
315
316 let url: url::Url = redirect_uri
318 .parse()
319 .map_err(|e| AvError::InvalidPolicy(format!("Invalid redirect URI: {}", e)))?;
320 let port = url.port().unwrap_or(8400);
321
322 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
323 .await
324 .map_err(|e| {
325 AvError::InvalidPolicy(format!(
326 "Failed to start callback listener on port {}: {}",
327 port, e
328 ))
329 })?;
330
331 eprintln!("Waiting for SSO callback on port {}...", port);
332
333 let (mut stream, _) = listener
334 .accept()
335 .await
336 .map_err(|e| AvError::InvalidPolicy(format!("Failed to accept callback: {}", e)))?;
337
338 let mut buf = vec![0u8; 4096];
339 let n = stream
340 .read(&mut buf)
341 .await
342 .map_err(|e| AvError::InvalidPolicy(format!("Failed to read callback: {}", e)))?;
343
344 let request = String::from_utf8_lossy(&buf[..n]);
345
346 let first_line = request.lines().next().unwrap_or("");
348 let path = first_line.split_whitespace().nth(1).unwrap_or("/");
349 let query_url = format!("http://localhost{}", path);
350 let parsed: url::Url = query_url
351 .parse()
352 .map_err(|_| AvError::InvalidPolicy("Failed to parse callback URL".to_string()))?;
353
354 let params: std::collections::HashMap<String, String> = parsed
355 .query_pairs()
356 .map(|(k, v)| (k.to_string(), v.to_string()))
357 .collect();
358
359 let state = params.get("state").cloned().unwrap_or_default();
361 if state != expected_state {
362 let response = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>State mismatch. Please try again.</p>";
363 let _ = stream.write_all(response.as_bytes()).await;
364 return Err(AvError::InvalidPolicy(
365 "SSO state mismatch — possible CSRF".to_string(),
366 ));
367 }
368
369 if let Some(error) = params.get("error") {
371 let desc = params.get("error_description").cloned().unwrap_or_default();
372 let response = format!("HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>{}: {}</p>", error, desc);
373 let _ = stream.write_all(response.as_bytes()).await;
374 return Err(AvError::InvalidPolicy(format!(
375 "SSO error: {}: {}",
376 error, desc
377 )));
378 }
379
380 let code = params
381 .get("code")
382 .cloned()
383 .ok_or_else(|| AvError::InvalidPolicy("No authorization code in callback".to_string()))?;
384
385 let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<h1>Audex SSO Login Successful</h1><p>You can close this window and return to the terminal.</p>";
387 let _ = stream.write_all(response.as_bytes()).await;
388
389 Ok(code)
390}
391
392pub fn logout() {
394 let path = cache_path();
395 let _ = std::fs::remove_file(path);
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[test]
403 fn test_sso_config_deserialize() {
404 let toml_str = r#"
405provider = "okta"
406issuer = "https://myorg.okta.com"
407client_id = "0oa1234567890"
408scopes = ["openid", "email", "groups"]
409identity_claim = "email"
410groups_claim = "groups"
411required = true
412"#;
413 let config: SsoConfig = toml::from_str(toml_str).unwrap();
414 assert!(matches!(config.provider, SsoProvider::Okta));
415 assert_eq!(config.issuer, "https://myorg.okta.com");
416 assert_eq!(config.client_id, "0oa1234567890");
417 assert!(config.required);
418 assert_eq!(config.scopes.len(), 3);
419 }
420
421 #[test]
422 fn test_sso_config_google() {
423 let toml_str = r#"
424provider = "google"
425issuer = "https://accounts.google.com"
426client_id = "12345.apps.googleusercontent.com"
427client_secret = "GOCSPX-secret"
428"#;
429 let config: SsoConfig = toml::from_str(toml_str).unwrap();
430 assert!(matches!(config.provider, SsoProvider::Google));
431 assert!(config.client_secret.is_some());
432 assert_eq!(config.identity_claim, "email");
434 assert_eq!(config.groups_claim, "groups");
435 assert_eq!(config.redirect_uri, "http://localhost:8400/callback");
436 assert!(!config.required);
437 }
438
439 #[test]
440 fn test_sso_identity_serialize() {
441 let identity = SsoIdentity {
442 identity: "alice@example.com".to_string(),
443 groups: vec!["backend-team".to_string(), "devops".to_string()],
444 claims: serde_json::json!({"email": "alice@example.com", "sub": "user123"}),
445 };
446 let json = serde_json::to_string(&identity).unwrap();
447 assert!(json.contains("alice@example.com"));
448 assert!(json.contains("backend-team"));
449 let parsed: SsoIdentity = serde_json::from_str(&json).unwrap();
451 assert_eq!(parsed.identity, "alice@example.com");
452 assert_eq!(parsed.groups.len(), 2);
453 }
454
455 #[test]
456 fn test_pkce_verifier_and_challenge() {
457 let verifier = generate_pkce_verifier();
458 assert!(!verifier.is_empty());
459 let challenge = generate_pkce_challenge(&verifier);
460 assert!(!challenge.is_empty());
461 assert_ne!(verifier, challenge);
462 }
463
464 #[test]
465 fn test_state_generation() {
466 let state1 = generate_state();
467 let state2 = generate_state();
468 assert_ne!(state1, state2);
469 assert_eq!(state1.len(), 32); }
471
472 #[test]
473 fn test_decode_jwt_claims() {
474 use base64::Engine;
476 let claims = serde_json::json!({"email": "test@example.com", "groups": ["admin"]});
477 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
478 .encode(serde_json::to_vec(&claims).unwrap());
479 let fake_jwt = format!("header.{}.signature", payload);
480
481 let decoded = decode_jwt_claims(&fake_jwt).unwrap();
482 assert_eq!(decoded["email"], "test@example.com");
483 assert_eq!(decoded["groups"][0], "admin");
484 }
485
486 #[test]
487 fn test_decode_jwt_invalid() {
488 assert!(decode_jwt_claims("not-a-jwt").is_err());
489 assert!(decode_jwt_claims("a.b").is_err());
490 }
491}