1use serde::{Deserialize, Serialize};
2
3use crate::error::{AvError, Result};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum SsoProvider {
9 Okta,
10 Auth0,
11 Google,
12 Oidc,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct SsoConfig {
19 pub provider: SsoProvider,
21 pub issuer: String,
23 pub client_id: String,
25 pub client_secret: Option<String>,
27 #[serde(default = "default_redirect")]
29 pub redirect_uri: String,
30 #[serde(default = "default_scopes")]
32 pub scopes: Vec<String>,
33 #[serde(default = "default_identity_claim")]
35 pub identity_claim: String,
36 #[serde(default = "default_groups_claim")]
38 pub groups_claim: String,
39 #[serde(default)]
41 pub required: bool,
42}
43
44fn default_redirect() -> String {
45 "http://localhost:8400/callback".to_string()
46}
47
48fn default_scopes() -> Vec<String> {
49 vec![
50 "openid".to_string(),
51 "email".to_string(),
52 "profile".to_string(),
53 ]
54}
55
56fn default_identity_claim() -> String {
57 "email".to_string()
58}
59
60fn default_groups_claim() -> String {
61 "groups".to_string()
62}
63
64#[derive(Debug, Deserialize)]
66struct TokenResponse {
67 access_token: String,
68 id_token: Option<String>,
69 #[allow(dead_code)]
70 token_type: Option<String>,
71 #[allow(dead_code)]
72 expires_in: Option<u64>,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SsoIdentity {
78 pub identity: String,
80 pub groups: Vec<String>,
82 pub claims: serde_json::Value,
84}
85
86#[derive(Debug, Serialize, Deserialize)]
88struct CachedSsoSession {
89 identity: SsoIdentity,
90 expires_at: i64,
91 #[serde(default)]
93 issuer: String,
94 #[serde(default)]
96 client_id: String,
97 #[serde(default)]
99 machine_id: String,
100}
101
102fn read_machine_id() -> String {
105 if let Ok(id) = std::fs::read_to_string("/etc/machine-id") {
107 let trimmed = id.trim().to_string();
108 if !trimmed.is_empty() {
109 return trimmed;
110 }
111 }
112 #[cfg(target_os = "macos")]
114 {
115 if let Ok(out) = std::process::Command::new("ioreg")
116 .args(["-rd1", "-c", "IOPlatformExpertDevice"])
117 .output()
118 {
119 let text = String::from_utf8_lossy(&out.stdout);
120 for line in text.lines() {
121 if line.contains("IOPlatformUUID") {
122 if let Some(start) = line.rfind('"') {
123 let rest = &line[..start];
124 if let Some(end) = rest.rfind('"') {
125 return rest[end + 1..].to_string();
126 }
127 }
128 }
129 }
130 }
131 }
132 String::new()
133}
134
135fn cache_path() -> std::path::PathBuf {
136 dirs::cache_dir()
137 .unwrap_or_else(|| std::path::PathBuf::from("."))
138 .join("audex")
139 .join("sso_session.json")
140}
141
142pub fn load_cached_identity(issuer: &str, client_id: &str) -> Option<SsoIdentity> {
144 let path = cache_path();
145 let cached: CachedSsoSession = crate::keystore::decrypt_from_file(&path).ok()??;
146 let now = chrono::Utc::now().timestamp();
147 if now >= cached.expires_at {
148 return None;
149 }
150 if cached.issuer != issuer || cached.client_id != client_id {
152 tracing::info!("SSO cache issuer/client mismatch — re-authenticating");
153 return None;
154 }
155 let current_machine_id = read_machine_id();
157 if !current_machine_id.is_empty() && cached.machine_id != current_machine_id {
158 tracing::info!("SSO cache machine-id mismatch — re-authenticating");
159 return None;
160 }
161 Some(cached.identity)
162}
163
164fn save_cached_identity(identity: &SsoIdentity, ttl_secs: u64, issuer: &str, client_id: &str) {
165 let cached = CachedSsoSession {
166 identity: identity.clone(),
167 expires_at: chrono::Utc::now().timestamp() + ttl_secs as i64,
168 issuer: issuer.to_string(),
169 client_id: client_id.to_string(),
170 machine_id: read_machine_id(),
171 };
172 let path = cache_path();
173 if let Some(parent) = path.parent() {
174 let _ = std::fs::create_dir_all(parent);
175 }
176 if let Err(e) = crate::keystore::encrypt_to_file(&path, &cached) {
177 tracing::warn!(error = %e, "Failed to cache SSO identity — next login will require re-authentication");
178 }
179}
180
181pub async fn authenticate(config: &SsoConfig) -> Result<SsoIdentity> {
184 if let Some(cached) = load_cached_identity(&config.issuer, &config.client_id) {
186 return Ok(cached);
187 }
188
189 let discovery_url = format!(
190 "{}/.well-known/openid-configuration",
191 config.issuer.trim_end_matches('/')
192 );
193
194 let discovery_client = reqwest::Client::builder()
206 .redirect(reqwest::redirect::Policy::none())
207 .timeout(std::time::Duration::from_secs(30))
208 .connect_timeout(std::time::Duration::from_secs(10))
209 .build()
210 .map_err(|e| AvError::InvalidPolicy(format!("SSO client build failed: {}", e)))?;
211
212 let discovery: serde_json::Value = fetch_json_capped(
217 &discovery_client,
218 &discovery_url,
219 None,
220 1024 * 1024,
221 "SSO discovery",
222 )
223 .await?;
224
225 let discovery_issuer = discovery["issuer"].as_str().ok_or_else(|| {
233 AvError::InvalidPolicy("OIDC discovery document missing `issuer` field".to_string())
234 })?;
235 if discovery_issuer.trim_end_matches('/') != config.issuer.trim_end_matches('/') {
236 return Err(AvError::InvalidPolicy(format!(
237 "OIDC discovery issuer mismatch: configured={}, received={}. \
238 This can indicate a redirect on the discovery URL or a misconfigured IdP.",
239 config.issuer, discovery_issuer
240 )));
241 }
242
243 let client = reqwest::Client::builder()
245 .timeout(std::time::Duration::from_secs(30))
246 .connect_timeout(std::time::Duration::from_secs(10))
247 .build()
248 .map_err(|e| AvError::InvalidPolicy(format!("SSO client build failed: {}", e)))?;
249
250 let auth_endpoint = discovery["authorization_endpoint"]
251 .as_str()
252 .ok_or_else(|| {
253 AvError::InvalidPolicy("Missing authorization_endpoint in OIDC discovery".to_string())
254 })?;
255 let token_endpoint = discovery["token_endpoint"].as_str().ok_or_else(|| {
256 AvError::InvalidPolicy("Missing token_endpoint in OIDC discovery".to_string())
257 })?;
258
259 let verifier = generate_pkce_verifier();
261 let challenge = generate_pkce_challenge(&verifier);
262 let state = generate_state();
263 let nonce = generate_state(); let scopes = config.scopes.join(" ");
266 let auth_url = format!(
267 "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}&nonce={}&code_challenge={}&code_challenge_method=S256",
268 auth_endpoint,
269 urlencoding::encode(&config.client_id),
270 urlencoding::encode(&config.redirect_uri),
271 urlencoding::encode(&scopes),
272 urlencoding::encode(&state),
273 urlencoding::encode(&nonce),
274 urlencoding::encode(&challenge),
275 );
276
277 eprintln!("Opening browser for SSO login...");
278 eprintln!("If the browser doesn't open, visit:\n {}\n", auth_url);
279
280 let _ = open_browser(&auth_url);
282
283 let code = listen_for_callback(&config.redirect_uri, &state).await?;
285
286 let mut token_params = vec![
288 ("grant_type", "authorization_code".to_string()),
289 ("code", code),
290 ("redirect_uri", config.redirect_uri.clone()),
291 ("client_id", config.client_id.clone()),
292 ("code_verifier", verifier),
293 ];
294 if let Some(ref secret) = config.client_secret {
295 token_params.push(("client_secret", secret.clone()));
296 }
297
298 let token_resp: TokenResponse = client
299 .post(token_endpoint)
300 .form(&token_params)
301 .send()
302 .await
303 .map_err(|e| AvError::InvalidPolicy(format!("SSO token exchange failed: {}", e)))?
304 .json()
305 .await
306 .map_err(|e| AvError::InvalidPolicy(format!("SSO token parse failed: {}", e)))?;
307
308 let claims = if let Some(ref id_token) = token_resp.id_token {
321 let jwks_uri = discovery["jwks_uri"].as_str().unwrap_or_default();
322 if jwks_uri.is_empty() {
323 return Err(AvError::InvalidPolicy(
324 "OIDC discovery document is missing `jwks_uri`; refusing to accept an \
325 id_token without signature verification. If the provider genuinely does \
326 not expose JWKS, configure audex to use the userinfo endpoint instead."
327 .to_string(),
328 ));
329 }
330 decode_and_verify_jwt(
331 &client,
332 jwks_uri,
333 id_token,
334 &config.issuer,
335 &config.client_id,
336 )
337 .await?
338 } else {
339 let userinfo_url = discovery["userinfo_endpoint"].as_str().ok_or_else(|| {
341 AvError::InvalidPolicy("No id_token and no userinfo_endpoint".to_string())
342 })?;
343
344 fetch_json_capped(
346 &client,
347 userinfo_url,
348 Some(&token_resp.access_token),
349 1024 * 1024,
350 "SSO userinfo",
351 )
352 .await?
353 };
354
355 if token_resp.id_token.is_some() {
362 let token_nonce = claims
363 .get("nonce")
364 .and_then(|v| v.as_str())
365 .ok_or_else(|| {
366 AvError::InvalidPolicy(
367 "SSO id_token is missing required `nonce` claim — \
368 rejecting to prevent replay (OIDC Core §3.1.3.7)"
369 .to_string(),
370 )
371 })?;
372 if token_nonce != nonce {
373 return Err(AvError::InvalidPolicy(
374 "SSO nonce mismatch — possible token replay".to_string(),
375 ));
376 }
377 }
378
379 let identity_str = claims[&config.identity_claim]
385 .as_str()
386 .ok_or_else(|| {
387 AvError::InvalidPolicy(format!(
388 "SSO identity claim `{}` is missing or not a string. \
389 Check the identity_claim configuration against the IdP's \
390 token/userinfo shape.",
391 config.identity_claim
392 ))
393 })?
394 .to_string();
395
396 let groups: Vec<String> = claims[&config.groups_claim]
397 .as_array()
398 .map(|arr| {
399 arr.iter()
400 .filter_map(|v| v.as_str().map(|s| s.to_string()))
401 .collect()
402 })
403 .unwrap_or_default();
404
405 let sso_identity = SsoIdentity {
406 identity: identity_str,
407 groups,
408 claims,
409 };
410
411 save_cached_identity(&sso_identity, 3600, &config.issuer, &config.client_id);
413
414 Ok(sso_identity)
415}
416
417async fn decode_and_verify_jwt(
425 client: &reqwest::Client,
426 jwks_uri: &str,
427 token: &str,
428 expected_issuer: &str,
429 expected_audience: &str,
430) -> Result<serde_json::Value> {
431 use jsonwebtoken::{decode, DecodingKey, Validation};
432
433 let parts: Vec<&str> = token.split('.').collect();
434 if parts.len() != 3 {
435 return Err(AvError::InvalidPolicy("Invalid JWT format".to_string()));
436 }
437
438 if !jwks_uri.is_empty() {
440 let header = jsonwebtoken::decode_header(token)
441 .map_err(|e| AvError::InvalidPolicy(format!("JWT header decode failed: {}", e)))?;
442
443 let algorithm = header.alg;
444
445 const ALLOWED_ALGORITHMS: &[jsonwebtoken::Algorithm] = &[
449 jsonwebtoken::Algorithm::RS256,
450 jsonwebtoken::Algorithm::RS384,
451 jsonwebtoken::Algorithm::RS512,
452 jsonwebtoken::Algorithm::ES256,
453 jsonwebtoken::Algorithm::ES384,
454 jsonwebtoken::Algorithm::PS256,
455 jsonwebtoken::Algorithm::PS384,
456 jsonwebtoken::Algorithm::PS512,
457 jsonwebtoken::Algorithm::EdDSA,
458 ];
459 if !ALLOWED_ALGORITHMS.contains(&algorithm) {
460 return Err(AvError::InvalidPolicy(format!(
461 "JWT algorithm {:?} is not in the allowed set of asymmetric algorithms",
462 algorithm
463 )));
464 }
465
466 let jwks_resp: serde_json::Value =
470 fetch_json_capped(client, jwks_uri, None, 1024 * 1024, "JWKS").await?;
471
472 let jwk_set: jsonwebtoken::jwk::JwkSet = serde_json::from_value(jwks_resp)
474 .map_err(|e| AvError::InvalidPolicy(format!("JWKS parse into JwkSet failed: {}", e)))?;
475
476 let kid = header.kid.as_deref().ok_or_else(|| {
483 AvError::InvalidPolicy(
484 "JWT header is missing `kid` — refusing to pick an \
485 arbitrary key from the JWKS. Reject tokens without kid \
486 to prevent key-confusion attacks."
487 .to_string(),
488 )
489 })?;
490 let matching_key = jwk_set
491 .keys
492 .iter()
493 .find(|k| k.common.key_id.as_deref() == Some(kid));
494
495 if let Some(jwk) = matching_key {
496 let decoding_key = DecodingKey::from_jwk(jwk)
497 .map_err(|e| AvError::InvalidPolicy(format!("JWK decode failed: {}", e)))?;
498
499 let mut validation = Validation::new(algorithm);
500 validation.set_issuer(&[expected_issuer]);
501 validation.set_audience(&[expected_audience]);
502
503 let token_data = decode::<serde_json::Value>(token, &decoding_key, &validation)
504 .map_err(|e| {
505 AvError::InvalidPolicy(format!("JWT signature verification failed: {}", e))
506 })?;
507
508 return Ok(token_data.claims);
509 }
510
511 return Err(AvError::InvalidPolicy(format!(
512 "No matching JWK found for token kid={:?}",
513 kid
514 )));
515 }
516
517 Err(AvError::InvalidPolicy(
524 "JWT signature verification cannot be skipped: jwks_uri is empty. \
525 This is a programming error — callers must reject missing jwks_uri \
526 before invoking decode_and_verify_jwt."
527 .to_string(),
528 ))
529}
530
531#[cfg(test)]
538fn decode_jwt_claims_unverified(
539 token: &str,
540 expected_issuer: &str,
541 expected_audience: &str,
542) -> Result<serde_json::Value> {
543 let parts: Vec<&str> = token.split('.').collect();
544 if parts.len() != 3 {
545 return Err(AvError::InvalidPolicy("Invalid JWT format".to_string()));
546 }
547
548 use base64::Engine;
549 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
550 .decode(parts[1])
551 .map_err(|e| AvError::InvalidPolicy(format!("JWT base64 decode failed: {}", e)))?;
552
553 let claims: serde_json::Value = serde_json::from_slice(&payload)
554 .map_err(|e| AvError::InvalidPolicy(format!("JWT claims parse failed: {}", e)))?;
555
556 if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
558 let now = chrono::Utc::now().timestamp();
559 if now > exp {
560 return Err(AvError::InvalidPolicy(format!(
561 "SSO token expired at {} (current time: {})",
562 exp, now
563 )));
564 }
565 } else {
566 tracing::warn!("SSO token has no 'exp' claim — cannot verify expiry");
567 }
568
569 match claims.get("iss").and_then(|v| v.as_str()) {
571 Some(iss) if iss == expected_issuer => {}
572 Some(iss) => {
573 return Err(AvError::InvalidPolicy(format!(
574 "SSO token issuer mismatch: got '{}', expected '{}'",
575 iss, expected_issuer
576 )));
577 }
578 None => {
579 return Err(AvError::InvalidPolicy(
580 "SSO token missing 'iss' claim".to_string(),
581 ));
582 }
583 }
584
585 let aud_valid = match claims.get("aud") {
587 Some(serde_json::Value::String(aud)) => aud == expected_audience,
588 Some(serde_json::Value::Array(arr)) => {
589 arr.iter().any(|v| v.as_str() == Some(expected_audience))
590 }
591 _ => false,
592 };
593 if !aud_valid {
594 return Err(AvError::InvalidPolicy(format!(
595 "SSO token audience does not include expected client_id '{}'",
596 expected_audience
597 )));
598 }
599
600 Ok(claims)
601}
602
603fn generate_pkce_verifier() -> String {
605 use rand::Rng;
606 let mut rng = rand::rng();
607 let bytes: Vec<u8> = (0..32).map(|_| rng.random::<u8>()).collect();
608 use base64::Engine;
609 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&bytes)
610}
611
612fn generate_pkce_challenge(verifier: &str) -> String {
614 use sha2::Digest;
615 let hash = sha2::Sha256::digest(verifier.as_bytes());
616 use base64::Engine;
617 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
618}
619
620fn generate_state() -> String {
622 use rand::Rng;
623 let mut rng = rand::rng();
624 let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
625 hex::encode(bytes)
626}
627
628fn html_escape(s: &str) -> String {
630 let mut out = String::with_capacity(s.len());
631 for ch in s.chars() {
632 match ch {
633 '&' => out.push_str("&"),
634 '<' => out.push_str("<"),
635 '>' => out.push_str(">"),
636 '"' => out.push_str("""),
637 _ => out.push(ch),
638 }
639 }
640 out
641}
642
643fn open_browser(url: &str) -> std::io::Result<()> {
645 #[cfg(target_os = "macos")]
646 {
647 std::process::Command::new("open").arg(url).spawn()?;
648 }
649 #[cfg(target_os = "linux")]
650 {
651 std::process::Command::new("xdg-open").arg(url).spawn()?;
652 }
653 #[cfg(target_os = "windows")]
654 {
655 std::process::Command::new("cmd")
656 .args(["/c", "start", url])
657 .spawn()?;
658 }
659 Ok(())
660}
661
662async fn fetch_json_capped(
669 client: &reqwest::Client,
670 url: &str,
671 bearer: Option<&str>,
672 max_bytes: usize,
673 label: &str,
674) -> Result<serde_json::Value> {
675 let mut req = client.get(url);
676 if let Some(token) = bearer {
677 req = req.bearer_auth(token);
678 }
679 let mut resp = req
680 .send()
681 .await
682 .map_err(|e| AvError::InvalidPolicy(format!("{} failed: {}", label, e)))?;
683 if !resp.status().is_success() {
684 return Err(AvError::InvalidPolicy(format!(
685 "{} returned HTTP {}",
686 label,
687 resp.status()
688 )));
689 }
690 let mut body: Vec<u8> = Vec::new();
691 loop {
692 match resp.chunk().await {
693 Ok(Some(chunk)) => {
694 if body.len() + chunk.len() > max_bytes {
695 return Err(AvError::InvalidPolicy(format!(
696 "{} response exceeds {} byte cap",
697 label, max_bytes
698 )));
699 }
700 body.extend_from_slice(&chunk);
701 }
702 Ok(None) => break,
703 Err(e) => {
704 return Err(AvError::InvalidPolicy(format!(
705 "{} body read failed: {}",
706 label, e
707 )));
708 }
709 }
710 }
711 serde_json::from_slice(&body)
712 .map_err(|e| AvError::InvalidPolicy(format!("{} parse failed: {}", label, e)))
713}
714
715async fn listen_for_callback(redirect_uri: &str, expected_state: &str) -> Result<String> {
717 use tokio::io::AsyncReadExt;
718 use tokio::io::AsyncWriteExt;
719
720 let url: url::Url = redirect_uri
722 .parse()
723 .map_err(|e| AvError::InvalidPolicy(format!("Invalid redirect URI: {}", e)))?;
724 let port = url.port().unwrap_or(8400);
725
726 let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
727 .await
728 .map_err(|e| {
729 AvError::InvalidPolicy(format!(
730 "Failed to start callback listener on port {}: {}",
731 port, e
732 ))
733 })?;
734
735 eprintln!(
736 "Waiting for SSO callback on port {} (5 minute timeout)...",
737 port
738 );
739
740 let (mut stream, _) =
741 tokio::time::timeout(std::time::Duration::from_secs(300), listener.accept())
742 .await
743 .map_err(|_| AvError::InvalidPolicy("SSO login timed out after 5 minutes".to_string()))?
744 .map_err(|e| AvError::InvalidPolicy(format!("Failed to accept callback: {}", e)))?;
745
746 const MAX_CALLBACK_REQUEST_BYTES: usize = 16 * 1024;
754 let mut buf: Vec<u8> = Vec::with_capacity(4096);
755 let mut scratch = [0u8; 4096];
756 loop {
757 let n = stream
758 .read(&mut scratch)
759 .await
760 .map_err(|e| AvError::InvalidPolicy(format!("Failed to read callback: {}", e)))?;
761 if n == 0 {
762 break;
763 }
764 if buf.len() + n > MAX_CALLBACK_REQUEST_BYTES {
765 return Err(AvError::InvalidPolicy(format!(
766 "SSO callback request exceeds {} byte cap",
767 MAX_CALLBACK_REQUEST_BYTES
768 )));
769 }
770 buf.extend_from_slice(&scratch[..n]);
771 if buf.windows(4).any(|w| w == b"\r\n\r\n") {
775 break;
776 }
777 }
778 let request = String::from_utf8_lossy(&buf);
779
780 let first_line = request.lines().next().unwrap_or("");
782 let path = first_line.split_whitespace().nth(1).unwrap_or("/");
783 let query_url = format!("http://localhost{}", path);
784 let parsed: url::Url = query_url
785 .parse()
786 .map_err(|_| AvError::InvalidPolicy("Failed to parse callback URL".to_string()))?;
787
788 let params: std::collections::HashMap<String, String> = parsed
789 .query_pairs()
790 .map(|(k, v)| (k.to_string(), v.to_string()))
791 .collect();
792
793 let state = params.get("state").cloned().unwrap_or_default();
795 let state_valid = {
799 let a = state.as_bytes();
800 let b = expected_state.as_bytes();
801 if a.len() != b.len() {
802 false
803 } else {
804 a.iter()
805 .zip(b.iter())
806 .fold(0u8, |acc, (x, y)| acc | (x ^ y))
807 == 0
808 }
809 };
810 if !state_valid {
811 let response = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>State mismatch. Please try again.</p>";
812 let _ = stream.write_all(response.as_bytes()).await;
813 return Err(AvError::InvalidPolicy(
814 "SSO state mismatch — possible CSRF".to_string(),
815 ));
816 }
817
818 if let Some(error) = params.get("error") {
820 let desc = params.get("error_description").cloned().unwrap_or_default();
821 let escaped_error = html_escape(error);
822 let escaped_desc = html_escape(&desc);
823 let response = format!("HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n<h1>SSO Error</h1><p>{}: {}</p>", escaped_error, escaped_desc);
824 let _ = stream.write_all(response.as_bytes()).await;
825 return Err(AvError::InvalidPolicy(format!(
826 "SSO error: {}: {}",
827 error, desc
828 )));
829 }
830
831 let code = params
832 .get("code")
833 .cloned()
834 .ok_or_else(|| AvError::InvalidPolicy("No authorization code in callback".to_string()))?;
835
836 let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<h1>Audex SSO Login Successful</h1><p>You can close this window and return to the terminal.</p>";
838 let _ = stream.write_all(response.as_bytes()).await;
839
840 Ok(code)
841}
842
843pub fn logout() {
845 let path = cache_path();
846 let _ = std::fs::remove_file(path);
847}
848
849#[cfg(test)]
850mod tests {
851 use super::*;
852
853 #[test]
854 fn test_sso_config_deserialize() {
855 let toml_str = r#"
856provider = "okta"
857issuer = "https://myorg.okta.com"
858client_id = "0oa1234567890"
859scopes = ["openid", "email", "groups"]
860identity_claim = "email"
861groups_claim = "groups"
862required = true
863"#;
864 let config: SsoConfig = toml::from_str(toml_str).unwrap();
865 assert!(matches!(config.provider, SsoProvider::Okta));
866 assert_eq!(config.issuer, "https://myorg.okta.com");
867 assert_eq!(config.client_id, "0oa1234567890");
868 assert!(config.required);
869 assert_eq!(config.scopes.len(), 3);
870 }
871
872 #[test]
873 fn test_sso_config_google() {
874 let toml_str = r#"
875provider = "google"
876issuer = "https://accounts.google.com"
877client_id = "12345.apps.googleusercontent.com"
878client_secret = "GOCSPX-secret"
879"#;
880 let config: SsoConfig = toml::from_str(toml_str).unwrap();
881 assert!(matches!(config.provider, SsoProvider::Google));
882 assert!(config.client_secret.is_some());
883 assert_eq!(config.identity_claim, "email");
885 assert_eq!(config.groups_claim, "groups");
886 assert_eq!(config.redirect_uri, "http://localhost:8400/callback");
887 assert!(!config.required);
888 }
889
890 #[test]
891 fn test_sso_identity_serialize() {
892 let identity = SsoIdentity {
893 identity: "alice@example.com".to_string(),
894 groups: vec!["backend-team".to_string(), "devops".to_string()],
895 claims: serde_json::json!({"email": "alice@example.com", "sub": "user123"}),
896 };
897 let json = serde_json::to_string(&identity).unwrap();
898 assert!(json.contains("alice@example.com"));
899 assert!(json.contains("backend-team"));
900 let parsed: SsoIdentity = serde_json::from_str(&json).unwrap();
902 assert_eq!(parsed.identity, "alice@example.com");
903 assert_eq!(parsed.groups.len(), 2);
904 }
905
906 #[test]
907 fn test_pkce_verifier_and_challenge() {
908 let verifier = generate_pkce_verifier();
909 assert!(!verifier.is_empty());
910 let challenge = generate_pkce_challenge(&verifier);
911 assert!(!challenge.is_empty());
912 assert_ne!(verifier, challenge);
913 }
914
915 #[test]
916 fn test_state_generation() {
917 let state1 = generate_state();
918 let state2 = generate_state();
919 assert_ne!(state1, state2);
920 assert_eq!(state1.len(), 32); }
922
923 #[test]
924 fn test_decode_jwt_claims_unverified() {
925 use base64::Engine;
927 let future_exp = chrono::Utc::now().timestamp() + 3600;
928 let claims = serde_json::json!({
929 "email": "test@example.com",
930 "groups": ["admin"],
931 "exp": future_exp,
932 "iss": "https://accounts.example.com",
933 "aud": "my-client-id"
934 });
935 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
936 .encode(serde_json::to_vec(&claims).unwrap());
937 let fake_jwt = format!("header.{}.signature", payload);
938
939 let decoded =
940 decode_jwt_claims_unverified(&fake_jwt, "https://accounts.example.com", "my-client-id")
941 .unwrap();
942 assert_eq!(decoded["email"], "test@example.com");
943 assert_eq!(decoded["groups"][0], "admin");
944 }
945
946 #[test]
947 fn test_decode_jwt_invalid() {
948 assert!(decode_jwt_claims_unverified("not-a-jwt", "iss", "aud").is_err());
949 assert!(decode_jwt_claims_unverified("a.b", "iss", "aud").is_err());
950 }
951
952 #[test]
953 fn test_decode_jwt_expired() {
954 use base64::Engine;
955 let past_exp = chrono::Utc::now().timestamp() - 3600;
956 let claims = serde_json::json!({"email": "test@example.com", "exp": past_exp, "iss": "https://iss", "aud": "aud"});
957 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
958 .encode(serde_json::to_vec(&claims).unwrap());
959 let fake_jwt = format!("header.{}.signature", payload);
960 let err = decode_jwt_claims_unverified(&fake_jwt, "https://iss", "aud").unwrap_err();
961 assert!(err.to_string().contains("expired"));
962 }
963
964 #[test]
965 fn test_decode_jwt_wrong_issuer() {
966 use base64::Engine;
967 let future_exp = chrono::Utc::now().timestamp() + 3600;
968 let claims =
969 serde_json::json!({"exp": future_exp, "iss": "https://evil.com", "aud": "my-client"});
970 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
971 .encode(serde_json::to_vec(&claims).unwrap());
972 let fake_jwt = format!("header.{}.signature", payload);
973 let err =
974 decode_jwt_claims_unverified(&fake_jwt, "https://legit.com", "my-client").unwrap_err();
975 assert!(err.to_string().contains("issuer mismatch"));
976 }
977
978 #[test]
979 fn test_decode_jwt_wrong_audience() {
980 use base64::Engine;
981 let future_exp = chrono::Utc::now().timestamp() + 3600;
982 let claims =
983 serde_json::json!({"exp": future_exp, "iss": "https://iss", "aud": "other-client"});
984 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
985 .encode(serde_json::to_vec(&claims).unwrap());
986 let fake_jwt = format!("header.{}.signature", payload);
987 let err = decode_jwt_claims_unverified(&fake_jwt, "https://iss", "my-client").unwrap_err();
988 assert!(err.to_string().contains("audience"));
989 }
990}