1use std::collections::HashMap;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::{Arc, RwLock};
19use std::time::{Duration, Instant};
20
21use jsonwebtoken::jwk::KeyAlgorithm;
22use jsonwebtoken::{decode_header, Algorithm, DecodingKey, Validation};
23use serde::Deserialize;
24use vellaveto_config::abac::FederationConfig;
25use vellaveto_types::abac::{
26 FederationAnchorInfo, FederationAnchorStatus, FederationStatus, FederationTrustAnchor,
27};
28use vellaveto_types::identity::AgentIdentity;
29
30#[derive(Debug, Clone)]
36pub struct FederatedIdentity {
37 pub identity: AgentIdentity,
39 pub org_id: String,
41 pub trust_level: String,
43}
44
45#[derive(Debug)]
47pub enum FederationError {
48 JwksFetchFailed { org_id: String, source: String },
50 JwtValidationFailed { org_id: String, source: String },
52 NoMatchingKey { org_id: String, kid: String },
54 DisallowedAlgorithm(String),
56 InvalidHeader(String),
58}
59
60impl std::fmt::Display for FederationError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::JwksFetchFailed { org_id, source } => {
64 write!(f, "JWKS fetch failed for {}: {}", org_id, source)
65 }
66 Self::JwtValidationFailed { org_id, source } => {
67 write!(f, "JWT validation failed for {}: {}", org_id, source)
68 }
69 Self::NoMatchingKey { org_id, kid } => {
70 let safe_kid = vellaveto_types::sanitize_for_log(kid, 128);
75 write!(
76 f,
77 "no matching key in JWKS for org {}, kid '{}'",
78 org_id, safe_kid
79 )
80 }
81 Self::DisallowedAlgorithm(alg) => {
82 write!(f, "disallowed JWT algorithm: {}", alg)
83 }
84 Self::InvalidHeader(msg) => {
85 write!(f, "invalid JWT header: {}", msg)
86 }
87 }
88 }
89}
90
91impl std::error::Error for FederationError {}
92
93struct CachedJwks {
95 keys: jsonwebtoken::jwk::JwkSet,
96 fetched_at: Instant,
97}
98
99struct CompiledAnchor {
101 config: FederationTrustAnchor,
102 jwks_cache: RwLock<Option<CachedJwks>>,
103 success_count: AtomicU64,
104 failure_count: AtomicU64,
105}
106
107const MAX_EXTRA_CLAIMS: usize = 50;
109
110const MAX_EXTRA_CLAIM_VALUE_BYTES: usize = 8192;
112
113const MAX_CLAIM_ELEMENT_LEN: usize = 1024;
115
116const MAX_CLAIM_JOINED_LEN: usize = 8192;
118
119const MAX_TEMPLATE_CLAIM_VALUE_LEN: usize = 1024;
121
122#[derive(Debug, Deserialize)]
124struct FederatedClaims {
125 #[serde(default)]
126 sub: Option<String>,
127 #[serde(default)]
128 iss: Option<String>,
129 #[serde(default)]
130 email: Option<String>,
131 #[serde(flatten)]
133 extra: HashMap<String, serde_json::Value>,
134}
135
136impl FederatedClaims {
137 fn validate_extra_claims(&self) -> Result<(), String> {
142 if self.extra.len() > MAX_EXTRA_CLAIMS {
143 return Err(format!(
144 "federated JWT has {} extra claims, max is {}",
145 self.extra.len(),
146 MAX_EXTRA_CLAIMS
147 ));
148 }
149 for (key, value) in &self.extra {
150 let serialized_len = serde_json::to_string(value).map(|s| s.len()).unwrap_or(0);
151 if serialized_len > MAX_EXTRA_CLAIM_VALUE_BYTES {
152 return Err(format!(
153 "federated JWT extra claim '{}' is {} bytes, max is {}",
154 key, serialized_len, MAX_EXTRA_CLAIM_VALUE_BYTES
155 ));
156 }
157 }
158 Ok(())
159 }
160}
161
162const ALLOWED_ALGORITHMS: &[Algorithm] = &[
164 Algorithm::RS256,
165 Algorithm::RS384,
166 Algorithm::RS512,
167 Algorithm::ES256,
168 Algorithm::ES384,
169 Algorithm::PS256,
170 Algorithm::PS384,
171 Algorithm::PS512,
172 Algorithm::EdDSA,
173];
174
175const MAX_JWKS_BODY_BYTES: usize = 1_048_576;
185
186pub struct FederationResolver {
187 anchors: Vec<Arc<CompiledAnchor>>,
188 http_client: reqwest::Client,
189 cache_ttl: Duration,
190 fetch_timeout: Duration,
191 expected_audience: Option<String>,
194}
195
196impl FederationResolver {
197 pub fn new(config: &FederationConfig, http_client: reqwest::Client) -> Result<Self, String> {
199 let mut anchors = Vec::with_capacity(config.trust_anchors.len());
200 for anchor_config in &config.trust_anchors {
201 anchor_config.validate()?;
202 anchors.push(Arc::new(CompiledAnchor {
203 config: anchor_config.clone(),
204 jwks_cache: RwLock::new(None),
205 success_count: AtomicU64::new(0),
206 failure_count: AtomicU64::new(0),
207 }));
208 }
209 Ok(Self {
210 anchors,
211 http_client,
212 cache_ttl: Duration::from_secs(config.jwks_cache_ttl_secs),
213 fetch_timeout: Duration::from_secs(config.jwks_fetch_timeout_secs),
214 expected_audience: config.expected_audience.clone(),
215 })
216 }
217
218 pub async fn validate_federated_token(
224 &self,
225 token: &str,
226 ) -> Result<Option<FederatedIdentity>, FederationError> {
227 let header =
229 decode_header(token).map_err(|e| FederationError::InvalidHeader(e.to_string()))?;
230
231 let alg = header.alg;
232 let alg_str = format!("{:?}", alg);
233 if !ALLOWED_ALGORITHMS.contains(&alg) {
234 return Err(FederationError::DisallowedAlgorithm(alg_str));
235 }
236
237 let kid = match header.kid {
241 Some(ref k) if !k.is_empty() => k.clone(),
242 _ => {
243 return Err(FederationError::InvalidHeader(
244 "JWT header missing required 'kid' field for federation".to_string(),
245 ))
246 }
247 };
248
249 let issuer = extract_issuer_from_payload(token)
251 .ok_or_else(|| FederationError::InvalidHeader("missing iss claim".to_string()))?;
252
253 let anchor = match self.find_matching_anchor(&issuer) {
259 Some(a) => a,
260 None => return Ok(None), };
262
263 let decoding_key =
265 self.get_decoding_key(&anchor, &kid, &alg)
266 .await
267 .inspect_err(|_| {
268 let _ = anchor.failure_count.fetch_update(
269 Ordering::SeqCst,
270 Ordering::SeqCst,
271 |v| Some(v.saturating_add(1)),
272 );
273 })?;
274
275 let mut validation = Validation::new(alg);
277 validation.validate_exp = true;
278 validation.validate_nbf = true;
281 validation.set_issuer(&[&issuer]);
282 if let Some(ref aud) = self.expected_audience {
285 validation.validate_aud = true;
286 validation.set_audience(&[aud]);
287 } else {
288 validation.validate_aud = false;
289 }
290 validation.leeway = 60;
292
293 let token_data = jsonwebtoken::decode::<FederatedClaims>(token, &decoding_key, &validation)
294 .map_err(|e| {
295 let _ =
296 anchor
297 .failure_count
298 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
299 Some(v.saturating_add(1))
300 });
301 FederationError::JwtValidationFailed {
302 org_id: anchor.config.org_id.clone(),
303 source: e.to_string(),
304 }
305 })?;
306
307 if let Err(reason) = token_data.claims.validate_extra_claims() {
309 let _ = anchor
310 .failure_count
311 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
312 Some(v.saturating_add(1))
313 });
314 return Err(FederationError::JwtValidationFailed {
315 org_id: anchor.config.org_id.clone(),
316 source: reason,
317 });
318 }
319
320 if let Some(ref verified_iss) = token_data.claims.iss {
326 if !issuer_pattern_matches(&anchor.config.issuer_pattern, verified_iss) {
327 let _ =
328 anchor
329 .failure_count
330 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
331 Some(v.saturating_add(1))
332 });
333 return Err(FederationError::JwtValidationFailed {
334 org_id: anchor.config.org_id.clone(),
335 source: format!(
336 "verified token issuer '{}' does not match anchor pattern '{}'",
337 verified_iss, anchor.config.issuer_pattern
338 ),
339 });
340 }
341 }
342
343 let identity = self.apply_identity_mappings(&anchor, &token_data.claims);
345
346 let _ = anchor
347 .success_count
348 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
349 Some(v.saturating_add(1))
350 });
351
352 Ok(Some(FederatedIdentity {
353 identity,
354 org_id: anchor.config.org_id.clone(),
355 trust_level: anchor.config.trust_level.clone(),
356 }))
357 }
358
359 pub fn status(&self) -> FederationStatus {
361 FederationStatus {
362 enabled: true,
363 trust_anchor_count: self.anchors.len(),
364 anchors: self
365 .anchors
366 .iter()
367 .map(|a| {
368 let cache = a.jwks_cache.read().ok();
369 let (cached, last_fetched) = match cache.as_ref().and_then(|c| c.as_ref()) {
370 Some(c) => (true, Some(format!("{:?}", c.fetched_at.elapsed()))),
371 None => (false, None),
372 };
373 FederationAnchorStatus {
374 org_id: a.config.org_id.clone(),
375 display_name: a.config.display_name.clone(),
376 issuer_pattern: a.config.issuer_pattern.clone(),
377 trust_level: a.config.trust_level.clone(),
378 has_jwks_uri: a.config.jwks_uri.is_some(),
379 jwks_cached: cached,
380 jwks_last_fetched: last_fetched,
381 identity_mapping_count: a.config.identity_mappings.len(),
382 successful_validations: a.success_count.load(Ordering::SeqCst),
383 failed_validations: a.failure_count.load(Ordering::SeqCst),
384 }
385 })
386 .collect(),
387 }
388 }
389
390 pub fn anchor_info(&self) -> Vec<FederationAnchorInfo> {
392 self.anchors
393 .iter()
394 .map(|a| FederationAnchorInfo {
395 org_id: a.config.org_id.clone(),
396 display_name: a.config.display_name.clone(),
397 issuer_pattern: a.config.issuer_pattern.clone(),
398 trust_level: a.config.trust_level.clone(),
399 has_jwks_uri: a.config.jwks_uri.is_some(),
400 identity_mapping_count: a.config.identity_mappings.len(),
401 })
402 .collect()
403 }
404
405 fn find_matching_anchor(&self, issuer: &str) -> Option<Arc<CompiledAnchor>> {
408 self.anchors
409 .iter()
410 .find(|a| issuer_pattern_matches(&a.config.issuer_pattern, issuer))
411 .cloned()
412 }
413
414 async fn get_decoding_key(
415 &self,
416 anchor: &CompiledAnchor,
417 kid: &str,
418 alg: &Algorithm,
419 ) -> Result<DecodingKey, FederationError> {
420 let jwks_uri =
421 anchor
422 .config
423 .jwks_uri
424 .as_deref()
425 .ok_or_else(|| FederationError::JwksFetchFailed {
426 org_id: anchor.config.org_id.clone(),
427 source: "no jwks_uri configured".to_string(),
428 })?;
429
430 {
435 match anchor.jwks_cache.read() {
436 Ok(cache_guard) => {
437 if let Some(ref cached) = *cache_guard {
438 if cached.fetched_at.elapsed() < self.cache_ttl {
439 return find_key_in_jwks(&cached.keys, kid, alg, &anchor.config.org_id);
440 }
441 }
442 }
443 Err(_) => {
444 tracing::error!(
445 target: "vellaveto::security",
446 org_id = %anchor.config.org_id,
447 "JWKS cache read lock poisoned — treating as cache miss (fail-closed)"
448 );
449 }
451 }
452 }
453
454 let jwks = self.fetch_jwks(jwks_uri, &anchor.config.org_id).await?;
456
457 let result = find_key_in_jwks(&jwks, kid, alg, &anchor.config.org_id);
459
460 {
464 let mut cache_guard = anchor.jwks_cache.write().unwrap_or_else(|e| {
465 tracing::warn!(
466 "JWKS cache write lock poisoned for org '{}'; clearing stale cache",
467 anchor.config.org_id
468 );
469 let mut guard = e.into_inner();
470 *guard = None; guard
472 });
473 *cache_guard = Some(CachedJwks {
474 keys: jwks,
475 fetched_at: Instant::now(),
476 });
477 }
478
479 result
480 }
481
482 async fn fetch_jwks(
483 &self,
484 uri: &str,
485 org_id: &str,
486 ) -> Result<jsonwebtoken::jwk::JwkSet, FederationError> {
487 let resp = self
488 .http_client
489 .get(uri)
490 .timeout(self.fetch_timeout)
491 .send()
492 .await
493 .map_err(|e| FederationError::JwksFetchFailed {
494 org_id: org_id.to_string(),
495 source: e.to_string(),
496 })?;
497
498 if !resp.status().is_success() {
499 return Err(FederationError::JwksFetchFailed {
500 org_id: org_id.to_string(),
501 source: format!("HTTP {}", resp.status()),
502 });
503 }
504
505 if let Some(len) = resp.content_length() {
508 if len as usize > MAX_JWKS_BODY_BYTES {
509 return Err(FederationError::JwksFetchFailed {
510 org_id: org_id.to_string(),
511 source: format!(
512 "JWKS Content-Length {} exceeds {} byte limit",
513 len, MAX_JWKS_BODY_BYTES
514 ),
515 });
516 }
517 }
518
519 let capacity = std::cmp::min(
522 resp.content_length().unwrap_or(8192) as usize,
523 MAX_JWKS_BODY_BYTES,
524 );
525 let mut body = Vec::with_capacity(capacity);
526 let mut resp = resp;
527 while let Some(chunk) =
528 resp.chunk()
529 .await
530 .map_err(|e| FederationError::JwksFetchFailed {
531 org_id: org_id.to_string(),
532 source: e.to_string(),
533 })?
534 {
535 if body.len().saturating_add(chunk.len()) > MAX_JWKS_BODY_BYTES {
536 return Err(FederationError::JwksFetchFailed {
537 org_id: org_id.to_string(),
538 source: format!("JWKS response exceeds {} byte limit", MAX_JWKS_BODY_BYTES),
539 });
540 }
541 body.extend_from_slice(&chunk);
542 }
543
544 serde_json::from_slice(&body).map_err(|e| FederationError::JwksFetchFailed {
545 org_id: org_id.to_string(),
546 source: format!("invalid JWKS JSON: {}", e),
547 })
548 }
549
550 fn apply_identity_mappings(
551 &self,
552 anchor: &CompiledAnchor,
553 claims: &FederatedClaims,
554 ) -> AgentIdentity {
555 let mut identity_claims: HashMap<String, serde_json::Value> = HashMap::new();
556
557 identity_claims.insert(
559 "federation.org_id".to_string(),
560 serde_json::Value::String(anchor.config.org_id.clone()),
561 );
562 identity_claims.insert(
563 "federation.trust_level".to_string(),
564 serde_json::Value::String(anchor.config.trust_level.clone()),
565 );
566 if let Some(ref iss) = claims.iss {
567 identity_claims.insert(
568 "federation.issuer".to_string(),
569 serde_json::Value::String(iss.clone()),
570 );
571 }
572
573 let mut subject = claims.sub.clone();
574
575 for mapping in &anchor.config.identity_mappings {
577 if let Some(value) = extract_claim_value(claims, &mapping.external_claim) {
578 let sanitized = sanitize_claim_for_template(&value);
583
584 let rendered = mapping
585 .id_template
586 .replace("{claim_value}", &sanitized)
587 .replace("{org_id}", &anchor.config.org_id);
588
589 identity_claims.insert(
590 "principal.type".to_string(),
591 serde_json::Value::String(mapping.internal_principal_type.clone()),
592 );
593 identity_claims.insert(
594 "principal.id".to_string(),
595 serde_json::Value::String(rendered.clone()),
596 );
597
598 if mapping.external_claim == "sub" || mapping.external_claim == "email" {
600 subject = Some(rendered);
601 }
602 }
603 }
604
605 AgentIdentity {
606 issuer: claims.iss.clone(),
607 subject,
608 audience: Vec::new(),
609 claims: identity_claims,
610 }
611 }
612}
613
614fn sanitize_claim_for_template(value: &str) -> String {
624 let sanitized: String = value
625 .chars()
626 .filter(|c| {
628 !c.is_control()
629 && !vellaveto_types::is_unicode_format_char(*c)
630 && *c != '{'
631 && *c != '}'
632 })
633 .collect();
634 if sanitized.len() > MAX_TEMPLATE_CLAIM_VALUE_LEN {
635 let mut end = MAX_TEMPLATE_CLAIM_VALUE_LEN;
637 while end > 0 && !sanitized.is_char_boundary(end) {
638 end -= 1;
639 }
640 sanitized[..end].to_string()
641 } else {
642 sanitized
643 }
644}
645
646fn extract_claim_value(claims: &FederatedClaims, claim_path: &str) -> Option<String> {
648 match claim_path {
650 "sub" => return claims.sub.clone(),
651 "iss" => return claims.iss.clone(),
652 "email" => return claims.email.clone(),
653 _ => {}
654 }
655
656 let parts: Vec<&str> = claim_path.splitn(10, '.').collect();
658 let mut current: Option<&serde_json::Value> = claims.extra.get(parts[0]);
659
660 for part in &parts[1..] {
661 current = current.and_then(|v| v.get(part));
662 }
663
664 current.and_then(|v| match v {
665 serde_json::Value::String(s) => Some(s.clone()),
666 serde_json::Value::Array(arr) => {
667 let mut total_len = 0usize;
669 let joined: Vec<String> = arr
670 .iter()
671 .take(64) .filter_map(|item| {
673 item.as_str().map(|s| {
674 if s.len() > MAX_CLAIM_ELEMENT_LEN {
676 s[..MAX_CLAIM_ELEMENT_LEN].to_string()
677 } else {
678 s.to_string()
679 }
680 })
681 })
682 .take_while(|s| {
683 let added = if total_len == 0 { s.len() } else { s.len() + 1 };
685 if total_len + added > MAX_CLAIM_JOINED_LEN {
686 return false;
687 }
688 total_len += added;
689 true
690 })
691 .collect();
692 if joined.is_empty() {
693 None
694 } else {
695 Some(joined.join(","))
696 }
697 }
698 serde_json::Value::Number(n) => Some(n.to_string()),
699 serde_json::Value::Bool(b) => Some(b.to_string()),
700 _ => None,
701 })
702}
703
704const MAX_JWT_PAYLOAD_B64_LEN: usize = 65_536;
708
709fn extract_issuer_from_payload(token: &str) -> Option<String> {
711 let parts: Vec<&str> = token.splitn(4, '.').collect();
712 if parts.len() < 2 {
713 return None;
714 }
715 if parts[1].len() > MAX_JWT_PAYLOAD_B64_LEN {
717 return None;
718 }
719 use base64::Engine;
720 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
721 .decode(parts[1])
722 .ok()?;
723 let payload: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
724 payload.get("iss")?.as_str().map(String::from)
725}
726
727fn issuer_pattern_matches(pattern: &str, issuer: &str) -> bool {
732 if pattern == issuer {
733 return true;
734 }
735 if !pattern.contains('*') {
736 return false;
737 }
738
739 let normalized: String = {
742 let mut result = String::with_capacity(pattern.len());
743 let mut prev_star = false;
744 for ch in pattern.chars() {
745 if ch == '*' {
746 if !prev_star {
747 result.push(ch);
748 }
749 prev_star = true;
750 } else {
751 result.push(ch);
752 prev_star = false;
753 }
754 }
755 result
756 };
757
758 let parts: Vec<&str> = normalized.split('*').collect();
760
761 if parts.len() > 11 {
763 return false;
764 }
765
766 let mut remaining = issuer;
767 for (i, part) in parts.iter().enumerate() {
768 if part.is_empty() {
769 continue;
770 }
771 if i == 0 {
772 if !remaining.starts_with(part) {
774 return false;
775 }
776 remaining = &remaining[part.len()..];
777 } else if i == parts.len() - 1 && !part.is_empty() {
778 if !remaining.ends_with(part) {
780 return false;
781 }
782 remaining = "";
783 } else {
784 match remaining.find(part) {
786 Some(pos) => remaining = &remaining[pos + part.len()..],
787 None => return false,
788 }
789 }
790 }
791 true
792}
793
794fn key_algorithm_to_algorithm(ka: &KeyAlgorithm) -> Option<Algorithm> {
801 match ka {
802 KeyAlgorithm::HS256 => Some(Algorithm::HS256),
803 KeyAlgorithm::HS384 => Some(Algorithm::HS384),
804 KeyAlgorithm::HS512 => Some(Algorithm::HS512),
805 KeyAlgorithm::ES256 => Some(Algorithm::ES256),
806 KeyAlgorithm::ES384 => Some(Algorithm::ES384),
807 KeyAlgorithm::RS256 => Some(Algorithm::RS256),
808 KeyAlgorithm::RS384 => Some(Algorithm::RS384),
809 KeyAlgorithm::RS512 => Some(Algorithm::RS512),
810 KeyAlgorithm::PS256 => Some(Algorithm::PS256),
811 KeyAlgorithm::PS384 => Some(Algorithm::PS384),
812 KeyAlgorithm::PS512 => Some(Algorithm::PS512),
813 KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
814 _ => None,
816 }
817}
818
819fn find_key_in_jwks(
827 jwks: &jsonwebtoken::jwk::JwkSet,
828 kid: &str,
829 alg: &Algorithm,
830 org_id: &str,
831) -> Result<DecodingKey, FederationError> {
832 for key in &jwks.keys {
833 match &key.common.key_id {
836 Some(key_kid) => {
837 if !kid.is_empty() && key_kid != kid {
838 continue; }
840 }
841 None => {
842 tracing::warn!(
844 org_id = %org_id,
845 "skipping JWK without kid field in JWKS for org '{}'",
846 org_id
847 );
848 continue;
849 }
850 }
851
852 if let Some(ref key_alg) = key.common.key_algorithm {
854 match key_algorithm_to_algorithm(key_alg) {
855 Some(mapped) if &mapped == alg => {} _ => continue, }
858 }
859
860 if let Ok(dk) = DecodingKey::from_jwk(key) {
862 return Ok(dk);
863 }
864 }
865
866 Err(FederationError::NoMatchingKey {
867 org_id: org_id.to_string(),
868 kid: kid.to_string(),
869 })
870}
871
872#[cfg(test)]
877mod tests {
878 use super::*;
879
880 fn test_config() -> FederationConfig {
881 FederationConfig {
882 enabled: true,
883 trust_anchors: vec![FederationTrustAnchor {
884 org_id: "partner-org".to_string(),
885 display_name: "Partner Organization".to_string(),
886 jwks_uri: Some("https://auth.partner.com/.well-known/jwks.json".to_string()),
887 issuer_pattern: "https://auth.partner.com".to_string(),
888 identity_mappings: vec![vellaveto_types::abac::IdentityMapping {
889 external_claim: "sub".to_string(),
890 internal_principal_type: "agent".to_string(),
891 id_template: "partner-org:{claim_value}".to_string(),
892 }],
893 trust_level: "limited".to_string(),
894 }],
895 jwks_cache_ttl_secs: 300,
896 jwks_fetch_timeout_secs: 10,
897 expected_audience: None,
898 }
899 }
900
901 fn test_config_wildcard() -> FederationConfig {
902 FederationConfig {
903 enabled: true,
904 trust_anchors: vec![FederationTrustAnchor {
905 org_id: "acme".to_string(),
906 display_name: "ACME Corp".to_string(),
907 jwks_uri: Some("https://auth.acme.com/.well-known/jwks.json".to_string()),
908 issuer_pattern: "https://auth.acme.com/*".to_string(),
909 identity_mappings: vec![],
910 trust_level: "full".to_string(),
911 }],
912 jwks_cache_ttl_secs: 300,
913 jwks_fetch_timeout_secs: 10,
914 expected_audience: None,
915 }
916 }
917
918 #[test]
919 fn test_new_valid_config_compiles_anchors() {
920 let config = test_config();
921 let client = reqwest::Client::new();
922 let resolver = FederationResolver::new(&config, client);
923 assert!(resolver.is_ok());
924 let resolver = resolver.expect("test config should be valid");
925 assert_eq!(resolver.anchors.len(), 1);
926 }
927
928 #[test]
929 fn test_new_invalid_anchor_fails() {
930 let config = FederationConfig {
931 enabled: true,
932 trust_anchors: vec![FederationTrustAnchor {
933 org_id: String::new(), display_name: "Bad".to_string(),
935 jwks_uri: None,
936 issuer_pattern: "https://x.com".to_string(),
937 identity_mappings: vec![],
938 trust_level: "limited".to_string(),
939 }],
940 jwks_cache_ttl_secs: 300,
941 jwks_fetch_timeout_secs: 10,
942 expected_audience: None,
943 };
944 let client = reqwest::Client::new();
945 assert!(FederationResolver::new(&config, client).is_err());
946 }
947
948 #[test]
949 fn test_issuer_pattern_exact_match() {
950 assert!(issuer_pattern_matches(
951 "https://auth.partner.com",
952 "https://auth.partner.com"
953 ));
954 }
955
956 #[test]
957 fn test_issuer_pattern_no_match() {
958 assert!(!issuer_pattern_matches(
959 "https://auth.partner.com",
960 "https://auth.evil.com"
961 ));
962 }
963
964 #[test]
965 fn test_issuer_pattern_glob_wildcard_suffix() {
966 assert!(issuer_pattern_matches(
967 "https://auth.acme.com/*",
968 "https://auth.acme.com/tenant-1"
969 ));
970 }
971
972 #[test]
973 fn test_issuer_pattern_glob_wildcard_middle() {
974 assert!(issuer_pattern_matches(
975 "https://*.acme.com/auth",
976 "https://tenant1.acme.com/auth"
977 ));
978 }
979
980 #[test]
981 fn test_issuer_pattern_glob_wildcard_no_match() {
982 assert!(!issuer_pattern_matches(
983 "https://auth.acme.com/*",
984 "https://auth.evil.com/tenant"
985 ));
986 }
987
988 #[test]
990 fn test_issuer_pattern_consecutive_wildcards_collapsed() {
991 assert!(issuer_pattern_matches(
993 "https://auth.acme.com/**",
994 "https://auth.acme.com/tenant-1"
995 ));
996 assert!(!issuer_pattern_matches(
997 "https://auth.acme.com/**",
998 "https://auth.evil.com/tenant"
999 ));
1000 }
1001
1002 #[test]
1004 fn test_issuer_pattern_star_dot_star() {
1005 assert!(issuer_pattern_matches(
1006 "https://*.*",
1007 "https://auth.example.com"
1008 ));
1009 assert!(!issuer_pattern_matches(
1010 "https://*.*",
1011 "https://nosubdomain"
1012 ));
1013 }
1014
1015 #[test]
1016 fn test_find_matching_anchor_exact() {
1017 let config = test_config();
1018 let client = reqwest::Client::new();
1019 let resolver = FederationResolver::new(&config, client).expect("valid config");
1020 let anchor = resolver.find_matching_anchor("https://auth.partner.com");
1021 assert!(anchor.is_some());
1022 assert_eq!(anchor.expect("should match").config.org_id, "partner-org");
1023 }
1024
1025 #[test]
1026 fn test_find_matching_anchor_wildcard() {
1027 let config = test_config_wildcard();
1028 let client = reqwest::Client::new();
1029 let resolver = FederationResolver::new(&config, client).expect("valid config");
1030 let anchor = resolver.find_matching_anchor("https://auth.acme.com/tenant-1");
1031 assert!(anchor.is_some());
1032 assert_eq!(anchor.expect("should match").config.org_id, "acme");
1033 }
1034
1035 #[test]
1036 fn test_find_matching_anchor_no_match() {
1037 let config = test_config();
1038 let client = reqwest::Client::new();
1039 let resolver = FederationResolver::new(&config, client).expect("valid config");
1040 assert!(resolver
1041 .find_matching_anchor("https://auth.evil.com")
1042 .is_none());
1043 }
1044
1045 #[test]
1046 fn test_extract_claim_value_sub() {
1047 let claims = FederatedClaims {
1048 sub: Some("agent-123".to_string()),
1049 iss: Some("https://auth.example.com".to_string()),
1050 email: None,
1051 extra: HashMap::new(),
1052 };
1053 assert_eq!(
1054 extract_claim_value(&claims, "sub"),
1055 Some("agent-123".to_string())
1056 );
1057 }
1058
1059 #[test]
1060 fn test_extract_claim_value_nested() {
1061 let mut extra = HashMap::new();
1062 extra.insert(
1063 "realm_access".to_string(),
1064 serde_json::json!({"roles": ["admin", "user"]}),
1065 );
1066 let claims = FederatedClaims {
1067 sub: None,
1068 iss: None,
1069 email: None,
1070 extra,
1071 };
1072 assert_eq!(
1073 extract_claim_value(&claims, "realm_access.roles"),
1074 Some("admin,user".to_string())
1075 );
1076 }
1077
1078 #[test]
1079 fn test_extract_claim_value_missing() {
1080 let claims = FederatedClaims {
1081 sub: None,
1082 iss: None,
1083 email: None,
1084 extra: HashMap::new(),
1085 };
1086 assert_eq!(extract_claim_value(&claims, "nonexistent"), None);
1087 }
1088
1089 #[test]
1090 fn test_apply_identity_mappings_injects_federation_metadata() {
1091 let config = test_config();
1092 let client = reqwest::Client::new();
1093 let resolver = FederationResolver::new(&config, client).expect("valid config");
1094 let claims = FederatedClaims {
1095 sub: Some("agent-456".to_string()),
1096 iss: Some("https://auth.partner.com".to_string()),
1097 email: None,
1098 extra: HashMap::new(),
1099 };
1100 let identity = resolver.apply_identity_mappings(&resolver.anchors[0], &claims);
1101 assert_eq!(
1102 identity.claims.get("federation.org_id"),
1103 Some(&serde_json::Value::String("partner-org".to_string()))
1104 );
1105 assert_eq!(
1106 identity.claims.get("federation.trust_level"),
1107 Some(&serde_json::Value::String("limited".to_string()))
1108 );
1109 assert_eq!(
1110 identity.claims.get("federation.issuer"),
1111 Some(&serde_json::Value::String(
1112 "https://auth.partner.com".to_string()
1113 ))
1114 );
1115 }
1116
1117 #[test]
1118 fn test_apply_identity_mappings_template_substitution() {
1119 let config = test_config();
1120 let client = reqwest::Client::new();
1121 let resolver = FederationResolver::new(&config, client).expect("valid config");
1122 let claims = FederatedClaims {
1123 sub: Some("agent-789".to_string()),
1124 iss: Some("https://auth.partner.com".to_string()),
1125 email: None,
1126 extra: HashMap::new(),
1127 };
1128 let identity = resolver.apply_identity_mappings(&resolver.anchors[0], &claims);
1129 assert_eq!(
1130 identity.claims.get("principal.id"),
1131 Some(&serde_json::Value::String(
1132 "partner-org:agent-789".to_string()
1133 ))
1134 );
1135 assert_eq!(
1136 identity.claims.get("principal.type"),
1137 Some(&serde_json::Value::String("agent".to_string()))
1138 );
1139 }
1140
1141 #[test]
1143 fn test_apply_identity_mappings_sanitizes_control_chars() {
1144 let config = test_config();
1145 let client = reqwest::Client::new();
1146 let resolver = FederationResolver::new(&config, client).expect("valid config");
1147 let claims = FederatedClaims {
1148 sub: Some("agent\x00\x0A\x0Dinjected".to_string()),
1149 iss: Some("https://auth.partner.com".to_string()),
1150 email: None,
1151 extra: HashMap::new(),
1152 };
1153 let identity = resolver.apply_identity_mappings(&resolver.anchors[0], &claims);
1154 let principal_id = identity
1155 .claims
1156 .get("principal.id")
1157 .and_then(|v| v.as_str())
1158 .expect("principal.id must exist");
1159 assert!(!principal_id.contains('\x00'));
1161 assert!(!principal_id.contains('\x0A'));
1162 assert!(!principal_id.contains('\x0D'));
1163 assert!(principal_id.contains("agentinjected"));
1164 }
1165
1166 #[test]
1168 fn test_apply_identity_mappings_strips_template_syntax() {
1169 let config = test_config();
1170 let client = reqwest::Client::new();
1171 let resolver = FederationResolver::new(&config, client).expect("valid config");
1172 let claims = FederatedClaims {
1173 sub: Some("{claim_value}evil{org_id}".to_string()),
1174 iss: Some("https://auth.partner.com".to_string()),
1175 email: None,
1176 extra: HashMap::new(),
1177 };
1178 let identity = resolver.apply_identity_mappings(&resolver.anchors[0], &claims);
1179 let principal_id = identity
1180 .claims
1181 .get("principal.id")
1182 .and_then(|v| v.as_str())
1183 .expect("principal.id must exist");
1184 assert!(!principal_id.contains('{'));
1186 assert!(!principal_id.contains('}'));
1187 assert!(principal_id.contains("claim_valueevilorg_id"));
1188 }
1189
1190 #[test]
1192 fn test_sanitize_claim_for_template_truncates_long_values() {
1193 let long_value = "a".repeat(2000);
1194 let sanitized = sanitize_claim_for_template(&long_value);
1195 assert!(sanitized.len() <= MAX_TEMPLATE_CLAIM_VALUE_LEN);
1196 }
1197
1198 #[test]
1200 fn test_key_algorithm_to_algorithm_explicit_mapping() {
1201 assert_eq!(
1202 key_algorithm_to_algorithm(&KeyAlgorithm::RS256),
1203 Some(Algorithm::RS256)
1204 );
1205 assert_eq!(
1206 key_algorithm_to_algorithm(&KeyAlgorithm::ES256),
1207 Some(Algorithm::ES256)
1208 );
1209 assert_eq!(
1210 key_algorithm_to_algorithm(&KeyAlgorithm::PS256),
1211 Some(Algorithm::PS256)
1212 );
1213 assert_eq!(
1214 key_algorithm_to_algorithm(&KeyAlgorithm::EdDSA),
1215 Some(Algorithm::EdDSA)
1216 );
1217 assert_eq!(key_algorithm_to_algorithm(&KeyAlgorithm::RSA1_5), None);
1219 }
1220
1221 #[test]
1222 fn test_status_reports_anchors() {
1223 let config = test_config();
1224 let client = reqwest::Client::new();
1225 let resolver = FederationResolver::new(&config, client).expect("valid config");
1226 let status = resolver.status();
1227 assert!(status.enabled);
1228 assert_eq!(status.trust_anchor_count, 1);
1229 assert_eq!(status.anchors.len(), 1);
1230 assert_eq!(status.anchors[0].org_id, "partner-org");
1231 assert_eq!(status.anchors[0].trust_level, "limited");
1232 assert!(!status.anchors[0].jwks_cached);
1233 }
1234
1235 #[test]
1236 fn test_anchor_info() {
1237 let config = test_config();
1238 let client = reqwest::Client::new();
1239 let resolver = FederationResolver::new(&config, client).expect("valid config");
1240 let infos = resolver.anchor_info();
1241 assert_eq!(infos.len(), 1);
1242 assert_eq!(infos[0].org_id, "partner-org");
1243 assert!(infos[0].has_jwks_uri);
1244 assert_eq!(infos[0].identity_mapping_count, 1);
1245 }
1246
1247 #[test]
1248 fn test_extract_issuer_from_payload_valid() {
1249 use base64::Engine;
1251 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
1252 .encode(r#"{"alg":"RS256","typ":"JWT"}"#);
1253 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
1254 .encode(r#"{"iss":"https://auth.example.com","sub":"test"}"#);
1255 let token = format!("{}.{}.fake-sig", header, payload);
1256 assert_eq!(
1257 extract_issuer_from_payload(&token),
1258 Some("https://auth.example.com".to_string())
1259 );
1260 }
1261
1262 #[test]
1263 fn test_extract_issuer_from_payload_missing_iss() {
1264 use base64::Engine;
1265 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256"}"#);
1266 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#);
1267 let token = format!("{}.{}.sig", header, payload);
1268 assert_eq!(extract_issuer_from_payload(&token), None);
1269 }
1270
1271 #[tokio::test]
1272 async fn test_validate_unmatched_issuer_returns_none() {
1273 let config = test_config();
1274 let client = reqwest::Client::new();
1275 let resolver = FederationResolver::new(&config, client).expect("valid config");
1276
1277 use base64::Engine;
1279 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
1280 .encode(r#"{"alg":"RS256","typ":"JWT","kid":"key-1"}"#);
1281 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
1282 .encode(r#"{"iss":"https://auth.unknown.com","sub":"test","exp":9999999999}"#);
1283 let token = format!("{}.{}.fake-sig", header, payload);
1284
1285 let result = resolver.validate_federated_token(&token).await;
1286 assert!(result.is_ok());
1287 assert!(result.expect("should be Ok").is_none());
1288 }
1289
1290 #[tokio::test]
1291 async fn test_validate_matched_issuer_no_jwks_uri_returns_error() {
1292 let config = FederationConfig {
1293 enabled: true,
1294 trust_anchors: vec![FederationTrustAnchor {
1295 org_id: "no-jwks".to_string(),
1296 display_name: "No JWKS".to_string(),
1297 jwks_uri: None, issuer_pattern: "https://auth.nojwks.com".to_string(),
1299 identity_mappings: vec![],
1300 trust_level: "limited".to_string(),
1301 }],
1302 jwks_cache_ttl_secs: 300,
1303 jwks_fetch_timeout_secs: 10,
1304 expected_audience: None,
1305 };
1306 let client = reqwest::Client::new();
1307 let resolver = FederationResolver::new(&config, client).expect("valid config");
1308
1309 use base64::Engine;
1310 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
1311 .encode(r#"{"alg":"RS256","typ":"JWT","kid":"key-1"}"#);
1312 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
1313 .encode(r#"{"iss":"https://auth.nojwks.com","sub":"test","exp":9999999999}"#);
1314 let token = format!("{}.{}.fake-sig", header, payload);
1315
1316 let result = resolver.validate_federated_token(&token).await;
1317 assert!(result.is_err());
1318 match result.expect_err("should be err") {
1319 FederationError::JwksFetchFailed { org_id, .. } => {
1320 assert_eq!(org_id, "no-jwks");
1321 }
1322 other => panic!("Expected JwksFetchFailed, got: {}", other),
1323 }
1324 }
1325}