1use serde::{Deserialize, Serialize};
2
3use crate::error::{AvError, Result};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum SsoProvider {
9 Okta,
10 Auth0,
11 Google,
12 Oidc,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SsoConfig {
19 pub provider: SsoProvider,
21 pub issuer: String,
23 pub client_id: String,
25 pub client_secret: Option<String>,
27 #[serde(default = "default_redirect")]
29 pub redirect_uri: String,
30 #[serde(default = "default_scopes")]
32 pub scopes: Vec<String>,
33 #[serde(default = "default_identity_claim")]
35 pub identity_claim: String,
36 #[serde(default = "default_groups_claim")]
38 pub groups_claim: String,
39 #[serde(default)]
41 pub required: bool,
42}
43
44fn default_redirect() -> String {
45 "http://localhost:8400/callback".to_string()
46}
47
48fn default_scopes() -> Vec<String> {
49 vec!["openid".to_string(), "email".to_string(), "profile".to_string()]
50}
51
52fn default_identity_claim() -> String {
53 "email".to_string()
54}
55
56fn default_groups_claim() -> String {
57 "groups".to_string()
58}
59
60#[derive(Debug, Deserialize)]
62struct TokenResponse {
63 access_token: String,
64 id_token: Option<String>,
65 #[allow(dead_code)]
66 token_type: Option<String>,
67 #[allow(dead_code)]
68 expires_in: Option<u64>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct SsoIdentity {
74 pub identity: String,
76 pub groups: Vec<String>,
78 pub claims: serde_json::Value,
80}
81
82#[derive(Debug, Serialize, Deserialize)]
84struct CachedSsoSession {
85 identity: SsoIdentity,
86 expires_at: i64,
87}
88
89fn cache_path() -> std::path::PathBuf {
90 dirs::cache_dir()
91 .unwrap_or_else(|| std::path::PathBuf::from("."))
92 .join("audex")
93 .join("sso_session.json")
94}
95
96pub fn load_cached_identity() -> Option<SsoIdentity> {
98 let path = cache_path();
99 let cached: CachedSsoSession = crate::keystore::decrypt_from_file(&path).ok()??;
100 let now = chrono::Utc::now().timestamp();
101 if now < cached.expires_at {
102 Some(cached.identity)
103 } else {
104 None
105 }
106}
107
108fn save_cached_identity(identity: &SsoIdentity, ttl_secs: u64) {
109 let cached = CachedSsoSession {
110 identity: identity.clone(),
111 expires_at: chrono::Utc::now().timestamp() + ttl_secs as i64,
112 };
113 let path = cache_path();
114 if let Some(parent) = path.parent() {
115 let _ = std::fs::create_dir_all(parent);
116 }
117 let _ = crate::keystore::encrypt_to_file(&path, &cached);
118}
119
120pub async fn authenticate(config: &SsoConfig) -> Result<SsoIdentity> {
123 if let Some(cached) = load_cached_identity() {
125 return Ok(cached);
126 }
127
128 let discovery_url = format!(
129 "{}/.well-known/openid-configuration",
130 config.issuer.trim_end_matches('/')
131 );
132
133 let client = reqwest::Client::new();
134
135 let discovery: serde_json::Value = client
137 .get(&discovery_url)
138 .send()
139 .await
140 .map_err(|e| AvError::InvalidPolicy(format!("SSO discovery failed: {}", e)))?
141 .json()
142 .await
143 .map_err(|e| AvError::InvalidPolicy(format!("SSO discovery parse failed: {}", e)))?;
144
145 let auth_endpoint = discovery["authorization_endpoint"]
146 .as_str()
147 .ok_or_else(|| AvError::InvalidPolicy("Missing authorization_endpoint in OIDC discovery".to_string()))?;
148 let token_endpoint = discovery["token_endpoint"]
149 .as_str()
150 .ok_or_else(|| AvError::InvalidPolicy("Missing token_endpoint in OIDC discovery".to_string()))?;
151
152 let verifier = generate_pkce_verifier();
154 let challenge = generate_pkce_challenge(&verifier);
155 let state = generate_state();
156
157 let scopes = config.scopes.join(" ");
158 let auth_url = format!(
159 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}&code_challenge={}&code_challenge_method=S256",
160 auth_endpoint,
161 urlencoding::encode(&config.client_id),
162 urlencoding::encode(&config.redirect_uri),
163 urlencoding::encode(&scopes),
164 urlencoding::encode(&state),
165 urlencoding::encode(&challenge),
166 );
167
168 eprintln!("Opening browser for SSO login...");
169 eprintln!("If the browser doesn't open, visit:\n {}\n", auth_url);
170
171 let _ = open_browser(&auth_url);
173
174 let code = listen_for_callback(&config.redirect_uri, &state).await?;
176
177 let mut token_params = vec![
179 ("grant_type", "authorization_code".to_string()),
180 ("code", code),
181 ("redirect_uri", config.redirect_uri.clone()),
182 ("client_id", config.client_id.clone()),
183 ("code_verifier", verifier),
184 ];
185 if let Some(ref secret) = config.client_secret {
186 token_params.push(("client_secret", secret.clone()));
187 }
188
189 let token_resp: TokenResponse = client
190 .post(token_endpoint)
191 .form(&token_params)
192 .send()
193 .await
194 .map_err(|e| AvError::InvalidPolicy(format!("SSO token exchange failed: {}", e)))?
195 .json()
196 .await
197 .map_err(|e| AvError::InvalidPolicy(format!("SSO token parse failed: {}", e)))?;
198
199 let claims = if let Some(ref id_token) = token_resp.id_token {
201 decode_jwt_claims(id_token)?
202 } else {
203 let userinfo_url = discovery["userinfo_endpoint"]
205 .as_str()
206 .ok_or_else(|| AvError::InvalidPolicy("No id_token and no userinfo_endpoint".to_string()))?;
207
208 client
209 .get(userinfo_url)
210 .bearer_auth(&token_resp.access_token)
211 .send()
212 .await
213 .map_err(|e| AvError::InvalidPolicy(format!("SSO userinfo failed: {}", e)))?
214 .json()
215 .await
216 .map_err(|e| AvError::InvalidPolicy(format!("SSO userinfo parse failed: {}", e)))?
217 };
218
219 let identity_str = claims[&config.identity_claim]
220 .as_str()
221 .unwrap_or("unknown")
222 .to_string();
223
224 let groups: Vec<String> = claims[&config.groups_claim]
225 .as_array()
226 .map(|arr| {
227 arr.iter()
228 .filter_map(|v| v.as_str().map(|s| s.to_string()))
229 .collect()
230 })
231 .unwrap_or_default();
232
233 let sso_identity = SsoIdentity {
234 identity: identity_str,
235 groups,
236 claims,
237 };
238
239 save_cached_identity(&sso_identity, 3600);
241
242 Ok(sso_identity)
243}
244
245fn decode_jwt_claims(token: &str) -> Result<serde_json::Value> {
247 let parts: Vec<&str> = token.split('.').collect();
248 if parts.len() != 3 {
249 return Err(AvError::InvalidPolicy("Invalid JWT format".to_string()));
250 }
251
252 use base64::Engine;
253 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
254 .decode(parts[1])
255 .map_err(|e| AvError::InvalidPolicy(format!("JWT base64 decode failed: {}", e)))?;
256
257 serde_json::from_slice(&payload)
258 .map_err(|e| AvError::InvalidPolicy(format!("JWT claims parse failed: {}", e)))
259}
260
261fn generate_pkce_verifier() -> String {
263 use rand::Rng;
264 let mut rng = rand::rng();
265 let bytes: Vec<u8> = (0..32).map(|_| rng.random::<u8>()).collect();
266 use base64::Engine;
267 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes)
268}
269
270fn generate_pkce_challenge(verifier: &str) -> String {
272 use sha2::Digest;
273 let hash = sha2::Sha256::digest(verifier.as_bytes());
274 use base64::Engine;
275 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
276}
277
278fn generate_state() -> String {
280 use rand::Rng;
281 let mut rng = rand::rng();
282 let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
283 hex::encode(bytes)
284}
285
286fn open_browser(url: &str) -> std::io::Result<()> {
288 #[cfg(target_os = "macos")]
289 {
290 std::process::Command::new("open").arg(url).spawn()?;
291 }
292 #[cfg(target_os = "linux")]
293 {
294 std::process::Command::new("xdg-open").arg(url).spawn()?;
295 }
296 #[cfg(target_os = "windows")]
297 {
298 std::process::Command::new("cmd").args(["/c", "start", url]).spawn()?;
299 }
300 Ok(())
301}
302
303async fn listen_for_callback(redirect_uri: &str, expected_state: &str) -> Result<String> {
305 use tokio::io::AsyncReadExt;
306 use tokio::io::AsyncWriteExt;
307
308 let url: url::Url = redirect_uri
310 .parse()
311 .map_err(|e| AvError::InvalidPolicy(format!("Invalid redirect URI: {}", e)))?;
312 let port = url.port().unwrap_or(8400);
313
314 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
315 .await
316 .map_err(|e| AvError::InvalidPolicy(format!("Failed to start callback listener on port {}: {}", port, e)))?;
317
318 eprintln!("Waiting for SSO callback on port {}...", port);
319
320 let (mut stream, _) = listener
321 .accept()
322 .await
323 .map_err(|e| AvError::InvalidPolicy(format!("Failed to accept callback: {}", e)))?;
324
325 let mut buf = vec![0u8; 4096];
326 let n = stream
327 .read(&mut buf)
328 .await
329 .map_err(|e| AvError::InvalidPolicy(format!("Failed to read callback: {}", e)))?;
330
331 let request = String::from_utf8_lossy(&buf[..n]);
332
333 let first_line = request.lines().next().unwrap_or("");
335 let path = first_line.split_whitespace().nth(1).unwrap_or("/");
336 let query_url = format!("http://localhost{}", path);
337 let parsed: url::Url = query_url
338 .parse()
339 .map_err(|_| AvError::InvalidPolicy("Failed to parse callback URL".to_string()))?;
340
341 let params: std::collections::HashMap<String, String> =
342 parsed.query_pairs().map(|(k, v)| (k.to_string(), v.to_string())).collect();
343
344 let state = params.get("state").cloned().unwrap_or_default();
346 if state != expected_state {
347 let response = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>State mismatch. Please try again.</p>";
348 let _ = stream.write_all(response.as_bytes()).await;
349 return Err(AvError::InvalidPolicy("SSO state mismatch — possible CSRF".to_string()));
350 }
351
352 if let Some(error) = params.get("error") {
354 let desc = params.get("error_description").cloned().unwrap_or_default();
355 let response = format!("HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>{}: {}</p>", error, desc);
356 let _ = stream.write_all(response.as_bytes()).await;
357 return Err(AvError::InvalidPolicy(format!("SSO error: {}: {}", error, desc)));
358 }
359
360 let code = params.get("code").cloned().ok_or_else(|| {
361 AvError::InvalidPolicy("No authorization code in callback".to_string())
362 })?;
363
364 let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<h1>Audex SSO Login Successful</h1><p>You can close this window and return to the terminal.</p>";
366 let _ = stream.write_all(response.as_bytes()).await;
367
368 Ok(code)
369}
370
371pub fn logout() {
373 let path = cache_path();
374 let _ = std::fs::remove_file(path);
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_sso_config_deserialize() {
383 let toml_str = r#"
384provider = "okta"
385issuer = "https://myorg.okta.com"
386client_id = "0oa1234567890"
387scopes = ["openid", "email", "groups"]
388identity_claim = "email"
389groups_claim = "groups"
390required = true
391"#;
392 let config: SsoConfig = toml::from_str(toml_str).unwrap();
393 assert!(matches!(config.provider, SsoProvider::Okta));
394 assert_eq!(config.issuer, "https://myorg.okta.com");
395 assert_eq!(config.client_id, "0oa1234567890");
396 assert!(config.required);
397 assert_eq!(config.scopes.len(), 3);
398 }
399
400 #[test]
401 fn test_sso_config_google() {
402 let toml_str = r#"
403provider = "google"
404issuer = "https://accounts.google.com"
405client_id = "12345.apps.googleusercontent.com"
406client_secret = "GOCSPX-secret"
407"#;
408 let config: SsoConfig = toml::from_str(toml_str).unwrap();
409 assert!(matches!(config.provider, SsoProvider::Google));
410 assert!(config.client_secret.is_some());
411 assert_eq!(config.identity_claim, "email");
413 assert_eq!(config.groups_claim, "groups");
414 assert_eq!(config.redirect_uri, "http://localhost:8400/callback");
415 assert!(!config.required);
416 }
417
418 #[test]
419 fn test_sso_identity_serialize() {
420 let identity = SsoIdentity {
421 identity: "alice@example.com".to_string(),
422 groups: vec!["backend-team".to_string(), "devops".to_string()],
423 claims: serde_json::json!({"email": "alice@example.com", "sub": "user123"}),
424 };
425 let json = serde_json::to_string(&identity).unwrap();
426 assert!(json.contains("alice@example.com"));
427 assert!(json.contains("backend-team"));
428 let parsed: SsoIdentity = serde_json::from_str(&json).unwrap();
430 assert_eq!(parsed.identity, "alice@example.com");
431 assert_eq!(parsed.groups.len(), 2);
432 }
433
434 #[test]
435 fn test_pkce_verifier_and_challenge() {
436 let verifier = generate_pkce_verifier();
437 assert!(!verifier.is_empty());
438 let challenge = generate_pkce_challenge(&verifier);
439 assert!(!challenge.is_empty());
440 assert_ne!(verifier, challenge);
441 }
442
443 #[test]
444 fn test_state_generation() {
445 let state1 = generate_state();
446 let state2 = generate_state();
447 assert_ne!(state1, state2);
448 assert_eq!(state1.len(), 32); }
450
451 #[test]
452 fn test_decode_jwt_claims() {
453 use base64::Engine;
455 let claims = serde_json::json!({"email": "test@example.com", "groups": ["admin"]});
456 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
457 .encode(serde_json::to_vec(&claims).unwrap());
458 let fake_jwt = format!("header.{}.signature", payload);
459
460 let decoded = decode_jwt_claims(&fake_jwt).unwrap();
461 assert_eq!(decoded["email"], "test@example.com");
462 assert_eq!(decoded["groups"][0], "admin");
463 }
464
465 #[test]
466 fn test_decode_jwt_invalid() {
467 assert!(decode_jwt_claims("not-a-jwt").is_err());
468 assert!(decode_jwt_claims("a.b").is_err());
469 }
470}