1use std::collections::HashMap;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::{Arc, RwLock};
19use std::time::{Duration, Instant};
20
21use jsonwebtoken::jwk::KeyAlgorithm;
22use jsonwebtoken::{decode_header, Algorithm, DecodingKey, Validation};
23use serde::Deserialize;
24use vellaveto_config::abac::FederationConfig;
25use vellaveto_types::abac::{
26 FederationAnchorInfo, FederationAnchorStatus, FederationStatus, FederationTrustAnchor,
27};
28use vellaveto_types::identity::AgentIdentity;
29
30#[derive(Debug, Clone)]
36pub struct FederatedIdentity {
37 pub identity: AgentIdentity,
39 pub org_id: String,
41 pub trust_level: String,
43}
44
45#[derive(Debug)]
47pub enum FederationError {
48 JwksFetchFailed { org_id: String, source: String },
50 JwtValidationFailed { org_id: String, source: String },
52 NoMatchingKey { org_id: String, kid: String },
54 DisallowedAlgorithm(String),
56 InvalidHeader(String),
58}
59
60impl std::fmt::Display for FederationError {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::JwksFetchFailed { org_id, source } => {
64 write!(f, "JWKS fetch failed for {org_id}: {source}")
65 }
66 Self::JwtValidationFailed { org_id, source } => {
67 write!(f, "JWT validation failed for {org_id}: {source}")
68 }
69 Self::NoMatchingKey { org_id, kid } => {
70 let safe_kid = vellaveto_types::sanitize_for_log(kid, 128);
75 write!(
76 f,
77 "no matching key in JWKS for org {org_id}, kid '{safe_kid}'"
78 )
79 }
80 Self::DisallowedAlgorithm(alg) => {
81 write!(f, "disallowed JWT algorithm: {alg}")
82 }
83 Self::InvalidHeader(msg) => {
84 write!(f, "invalid JWT header: {msg}")
85 }
86 }
87 }
88}
89
90impl std::error::Error for FederationError {}
91
92struct CachedJwks {
94 keys: jsonwebtoken::jwk::JwkSet,
95 fetched_at: Instant,
96}
97
98struct CompiledAnchor {
100 config: FederationTrustAnchor,
101 jwks_cache: RwLock<Option<CachedJwks>>,
102 success_count: AtomicU64,
103 failure_count: AtomicU64,
104}
105
106const MAX_EXTRA_CLAIMS: usize = 50;
108
109const MAX_EXTRA_CLAIM_VALUE_BYTES: usize = 8192;
111
112const MAX_CLAIM_ELEMENT_LEN: usize = 1024;
114
115const MAX_CLAIM_JOINED_LEN: usize = 8192;
117
118const MAX_TEMPLATE_CLAIM_VALUE_LEN: usize = 1024;
120
121#[derive(Debug, Deserialize)]
123struct FederatedClaims {
124 #[serde(default)]
125 sub: Option<String>,
126 #[serde(default)]
127 iss: Option<String>,
128 #[serde(default)]
129 email: Option<String>,
130 #[serde(flatten)]
132 extra: HashMap<String, serde_json::Value>,
133}
134
135impl FederatedClaims {
136 fn validate_extra_claims(&self) -> Result<(), String> {
141 if self.extra.len() > MAX_EXTRA_CLAIMS {
142 return Err(format!(
143 "federated JWT has {} extra claims, max is {}",
144 self.extra.len(),
145 MAX_EXTRA_CLAIMS
146 ));
147 }
148 for (key, value) in &self.extra {
149 let serialized_len = serde_json::to_string(value).map(|s| s.len()).unwrap_or(0);
150 if serialized_len > MAX_EXTRA_CLAIM_VALUE_BYTES {
151 return Err(format!(
152 "federated JWT extra claim '{key}' is {serialized_len} bytes, max is {MAX_EXTRA_CLAIM_VALUE_BYTES}"
153 ));
154 }
155 }
156 Ok(())
157 }
158}
159
160const ALLOWED_ALGORITHMS: &[Algorithm] = &[
162 Algorithm::RS256,
163 Algorithm::RS384,
164 Algorithm::RS512,
165 Algorithm::ES256,
166 Algorithm::ES384,
167 Algorithm::PS256,
168 Algorithm::PS384,
169 Algorithm::PS512,
170 Algorithm::EdDSA,
171];
172
173const MAX_JWKS_BODY_BYTES: usize = 1_048_576;
183
184pub struct FederationResolver {
185 anchors: Vec<Arc<CompiledAnchor>>,
186 http_client: reqwest::Client,
187 cache_ttl: Duration,
188 fetch_timeout: Duration,
189 expected_audience: Option<String>,
192}
193
194impl FederationResolver {
195 pub fn new(config: &FederationConfig, http_client: reqwest::Client) -> Result<Self, String> {
197 let mut anchors = Vec::with_capacity(config.trust_anchors.len());
198 for anchor_config in &config.trust_anchors {
199 anchor_config.validate()?;
200 anchors.push(Arc::new(CompiledAnchor {
201 config: anchor_config.clone(),
202 jwks_cache: RwLock::new(None),
203 success_count: AtomicU64::new(0),
204 failure_count: AtomicU64::new(0),
205 }));
206 }
207 Ok(Self {
208 anchors,
209 http_client,
210 cache_ttl: Duration::from_secs(config.jwks_cache_ttl_secs),
211 fetch_timeout: Duration::from_secs(config.jwks_fetch_timeout_secs),
212 expected_audience: config.expected_audience.clone(),
213 })
214 }
215
216 pub async fn validate_federated_token(
222 &self,
223 token: &str,
224 ) -> Result<Option<FederatedIdentity>, FederationError> {
225 let header =
227 decode_header(token).map_err(|e| FederationError::InvalidHeader(e.to_string()))?;
228
229 let alg = header.alg;
230 let alg_str = format!("{alg:?}");
231 if !ALLOWED_ALGORITHMS.contains(&alg) {
232 return Err(FederationError::DisallowedAlgorithm(alg_str));
233 }
234
235 let kid = match header.kid {
239 Some(ref k) if !k.is_empty() => k.clone(),
240 _ => {
241 return Err(FederationError::InvalidHeader(
242 "JWT header missing required 'kid' field for federation".to_string(),
243 ))
244 }
245 };
246
247 let issuer = extract_issuer_from_payload(token)
249 .ok_or_else(|| FederationError::InvalidHeader("missing iss claim".to_string()))?;
250
251 let anchor = match self.find_matching_anchor(&issuer) {
257 Some(a) => a,
258 None => return Ok(None), };
260
261 let decoding_key =
263 self.get_decoding_key(&anchor, &kid, &alg)
264 .await
265 .inspect_err(|_| {
266 let _ = anchor.failure_count.fetch_update(
267 Ordering::SeqCst,
268 Ordering::SeqCst,
269 |v| Some(v.saturating_add(1)),
270 );
271 })?;
272
273 let mut validation = Validation::new(alg);
275 validation.validate_exp = true;
276 validation.validate_nbf = true;
279 validation.set_issuer(&[&issuer]);
280 if let Some(ref aud) = self.expected_audience {
283 validation.validate_aud = true;
284 validation.set_audience(&[aud]);
285 } else {
286 validation.validate_aud = false;
287 }
288 validation.leeway = 60;
290
291 let token_data = jsonwebtoken::decode::<FederatedClaims>(token, &decoding_key, &validation)
292 .map_err(|e| {
293 let _ =
294 anchor
295 .failure_count
296 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
297 Some(v.saturating_add(1))
298 });
299 FederationError::JwtValidationFailed {
300 org_id: anchor.config.org_id.clone(),
301 source: e.to_string(),
302 }
303 })?;
304
305 if let Err(reason) = token_data.claims.validate_extra_claims() {
307 let _ = anchor
308 .failure_count
309 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
310 Some(v.saturating_add(1))
311 });
312 return Err(FederationError::JwtValidationFailed {
313 org_id: anchor.config.org_id.clone(),
314 source: reason,
315 });
316 }
317
318 if let Some(ref verified_iss) = token_data.claims.iss {
324 if !issuer_pattern_matches(&anchor.config.issuer_pattern, verified_iss) {
325 let _ =
326 anchor
327 .failure_count
328 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
329 Some(v.saturating_add(1))
330 });
331 return Err(FederationError::JwtValidationFailed {
332 org_id: anchor.config.org_id.clone(),
333 source: format!(
334 "verified token issuer '{}' does not match anchor pattern '{}'",
335 verified_iss, anchor.config.issuer_pattern
336 ),
337 });
338 }
339 }
340
341 let identity = self.apply_identity_mappings(&anchor, &token_data.claims);
343
344 let _ = anchor
345 .success_count
346 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
347 Some(v.saturating_add(1))
348 });
349
350 Ok(Some(FederatedIdentity {
351 identity,
352 org_id: anchor.config.org_id.clone(),
353 trust_level: anchor.config.trust_level.clone(),
354 }))
355 }
356
357 pub fn status(&self) -> FederationStatus {
359 FederationStatus {
360 enabled: true,
361 trust_anchor_count: self.anchors.len(),
362 anchors: self
363 .anchors
364 .iter()
365 .map(|a| {
366 let cache = a.jwks_cache.read().ok();
367 let (cached, last_fetched) = match cache.as_ref().and_then(|c| c.as_ref()) {
368 Some(c) => (true, Some(format!("{:?}", c.fetched_at.elapsed()))),
369 None => (false, None),
370 };
371 FederationAnchorStatus {
372 org_id: a.config.org_id.clone(),
373 display_name: a.config.display_name.clone(),
374 issuer_pattern: a.config.issuer_pattern.clone(),
375 trust_level: a.config.trust_level.clone(),
376 has_jwks_uri: a.config.jwks_uri.is_some(),
377 jwks_cached: cached,
378 jwks_last_fetched: last_fetched,
379 identity_mapping_count: a.config.identity_mappings.len(),
380 successful_validations: a.success_count.load(Ordering::SeqCst),
381 failed_validations: a.failure_count.load(Ordering::SeqCst),
382 }
383 })
384 .collect(),
385 }
386 }
387
388 pub fn anchor_info(&self) -> Vec<FederationAnchorInfo> {
390 self.anchors
391 .iter()
392 .map(|a| FederationAnchorInfo {
393 org_id: a.config.org_id.clone(),
394 display_name: a.config.display_name.clone(),
395 issuer_pattern: a.config.issuer_pattern.clone(),
396 trust_level: a.config.trust_level.clone(),
397 has_jwks_uri: a.config.jwks_uri.is_some(),
398 identity_mapping_count: a.config.identity_mappings.len(),
399 })
400 .collect()
401 }
402
403 fn find_matching_anchor(&self, issuer: &str) -> Option<Arc<CompiledAnchor>> {
406 self.anchors
407 .iter()
408 .find(|a| issuer_pattern_matches(&a.config.issuer_pattern, issuer))
409 .cloned()
410 }
411
412 async fn get_decoding_key(
413 &self,
414 anchor: &CompiledAnchor,
415 kid: &str,
416 alg: &Algorithm,
417 ) -> Result<DecodingKey, FederationError> {
418 let jwks_uri =
419 anchor
420 .config
421 .jwks_uri
422 .as_deref()
423 .ok_or_else(|| FederationError::JwksFetchFailed {
424 org_id: anchor.config.org_id.clone(),
425 source: "no jwks_uri configured".to_string(),
426 })?;
427
428 {
433 match anchor.jwks_cache.read() {
434 Ok(cache_guard) => {
435 if let Some(ref cached) = *cache_guard {
436 if cached.fetched_at.elapsed() < self.cache_ttl {
437 return find_key_in_jwks(&cached.keys, kid, alg, &anchor.config.org_id);
438 }
439 }
440 }
441 Err(_) => {
442 tracing::error!(
443 target: "vellaveto::security",
444 org_id = %anchor.config.org_id,
445 "JWKS cache read lock poisoned — treating as cache miss (fail-closed)"
446 );
447 }
449 }
450 }
451
452 let jwks = self.fetch_jwks(jwks_uri, &anchor.config.org_id).await?;
454
455 let result = find_key_in_jwks(&jwks, kid, alg, &anchor.config.org_id);
457
458 {
462 let mut cache_guard = anchor.jwks_cache.write().unwrap_or_else(|e| {
463 tracing::warn!(
464 "JWKS cache write lock poisoned for org '{}'; clearing stale cache",
465 anchor.config.org_id
466 );
467 let mut guard = e.into_inner();
468 *guard = None; guard
470 });
471 *cache_guard = Some(CachedJwks {
472 keys: jwks,
473 fetched_at: Instant::now(),
474 });
475 }
476
477 result
478 }
479
480 async fn fetch_jwks(
481 &self,
482 uri: &str,
483 org_id: &str,
484 ) -> Result<jsonwebtoken::jwk::JwkSet, FederationError> {
485 let resp = self
486 .http_client
487 .get(uri)
488 .timeout(self.fetch_timeout)
489 .send()
490 .await
491 .map_err(|e| FederationError::JwksFetchFailed {
492 org_id: org_id.to_string(),
493 source: e.to_string(),
494 })?;
495
496 if !resp.status().is_success() {
497 return Err(FederationError::JwksFetchFailed {
498 org_id: org_id.to_string(),
499 source: format!("HTTP {}", resp.status()),
500 });
501 }
502
503 if let Some(len) = resp.content_length() {
507 if len > MAX_JWKS_BODY_BYTES as u64 {
508 return Err(FederationError::JwksFetchFailed {
509 org_id: org_id.to_string(),
510 source: format!(
511 "JWKS Content-Length {len} exceeds {MAX_JWKS_BODY_BYTES} byte limit"
512 ),
513 });
514 }
515 }
516
517 let capacity = std::cmp::min(
520 resp.content_length().unwrap_or(8192) as usize,
521 MAX_JWKS_BODY_BYTES,
522 );
523 let mut body = Vec::with_capacity(capacity);
524 let mut resp = resp;
525 while let Some(chunk) =
526 resp.chunk()
527 .await
528 .map_err(|e| FederationError::JwksFetchFailed {
529 org_id: org_id.to_string(),
530 source: e.to_string(),
531 })?
532 {
533 if body.len().saturating_add(chunk.len()) > MAX_JWKS_BODY_BYTES {
534 return Err(FederationError::JwksFetchFailed {
535 org_id: org_id.to_string(),
536 source: format!("JWKS response exceeds {MAX_JWKS_BODY_BYTES} byte limit"),
537 });
538 }
539 body.extend_from_slice(&chunk);
540 }
541
542 serde_json::from_slice(&body).map_err(|e| FederationError::JwksFetchFailed {
543 org_id: org_id.to_string(),
544 source: format!("invalid JWKS JSON: {e}"),
545 })
546 }
547
548 fn apply_identity_mappings(
549 &self,
550 anchor: &CompiledAnchor,
551 claims: &FederatedClaims,
552 ) -> AgentIdentity {
553 let mut identity_claims: HashMap<String, serde_json::Value> = HashMap::new();
554
555 identity_claims.insert(
557 "federation.org_id".to_string(),
558 serde_json::Value::String(anchor.config.org_id.clone()),
559 );
560 identity_claims.insert(
561 "federation.trust_level".to_string(),
562 serde_json::Value::String(anchor.config.trust_level.clone()),
563 );
564 if let Some(ref iss) = claims.iss {
565 identity_claims.insert(
566 "federation.issuer".to_string(),
567 serde_json::Value::String(iss.clone()),
568 );
569 }
570
571 let mut subject = claims.sub.clone();
572
573 for mapping in &anchor.config.identity_mappings {
575 if let Some(value) = extract_claim_value(claims, &mapping.external_claim) {
576 let sanitized = sanitize_claim_for_template(&value);
581
582 let rendered = mapping
583 .id_template
584 .replace("{claim_value}", &sanitized)
585 .replace("{org_id}", &anchor.config.org_id);
586
587 identity_claims.insert(
588 "principal.type".to_string(),
589 serde_json::Value::String(mapping.internal_principal_type.clone()),
590 );
591 identity_claims.insert(
592 "principal.id".to_string(),
593 serde_json::Value::String(rendered.clone()),
594 );
595
596 if mapping.external_claim == "sub" || mapping.external_claim == "email" {
598 subject = Some(rendered);
599 }
600 }
601 }
602
603 AgentIdentity {
604 issuer: claims.iss.clone(),
605 subject,
606 audience: Vec::new(),
607 claims: identity_claims,
608 }
609 }
610}
611
612fn sanitize_claim_for_template(value: &str) -> String {
622 let sanitized: String = value
623 .chars()
624 .filter(|c| {
626 !c.is_control()
627 && !vellaveto_types::is_unicode_format_char(*c)
628 && *c != '{'
629 && *c != '}'
630 })
631 .collect();
632 if sanitized.len() > MAX_TEMPLATE_CLAIM_VALUE_LEN {
633 let mut end = MAX_TEMPLATE_CLAIM_VALUE_LEN;
635 while end > 0 && !sanitized.is_char_boundary(end) {
636 end -= 1;
637 }
638 sanitized[..end].to_string()
639 } else {
640 sanitized
641 }
642}
643
644fn extract_claim_value(claims: &FederatedClaims, claim_path: &str) -> Option<String> {
646 match claim_path {
648 "sub" => return claims.sub.clone(),
649 "iss" => return claims.iss.clone(),
650 "email" => return claims.email.clone(),
651 _ => {}
652 }
653
654 let parts: Vec<&str> = claim_path.splitn(10, '.').collect();
656 let mut current: Option<&serde_json::Value> = claims.extra.get(parts[0]);
657
658 for part in &parts[1..] {
659 current = current.and_then(|v| v.get(part));
660 }
661
662 current.and_then(|v| match v {
663 serde_json::Value::String(s) => Some(s.clone()),
664 serde_json::Value::Array(arr) => {
665 let mut total_len = 0usize;
667 let joined: Vec<String> = arr
668 .iter()
669 .take(64) .filter_map(|item| {
671 item.as_str().map(|s| {
672 if s.len() > MAX_CLAIM_ELEMENT_LEN {
674 s[..MAX_CLAIM_ELEMENT_LEN].to_string()
675 } else {
676 s.to_string()
677 }
678 })
679 })
680 .take_while(|s| {
681 let added = if total_len == 0 { s.len() } else { s.len() + 1 };
683 if total_len + added > MAX_CLAIM_JOINED_LEN {
684 return false;
685 }
686 total_len += added;
687 true
688 })
689 .collect();
690 if joined.is_empty() {
691 None
692 } else {
693 Some(joined.join(","))
694 }
695 }
696 serde_json::Value::Number(n) => Some(n.to_string()),
697 serde_json::Value::Bool(b) => Some(b.to_string()),
698 _ => None,
699 })
700}
701
702const MAX_JWT_PAYLOAD_B64_LEN: usize = 65_536;
706
707fn extract_issuer_from_payload(token: &str) -> Option<String> {
709 let parts: Vec<&str> = token.splitn(4, '.').collect();
710 if parts.len() < 2 {
711 return None;
712 }
713 if parts[1].len() > MAX_JWT_PAYLOAD_B64_LEN {
715 return None;
716 }
717 use base64::Engine;
718 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
719 .decode(parts[1])
720 .ok()?;
721 let payload: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
722 payload.get("iss")?.as_str().map(String::from)
723}
724
725fn issuer_pattern_matches(pattern: &str, issuer: &str) -> bool {
730 if pattern == issuer {
731 return true;
732 }
733 if !pattern.contains('*') {
734 return false;
735 }
736
737 let normalized: String = {
740 let mut result = String::with_capacity(pattern.len());
741 let mut prev_star = false;
742 for ch in pattern.chars() {
743 if ch == '*' {
744 if !prev_star {
745 result.push(ch);
746 }
747 prev_star = true;
748 } else {
749 result.push(ch);
750 prev_star = false;
751 }
752 }
753 result
754 };
755
756 let parts: Vec<&str> = normalized.split('*').collect();
758
759 if parts.len() > 11 {
761 return false;
762 }
763
764 let mut remaining = issuer;
765 for (i, part) in parts.iter().enumerate() {
766 if part.is_empty() {
767 continue;
768 }
769 if i == 0 {
770 if !remaining.starts_with(part) {
772 return false;
773 }
774 remaining = &remaining[part.len()..];
775 } else if i == parts.len() - 1 && !part.is_empty() {
776 if !remaining.ends_with(part) {
778 return false;
779 }
780 remaining = "";
781 } else {
782 match remaining.find(part) {
784 Some(pos) => remaining = &remaining[pos + part.len()..],
785 None => return false,
786 }
787 }
788 }
789 true
790}
791
792fn key_algorithm_to_algorithm(ka: &KeyAlgorithm) -> Option<Algorithm> {
799 match ka {
800 KeyAlgorithm::HS256 => Some(Algorithm::HS256),
801 KeyAlgorithm::HS384 => Some(Algorithm::HS384),
802 KeyAlgorithm::HS512 => Some(Algorithm::HS512),
803 KeyAlgorithm::ES256 => Some(Algorithm::ES256),
804 KeyAlgorithm::ES384 => Some(Algorithm::ES384),
805 KeyAlgorithm::RS256 => Some(Algorithm::RS256),
806 KeyAlgorithm::RS384 => Some(Algorithm::RS384),
807 KeyAlgorithm::RS512 => Some(Algorithm::RS512),
808 KeyAlgorithm::PS256 => Some(Algorithm::PS256),
809 KeyAlgorithm::PS384 => Some(Algorithm::PS384),
810 KeyAlgorithm::PS512 => Some(Algorithm::PS512),
811 KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
812 _ => None,
814 }
815}
816
817fn find_key_in_jwks(
825 jwks: &jsonwebtoken::jwk::JwkSet,
826 kid: &str,
827 alg: &Algorithm,
828 org_id: &str,
829) -> Result<DecodingKey, FederationError> {
830 for key in &jwks.keys {
831 match &key.common.key_id {
834 Some(key_kid) => {
835 if !kid.is_empty() && key_kid != kid {
836 continue; }
838 }
839 None => {
840 tracing::warn!(
842 org_id = %org_id,
843 "skipping JWK without kid field in JWKS for org '{}'",
844 org_id
845 );
846 continue;
847 }
848 }
849
850 if let Some(ref key_alg) = key.common.key_algorithm {
852 match key_algorithm_to_algorithm(key_alg) {
853 Some(mapped) if &mapped == alg => {} _ => continue, }
856 }
857
858 if let Ok(dk) = DecodingKey::from_jwk(key) {
860 return Ok(dk);
861 }
862 }
863
864 Err(FederationError::NoMatchingKey {
865 org_id: org_id.to_string(),
866 kid: kid.to_string(),
867 })
868}
869
870#[cfg(test)]
875mod tests {
876 use super::*;
877
878 fn test_config() -> FederationConfig {
879 FederationConfig {
880 enabled: true,
881 trust_anchors: vec![FederationTrustAnchor {
882 org_id: "partner-org".to_string(),
883 display_name: "Partner Organization".to_string(),
884 jwks_uri: Some("https://auth.partner.com/.well-known/jwks.json".to_string()),
885 issuer_pattern: "https://auth.partner.com".to_string(),
886 identity_mappings: vec![vellaveto_types::abac::IdentityMapping {
887 external_claim: "sub".to_string(),
888 internal_principal_type: "agent".to_string(),
889 id_template: "partner-org:{claim_value}".to_string(),
890 }],
891 trust_level: "limited".to_string(),
892 }],
893 jwks_cache_ttl_secs: 300,
894 jwks_fetch_timeout_secs: 10,
895 expected_audience: None,
896 }
897 }
898
899 fn test_config_wildcard() -> FederationConfig {
900 FederationConfig {
901 enabled: true,
902 trust_anchors: vec![FederationTrustAnchor {
903 org_id: "acme".to_string(),
904 display_name: "ACME Corp".to_string(),
905 jwks_uri: Some("https://auth.acme.com/.well-known/jwks.json".to_string()),
906 issuer_pattern: "https://auth.acme.com/*".to_string(),
907 identity_mappings: vec![],
908 trust_level: "full".to_string(),
909 }],
910 jwks_cache_ttl_secs: 300,
911 jwks_fetch_timeout_secs: 10,
912 expected_audience: None,
913 }
914 }
915
916 #[test]
917 fn test_new_valid_config_compiles_anchors() {
918 let config = test_config();
919 let client = reqwest::Client::new();
920 let resolver = FederationResolver::new(&config, client);
921 assert!(resolver.is_ok());
922 let resolver = resolver.expect("test config should be valid");
923 assert_eq!(resolver.anchors.len(), 1);
924 }
925
926 #[test]
927 fn test_new_invalid_anchor_fails() {
928 let config = FederationConfig {
929 enabled: true,
930 trust_anchors: vec![FederationTrustAnchor {
931 org_id: String::new(), display_name: "Bad".to_string(),
933 jwks_uri: None,
934 issuer_pattern: "https://x.com".to_string(),
935 identity_mappings: vec![],
936 trust_level: "limited".to_string(),
937 }],
938 jwks_cache_ttl_secs: 300,
939 jwks_fetch_timeout_secs: 10,
940 expected_audience: None,
941 };
942 let client = reqwest::Client::new();
943 assert!(FederationResolver::new(&config, client).is_err());
944 }
945
946 #[test]
947 fn test_issuer_pattern_exact_match() {
948 assert!(issuer_pattern_matches(
949 "https://auth.partner.com",
950 "https://auth.partner.com"
951 ));
952 }
953
954 #[test]
955 fn test_issuer_pattern_no_match() {
956 assert!(!issuer_pattern_matches(
957 "https://auth.partner.com",
958 "https://auth.evil.com"
959 ));
960 }
961
962 #[test]
963 fn test_issuer_pattern_glob_wildcard_suffix() {
964 assert!(issuer_pattern_matches(
965 "https://auth.acme.com/*",
966 "https://auth.acme.com/tenant-1"
967 ));
968 }
969
970 #[test]
971 fn test_issuer_pattern_glob_wildcard_middle() {
972 assert!(issuer_pattern_matches(
973 "https://*.acme.com/auth",
974 "https://tenant1.acme.com/auth"
975 ));
976 }
977
978 #[test]
979 fn test_issuer_pattern_glob_wildcard_no_match() {
980 assert!(!issuer_pattern_matches(
981 "https://auth.acme.com/*",
982 "https://auth.evil.com/tenant"
983 ));
984 }
985
986 #[test]
988 fn test_issuer_pattern_consecutive_wildcards_collapsed() {
989 assert!(issuer_pattern_matches(
991 "https://auth.acme.com/**",
992 "https://auth.acme.com/tenant-1"
993 ));
994 assert!(!issuer_pattern_matches(
995 "https://auth.acme.com/**",
996 "https://auth.evil.com/tenant"
997 ));
998 }
999
1000 #[test]
1002 fn test_issuer_pattern_star_dot_star() {
1003 assert!(issuer_pattern_matches(
1004 "https://*.*",
1005 "https://auth.example.com"
1006 ));
1007 assert!(!issuer_pattern_matches(
1008 "https://*.*",
1009 "https://nosubdomain"
1010 ));
1011 }
1012
1013 #[test]
1014 fn test_find_matching_anchor_exact() {
1015 let config = test_config();
1016 let client = reqwest::Client::new();
1017 let resolver = FederationResolver::new(&config, client).expect("valid config");
1018 let anchor = resolver.find_matching_anchor("https://auth.partner.com");
1019 assert!(anchor.is_some());
1020 assert_eq!(anchor.expect("should match").config.org_id, "partner-org");
1021 }
1022
1023 #[test]
1024 fn test_find_matching_anchor_wildcard() {
1025 let config = test_config_wildcard();
1026 let client = reqwest::Client::new();
1027 let resolver = FederationResolver::new(&config, client).expect("valid config");
1028 let anchor = resolver.find_matching_anchor("https://auth.acme.com/tenant-1");
1029 assert!(anchor.is_some());
1030 assert_eq!(anchor.expect("should match").config.org_id, "acme");
1031 }
1032
1033 #[test]
1034 fn test_find_matching_anchor_no_match() {
1035 let config = test_config();
1036 let client = reqwest::Client::new();
1037 let resolver = FederationResolver::new(&config, client).expect("valid config");
1038 assert!(resolver
1039 .find_matching_anchor("https://auth.evil.com")
1040 .is_none());
1041 }
1042
1043 #[test]
1044 fn test_extract_claim_value_sub() {
1045 let claims = FederatedClaims {
1046 sub: Some("agent-123".to_string()),
1047 iss: Some("https://auth.example.com".to_string()),
1048 email: None,
1049 extra: HashMap::new(),
1050 };
1051 assert_eq!(
1052 extract_claim_value(&claims, "sub"),
1053 Some("agent-123".to_string())
1054 );
1055 }
1056
1057 #[test]
1058 fn test_extract_claim_value_nested() {
1059 let mut extra = HashMap::new();
1060 extra.insert(
1061 "realm_access".to_string(),
1062 serde_json::json!({"roles": ["admin", "user"]}),
1063 );
1064 let claims = FederatedClaims {
1065 sub: None,
1066 iss: None,
1067 email: None,
1068 extra,
1069 };
1070 assert_eq!(
1071 extract_claim_value(&claims, "realm_access.roles"),
1072 Some("admin,user".to_string())
1073 );
1074 }
1075
1076 #[test]
1077 fn test_extract_claim_value_missing() {
1078 let claims = FederatedClaims {
1079 sub: None,
1080 iss: None,
1081 email: None,
1082 extra: HashMap::new(),
1083 };
1084 assert_eq!(extract_claim_value(&claims, "nonexistent"), None);
1085 }
1086
1087 #[test]
1088 fn test_apply_identity_mappings_injects_federation_metadata() {
1089 let config = test_config();
1090 let client = reqwest::Client::new();
1091 let resolver = FederationResolver::new(&config, client).expect("valid config");
1092 let claims = FederatedClaims {
1093 sub: Some("agent-456".to_string()),
1094 iss: Some("https://auth.partner.com".to_string()),
1095 email: None,
1096 extra: HashMap::new(),
1097 };
1098 let identity = resolver.apply_identity_mappings(&resolver.anchors[0], &claims);
1099 assert_eq!(
1100 identity.claims.get("federation.org_id"),
1101 Some(&serde_json::Value::String("partner-org".to_string()))
1102 );
1103 assert_eq!(
1104 identity.claims.get("federation.trust_level"),
1105 Some(&serde_json::Value::String("limited".to_string()))
1106 );
1107 assert_eq!(
1108 identity.claims.get("federation.issuer"),
1109 Some(&serde_json::Value::String(
1110 "https://auth.partner.com".to_string()
1111 ))
1112 );
1113 }
1114
1115 #[test]
1116 fn test_apply_identity_mappings_template_substitution() {
1117 let config = test_config();
1118 let client = reqwest::Client::new();
1119 let resolver = FederationResolver::new(&config, client).expect("valid config");
1120 let claims = FederatedClaims {
1121 sub: Some("agent-789".to_string()),
1122 iss: Some("https://auth.partner.com".to_string()),
1123 email: None,
1124 extra: HashMap::new(),
1125 };
1126 let identity = resolver.apply_identity_mappings(&resolver.anchors[0], &claims);
1127 assert_eq!(
1128 identity.claims.get("principal.id"),
1129 Some(&serde_json::Value::String(
1130 "partner-org:agent-789".to_string()
1131 ))
1132 );
1133 assert_eq!(
1134 identity.claims.get("principal.type"),
1135 Some(&serde_json::Value::String("agent".to_string()))
1136 );
1137 }
1138
1139 #[test]
1141 fn test_apply_identity_mappings_sanitizes_control_chars() {
1142 let config = test_config();
1143 let client = reqwest::Client::new();
1144 let resolver = FederationResolver::new(&config, client).expect("valid config");
1145 let claims = FederatedClaims {
1146 sub: Some("agent\x00\x0A\x0Dinjected".to_string()),
1147 iss: Some("https://auth.partner.com".to_string()),
1148 email: None,
1149 extra: HashMap::new(),
1150 };
1151 let identity = resolver.apply_identity_mappings(&resolver.anchors[0], &claims);
1152 let principal_id = identity
1153 .claims
1154 .get("principal.id")
1155 .and_then(|v| v.as_str())
1156 .expect("principal.id must exist");
1157 assert!(!principal_id.contains('\x00'));
1159 assert!(!principal_id.contains('\x0A'));
1160 assert!(!principal_id.contains('\x0D'));
1161 assert!(principal_id.contains("agentinjected"));
1162 }
1163
1164 #[test]
1166 fn test_apply_identity_mappings_strips_template_syntax() {
1167 let config = test_config();
1168 let client = reqwest::Client::new();
1169 let resolver = FederationResolver::new(&config, client).expect("valid config");
1170 let claims = FederatedClaims {
1171 sub: Some("{claim_value}evil{org_id}".to_string()),
1172 iss: Some("https://auth.partner.com".to_string()),
1173 email: None,
1174 extra: HashMap::new(),
1175 };
1176 let identity = resolver.apply_identity_mappings(&resolver.anchors[0], &claims);
1177 let principal_id = identity
1178 .claims
1179 .get("principal.id")
1180 .and_then(|v| v.as_str())
1181 .expect("principal.id must exist");
1182 assert!(!principal_id.contains('{'));
1184 assert!(!principal_id.contains('}'));
1185 assert!(principal_id.contains("claim_valueevilorg_id"));
1186 }
1187
1188 #[test]
1190 fn test_sanitize_claim_for_template_truncates_long_values() {
1191 let long_value = "a".repeat(2000);
1192 let sanitized = sanitize_claim_for_template(&long_value);
1193 assert!(sanitized.len() <= MAX_TEMPLATE_CLAIM_VALUE_LEN);
1194 }
1195
1196 #[test]
1198 fn test_key_algorithm_to_algorithm_explicit_mapping() {
1199 assert_eq!(
1200 key_algorithm_to_algorithm(&KeyAlgorithm::RS256),
1201 Some(Algorithm::RS256)
1202 );
1203 assert_eq!(
1204 key_algorithm_to_algorithm(&KeyAlgorithm::ES256),
1205 Some(Algorithm::ES256)
1206 );
1207 assert_eq!(
1208 key_algorithm_to_algorithm(&KeyAlgorithm::PS256),
1209 Some(Algorithm::PS256)
1210 );
1211 assert_eq!(
1212 key_algorithm_to_algorithm(&KeyAlgorithm::EdDSA),
1213 Some(Algorithm::EdDSA)
1214 );
1215 assert_eq!(key_algorithm_to_algorithm(&KeyAlgorithm::RSA1_5), None);
1217 }
1218
1219 #[test]
1220 fn test_status_reports_anchors() {
1221 let config = test_config();
1222 let client = reqwest::Client::new();
1223 let resolver = FederationResolver::new(&config, client).expect("valid config");
1224 let status = resolver.status();
1225 assert!(status.enabled);
1226 assert_eq!(status.trust_anchor_count, 1);
1227 assert_eq!(status.anchors.len(), 1);
1228 assert_eq!(status.anchors[0].org_id, "partner-org");
1229 assert_eq!(status.anchors[0].trust_level, "limited");
1230 assert!(!status.anchors[0].jwks_cached);
1231 }
1232
1233 #[test]
1234 fn test_anchor_info() {
1235 let config = test_config();
1236 let client = reqwest::Client::new();
1237 let resolver = FederationResolver::new(&config, client).expect("valid config");
1238 let infos = resolver.anchor_info();
1239 assert_eq!(infos.len(), 1);
1240 assert_eq!(infos[0].org_id, "partner-org");
1241 assert!(infos[0].has_jwks_uri);
1242 assert_eq!(infos[0].identity_mapping_count, 1);
1243 }
1244
1245 #[test]
1246 fn test_extract_issuer_from_payload_valid() {
1247 use base64::Engine;
1249 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
1250 .encode(r#"{"alg":"RS256","typ":"JWT"}"#);
1251 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
1252 .encode(r#"{"iss":"https://auth.example.com","sub":"test"}"#);
1253 let token = format!("{header}.{payload}.fake-sig");
1254 assert_eq!(
1255 extract_issuer_from_payload(&token),
1256 Some("https://auth.example.com".to_string())
1257 );
1258 }
1259
1260 #[test]
1261 fn test_extract_issuer_from_payload_missing_iss() {
1262 use base64::Engine;
1263 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256"}"#);
1264 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"sub":"test"}"#);
1265 let token = format!("{header}.{payload}.sig");
1266 assert_eq!(extract_issuer_from_payload(&token), None);
1267 }
1268
1269 #[tokio::test]
1270 async fn test_validate_unmatched_issuer_returns_none() {
1271 let config = test_config();
1272 let client = reqwest::Client::new();
1273 let resolver = FederationResolver::new(&config, client).expect("valid config");
1274
1275 use base64::Engine;
1277 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
1278 .encode(r#"{"alg":"RS256","typ":"JWT","kid":"key-1"}"#);
1279 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
1280 .encode(r#"{"iss":"https://auth.unknown.com","sub":"test","exp":9999999999}"#);
1281 let token = format!("{header}.{payload}.fake-sig");
1282
1283 let result = resolver.validate_federated_token(&token).await;
1284 assert!(result.is_ok());
1285 assert!(result.expect("should be Ok").is_none());
1286 }
1287
1288 #[tokio::test]
1289 async fn test_validate_matched_issuer_no_jwks_uri_returns_error() {
1290 let config = FederationConfig {
1291 enabled: true,
1292 trust_anchors: vec![FederationTrustAnchor {
1293 org_id: "no-jwks".to_string(),
1294 display_name: "No JWKS".to_string(),
1295 jwks_uri: None, issuer_pattern: "https://auth.nojwks.com".to_string(),
1297 identity_mappings: vec![],
1298 trust_level: "limited".to_string(),
1299 }],
1300 jwks_cache_ttl_secs: 300,
1301 jwks_fetch_timeout_secs: 10,
1302 expected_audience: None,
1303 };
1304 let client = reqwest::Client::new();
1305 let resolver = FederationResolver::new(&config, client).expect("valid config");
1306
1307 use base64::Engine;
1308 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
1309 .encode(r#"{"alg":"RS256","typ":"JWT","kid":"key-1"}"#);
1310 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
1311 .encode(r#"{"iss":"https://auth.nojwks.com","sub":"test","exp":9999999999}"#);
1312 let token = format!("{header}.{payload}.fake-sig");
1313
1314 let result = resolver.validate_federated_token(&token).await;
1315 assert!(result.is_err());
1316 match result.expect_err("should be err") {
1317 FederationError::JwksFetchFailed { org_id, .. } => {
1318 assert_eq!(org_id, "no-jwks");
1319 }
1320 other => panic!("Expected JwksFetchFailed, got: {other}"),
1321 }
1322 }
1323}