1use jsonwebtoken::{
26 decode, decode_header,
27 jwk::{JwkSet, KeyAlgorithm},
28 Algorithm, DecodingKey, TokenData, Validation,
29};
30use serde::{Deserialize, Serialize};
31use std::{
32 collections::HashMap,
33 time::{Duration, Instant},
34};
35use tokio::sync::RwLock;
36
37use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
38use sha2::{Digest, Sha256};
39
40fn contains_control_chars(s: &str) -> bool {
46 vellaveto_types::has_dangerous_chars(s)
47}
48
49use vellaveto_types::uri_util::normalize_dpop_htu;
53
54#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
56#[serde(rename_all = "snake_case")]
57pub enum DpopMode {
58 #[default]
60 Off,
61 Optional,
63 Required,
65}
66
67#[derive(Debug, Clone)]
69pub struct OAuthConfig {
70 pub issuer: String,
73
74 pub audience: String,
77
78 pub jwks_uri: Option<String>,
81
82 pub required_scopes: Vec<String>,
85
86 pub pass_through: bool,
89
90 pub allowed_algorithms: Vec<Algorithm>,
98
99 pub expected_resource: Option<String>,
103
104 pub clock_skew_leeway: Duration,
107
108 pub require_audience: bool,
111
112 pub dpop_mode: DpopMode,
114
115 pub dpop_allowed_algorithms: Vec<Algorithm>,
117
118 pub dpop_require_ath: bool,
120
121 pub dpop_max_clock_skew: Duration,
123}
124
125pub fn default_allowed_algorithms() -> Vec<Algorithm> {
131 vec![
132 Algorithm::RS256,
133 Algorithm::RS384,
134 Algorithm::RS512,
135 Algorithm::ES256,
136 Algorithm::ES384,
137 Algorithm::PS256,
138 Algorithm::PS384,
139 Algorithm::PS512,
140 Algorithm::EdDSA,
141 ]
142}
143
144pub fn default_dpop_allowed_algorithms() -> Vec<Algorithm> {
148 vec![Algorithm::ES256, Algorithm::EdDSA]
149}
150
151impl OAuthConfig {
152 pub fn effective_jwks_uri(&self) -> String {
154 self.jwks_uri.clone().unwrap_or_else(|| {
155 let base = self.issuer.trim_end_matches('/');
156 format!("{}/.well-known/jwks.json", base)
157 })
158 }
159}
160
161pub fn extract_bearer_token(auth_header: &str) -> Result<&str, OAuthError> {
163 let token = if auth_header.len() > 7 && auth_header[..7].eq_ignore_ascii_case("bearer ") {
166 &auth_header[7..]
167 } else {
168 return Err(OAuthError::InvalidFormat);
169 };
170
171 if token.is_empty() {
172 return Err(OAuthError::InvalidFormat);
173 }
174
175 Ok(token)
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct OAuthClaims {
181 #[serde(default)]
183 pub sub: String,
184
185 #[serde(default)]
187 pub iss: String,
188
189 #[serde(default, deserialize_with = "deserialize_aud")]
191 pub aud: Vec<String>,
192
193 #[serde(default)]
195 pub exp: u64,
196
197 #[serde(default)]
199 pub iat: u64,
200
201 #[serde(default)]
203 pub scope: String,
204
205 #[serde(default)]
208 pub resource: Option<String>,
209
210 #[serde(default)]
213 pub cnf: Option<OAuthConfirmationClaim>,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct OAuthConfirmationClaim {
219 #[serde(default)]
221 pub jkt: Option<String>,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
232struct DpopClaims {
233 #[serde(default)]
234 htm: String,
235 #[serde(default)]
236 htu: String,
237 #[serde(default)]
238 iat: u64,
239 #[serde(default)]
240 jti: String,
241 #[serde(default)]
242 ath: Option<String>,
243 #[serde(default)]
245 exp: Option<u64>,
246 #[serde(default)]
248 nbf: Option<u64>,
249 #[serde(flatten)]
252 _extra: serde_json::Map<String, serde_json::Value>,
253}
254
255impl OAuthClaims {
256 pub fn scopes(&self) -> Vec<&str> {
258 if self.scope.is_empty() {
259 Vec::new()
260 } else {
261 self.scope.split(' ').filter(|s| !s.is_empty()).collect()
262 }
263 }
264}
265
266fn deserialize_aud<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
268where
269 D: serde::Deserializer<'de>,
270{
271 use serde::de;
272
273 struct AudVisitor;
274
275 impl<'de> de::Visitor<'de> for AudVisitor {
276 type Value = Vec<String>;
277
278 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
279 f.write_str("a string or array of strings")
280 }
281
282 fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
283 Ok(vec![v.to_string()])
284 }
285
286 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
287 let mut values = Vec::new();
288 while let Some(v) = seq.next_element::<String>()? {
289 values.push(v);
290 }
291 Ok(values)
292 }
293
294 fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
295 Ok(Vec::new())
296 }
297
298 fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
299 Ok(Vec::new())
300 }
301 }
302
303 deserializer.deserialize_any(AudVisitor)
304}
305
306fn jwk_required_str_field<'a>(
307 obj: &'a serde_json::Map<String, serde_json::Value>,
308 field: &str,
309) -> Result<&'a str, OAuthError> {
310 obj.get(field)
311 .and_then(serde_json::Value::as_str)
312 .ok_or_else(|| OAuthError::InvalidDpopProof(format!("DPoP JWK missing '{field}' field")))
313}
314
315fn dpop_jwk_thumbprint_sha256(jwk: &jsonwebtoken::jwk::Jwk) -> Result<String, OAuthError> {
320 let value = serde_json::to_value(jwk)
321 .map_err(|e| OAuthError::InvalidDpopProof(format!("invalid DPoP JWK: {e}")))?;
322 let obj = value
323 .as_object()
324 .ok_or_else(|| OAuthError::InvalidDpopProof("invalid DPoP JWK object".to_string()))?;
325
326 let kty = jwk_required_str_field(obj, "kty")?;
327
328 let canonical = match kty {
329 "EC" => {
330 let crv = jwk_required_str_field(obj, "crv")?;
331 let x = jwk_required_str_field(obj, "x")?;
332 let y = jwk_required_str_field(obj, "y")?;
333 format!(
334 r#"{{"crv":{},"kty":{},"x":{},"y":{}}}"#,
335 serde_json::to_string(crv).map_err(|e| {
336 OAuthError::InvalidDpopProof(format!("failed to encode JWK curve: {e}"))
337 })?,
338 serde_json::to_string(kty).map_err(|e| {
339 OAuthError::InvalidDpopProof(format!("failed to encode JWK type: {e}"))
340 })?,
341 serde_json::to_string(x).map_err(|e| {
342 OAuthError::InvalidDpopProof(format!("failed to encode JWK x: {e}"))
343 })?,
344 serde_json::to_string(y).map_err(|e| {
345 OAuthError::InvalidDpopProof(format!("failed to encode JWK y: {e}"))
346 })?
347 )
348 }
349 "OKP" => {
350 let crv = jwk_required_str_field(obj, "crv")?;
351 let x = jwk_required_str_field(obj, "x")?;
352 format!(
353 r#"{{"crv":{},"kty":{},"x":{}}}"#,
354 serde_json::to_string(crv).map_err(|e| {
355 OAuthError::InvalidDpopProof(format!("failed to encode JWK curve: {e}"))
356 })?,
357 serde_json::to_string(kty).map_err(|e| {
358 OAuthError::InvalidDpopProof(format!("failed to encode JWK type: {e}"))
359 })?,
360 serde_json::to_string(x).map_err(|e| {
361 OAuthError::InvalidDpopProof(format!("failed to encode JWK x: {e}"))
362 })?
363 )
364 }
365 "RSA" => {
366 let e = jwk_required_str_field(obj, "e")?;
367 let n = jwk_required_str_field(obj, "n")?;
368 format!(
369 r#"{{"e":{},"kty":{},"n":{}}}"#,
370 serde_json::to_string(e).map_err(|err| {
371 OAuthError::InvalidDpopProof(format!("failed to encode JWK e: {err}"))
372 })?,
373 serde_json::to_string(kty).map_err(|err| {
374 OAuthError::InvalidDpopProof(format!("failed to encode JWK type: {err}"))
375 })?,
376 serde_json::to_string(n).map_err(|err| {
377 OAuthError::InvalidDpopProof(format!("failed to encode JWK n: {err}"))
378 })?
379 )
380 }
381 _ => {
382 return Err(OAuthError::InvalidDpopProof(format!(
383 "unsupported DPoP JWK key type '{kty}'"
384 )));
385 }
386 };
387
388 Ok(URL_SAFE_NO_PAD.encode(Sha256::digest(canonical.as_bytes())))
389}
390
391#[derive(Debug, thiserror::Error)]
393pub enum OAuthError {
394 #[error("missing Authorization header")]
395 MissingToken,
396
397 #[error("invalid Authorization header format (expected: Bearer <token>)")]
398 InvalidFormat,
399
400 #[error("JWT validation failed: {0}")]
401 JwtError(#[from] jsonwebtoken::errors::Error),
402
403 #[error("insufficient scope: required {required}, found {found}")]
404 InsufficientScope { required: String, found: String },
405
406 #[error("JWKS fetch failed: {0}")]
407 JwksFetchFailed(String),
408
409 #[error("no matching key found in JWKS for kid '{0}'")]
410 NoMatchingKey(String),
411
412 #[error("disallowed algorithm: {0:?} is not in the allowed list")]
413 DisallowedAlgorithm(Algorithm),
414
415 #[error("token missing 'kid' header but JWKS contains {0} keys — ambiguous key selection")]
416 MissingKid(usize),
417
418 #[error("resource mismatch: token resource '{token}' does not match expected '{expected}' (RFC 8707)")]
419 ResourceMismatch { expected: String, token: String },
420
421 #[error("token missing required 'aud' claim")]
422 MissingAudience,
423
424 #[error("token audience mismatch: expected '{expected}', found '{found}'")]
425 AudienceMismatch { expected: String, found: String },
426
427 #[error("authorization server does not support PKCE (S256)")]
428 PkceNotSupported,
429
430 #[error("missing DPoP proof header")]
431 MissingDpopProof,
432
433 #[error("invalid DPoP proof: {0}")]
434 InvalidDpopProof(String),
435
436 #[error("DPoP replay detected")]
437 DpopReplayDetected,
438
439 #[error("JWT claim contains control characters")]
442 ClaimControlCharacters,
443}
444
445struct CachedJwks {
447 keys: JwkSet,
448 fetched_at: Instant,
449}
450
451pub struct OAuthValidator {
455 config: OAuthConfig,
456 http_client: reqwest::Client,
457 jwks_cache: RwLock<Option<CachedJwks>>,
458 cache_ttl: Duration,
460 dpop_jti_cache: RwLock<HashMap<String, u64>>,
464}
465
466impl OAuthValidator {
467 pub fn new(config: OAuthConfig, http_client: reqwest::Client) -> Self {
471 if config.dpop_mode != DpopMode::Off && !config.dpop_require_ath {
475 tracing::warn!(
476 "DPoP mode is {:?} but dpop_require_ath is false — \
477 access-token binding (RFC 9449 §4.3) is NOT enforced",
478 config.dpop_mode
479 );
480 }
481 Self {
482 config,
483 http_client,
484 jwks_cache: RwLock::new(None),
485 cache_ttl: Duration::from_secs(300), dpop_jti_cache: RwLock::new(HashMap::new()),
487 }
488 }
489
490 pub async fn validate_token(&self, auth_header: &str) -> Result<OAuthClaims, OAuthError> {
494 let token = extract_bearer_token(auth_header)?;
495
496 let header = decode_header(token)?;
498
499 if !self.config.allowed_algorithms.contains(&header.alg) {
506 return Err(OAuthError::DisallowedAlgorithm(header.alg));
507 }
508
509 let kid = header.kid.clone().unwrap_or_default();
510
511 let decoding_key = self.get_decoding_key(&kid, &header.alg).await?;
513
514 let mut validation = Validation::new(header.alg);
516 validation.set_issuer(&[&self.config.issuer]);
517 validation.set_audience(&[&self.config.audience]);
518 validation.validate_exp = true;
519 validation.validate_nbf = true; validation.leeway = self.config.clock_skew_leeway.as_secs();
521
522 let token_data: TokenData<OAuthClaims> = decode(token, &decoding_key, &validation)?;
524 let claims = token_data.claims;
525
526 if contains_control_chars(&claims.sub)
530 || contains_control_chars(&claims.iss)
531 || contains_control_chars(&claims.scope)
532 || claims.aud.iter().any(|a| contains_control_chars(a))
533 || claims
534 .resource
535 .as_deref()
536 .is_some_and(contains_control_chars)
537 || claims
538 .cnf
539 .as_ref()
540 .and_then(|cnf| cnf.jkt.as_deref())
541 .is_some_and(contains_control_chars)
542 {
543 tracing::warn!("SECURITY: Rejecting JWT with control characters in claims");
544 return Err(OAuthError::ClaimControlCharacters);
545 }
546
547 if claims.aud.is_empty() {
548 if self.config.require_audience {
549 return Err(OAuthError::MissingAudience);
550 }
551 } else if !claims.aud.iter().any(|aud| aud == &self.config.audience) {
552 return Err(OAuthError::AudienceMismatch {
553 expected: self.config.audience.clone(),
554 found: claims.aud.join(" "),
555 });
556 }
557
558 if !self.config.required_scopes.is_empty() {
560 let token_scopes = claims.scopes();
561 for required in &self.config.required_scopes {
562 if !token_scopes.contains(&required.as_str()) {
563 return Err(OAuthError::InsufficientScope {
564 required: self.config.required_scopes.join(" "),
565 found: claims.scope.clone(),
566 });
567 }
568 }
569 }
570
571 if let Some(ref expected_resource) = self.config.expected_resource {
575 match &claims.resource {
576 Some(token_resource) if token_resource == expected_resource => {
577 }
579 Some(token_resource) => {
580 return Err(OAuthError::ResourceMismatch {
581 expected: expected_resource.clone(),
582 token: token_resource.clone(),
583 });
584 }
585 None => {
586 return Err(OAuthError::ResourceMismatch {
587 expected: expected_resource.clone(),
588 token: String::new(),
589 });
590 }
591 }
592 }
593
594 if self.config.dpop_mode == DpopMode::Required {
595 let token_jkt = claims
596 .cnf
597 .as_ref()
598 .and_then(|cnf| cnf.jkt.as_deref())
599 .map(str::trim)
600 .filter(|jkt| !jkt.is_empty());
601
602 if token_jkt.is_none() {
603 return Err(OAuthError::InvalidDpopProof(
604 "missing cnf.jkt in access token for required DPoP mode".to_string(),
605 ));
606 }
607 }
608
609 Ok(claims)
610 }
611
612 pub async fn validate_dpop_proof(
614 &self,
615 dpop_header: Option<&str>,
616 access_token: &str,
617 expected_method: &str,
618 expected_uri: &str,
619 token_claims: Option<&OAuthClaims>,
620 ) -> Result<(), OAuthError> {
621 match self.config.dpop_mode {
622 DpopMode::Off => return Ok(()),
623 DpopMode::Optional if dpop_header.is_none() => return Ok(()),
624 DpopMode::Required if dpop_header.is_none() => {
625 return Err(OAuthError::MissingDpopProof)
626 }
627 _ => {}
628 }
629
630 let proof_jwt = dpop_header
631 .map(str::trim)
632 .filter(|v| !v.is_empty())
633 .ok_or(OAuthError::MissingDpopProof)?;
634
635 let header = decode_header(proof_jwt)?;
636
637 if !self.config.dpop_allowed_algorithms.contains(&header.alg) {
638 return Err(OAuthError::DisallowedAlgorithm(header.alg));
639 }
640
641 let has_dpop_typ = header
642 .typ
643 .as_deref()
644 .map(|typ| typ.eq_ignore_ascii_case("dpop+jwt"))
645 .unwrap_or(false);
646 if !has_dpop_typ {
647 return Err(OAuthError::InvalidDpopProof(
648 "missing typ=dpop+jwt header".to_string(),
649 ));
650 }
651
652 let jwk = header.jwk.ok_or_else(|| {
653 OAuthError::InvalidDpopProof("missing embedded JWK in DPoP header".to_string())
654 })?;
655 let decoding_key = DecodingKey::from_jwk(&jwk)
656 .map_err(|e| OAuthError::InvalidDpopProof(format!("invalid DPoP JWK: {}", e)))?;
657
658 let mut validation = Validation::new(header.alg);
659 validation.validate_exp = false;
660 validation.validate_nbf = false;
661 validation.required_spec_claims.clear();
662 let token_data: TokenData<DpopClaims> = decode(proof_jwt, &decoding_key, &validation)?;
663 let claims = token_data.claims;
664
665 if contains_control_chars(&claims.htm)
669 || contains_control_chars(&claims.htu)
670 || contains_control_chars(&claims.jti)
671 {
672 return Err(OAuthError::InvalidDpopProof(
673 "DPoP claims contain control or format characters".to_string(),
674 ));
675 }
676
677 if claims.htm.is_empty() || !claims.htm.eq_ignore_ascii_case(expected_method) {
678 return Err(OAuthError::InvalidDpopProof(format!(
679 "htm mismatch: expected '{}', got '{}'",
680 expected_method, claims.htm
681 )));
682 }
683
684 if !claims.htu.is_ascii() {
688 return Err(OAuthError::InvalidDpopProof(
689 "htu contains non-ASCII characters".to_string(),
690 ));
691 }
692
693 if normalize_dpop_htu(&claims.htu) != normalize_dpop_htu(expected_uri) {
699 return Err(OAuthError::InvalidDpopProof(format!(
700 "htu mismatch: expected '{}', got '{}'",
701 expected_uri, claims.htu
702 )));
703 }
704
705 if claims.jti.trim().is_empty() {
706 return Err(OAuthError::InvalidDpopProof("missing jti".to_string()));
707 }
708 if claims.jti.len() > 256 {
710 return Err(OAuthError::InvalidDpopProof(
711 "jti exceeds maximum length".to_string(),
712 ));
713 }
714
715 let now = chrono::Utc::now().timestamp();
716 let iat = claims.iat as i64;
717 let skew = self.config.dpop_max_clock_skew.as_secs() as i64;
718 if (now - iat).abs() > skew {
719 return Err(OAuthError::InvalidDpopProof(format!(
720 "iat outside allowed skew window (iat={}, now={})",
721 claims.iat, now
722 )));
723 }
724
725 if let Some(exp) = claims.exp {
727 if (exp as i64) < now - skew {
728 return Err(OAuthError::InvalidDpopProof(
729 "DPoP proof expired".to_string(),
730 ));
731 }
732 }
733 if let Some(nbf) = claims.nbf {
734 if (nbf as i64) > now + skew {
735 return Err(OAuthError::InvalidDpopProof(
736 "DPoP proof not yet valid (nbf in future)".to_string(),
737 ));
738 }
739 }
740
741 if claims._extra.len() > 20 {
744 return Err(OAuthError::InvalidDpopProof(
745 "DPoP proof contains too many unknown claims".to_string(),
746 ));
747 }
748
749 if self.config.dpop_require_ath {
750 let expected_ath = URL_SAFE_NO_PAD.encode(Sha256::digest(access_token.as_bytes()));
751 match claims.ath.as_deref() {
752 Some(ath) if ath == expected_ath => {}
753 _ => {
754 return Err(OAuthError::InvalidDpopProof(
755 "ath mismatch for access token binding".to_string(),
756 ));
757 }
758 }
759 }
760
761 if let Some(token_jkt) = token_claims
762 .and_then(|c| c.cnf.as_ref())
763 .and_then(|cnf| cnf.jkt.as_deref())
764 .map(str::trim)
765 .filter(|jkt| !jkt.is_empty())
766 {
767 let proof_jkt = dpop_jwk_thumbprint_sha256(&jwk)?;
768 if proof_jkt != token_jkt {
769 return Err(OAuthError::InvalidDpopProof(
770 "cnf.jkt does not match DPoP proof key thumbprint".to_string(),
771 ));
772 }
773 }
774
775 let now_u64 = now.max(0) as u64;
777 let replay_window = std::cmp::max((skew.max(0) as u64) * 2, 600);
778 let oldest_allowed = now_u64.saturating_sub(replay_window);
779
780 let replay_key = match claims.ath.as_deref() {
783 Some(ath) if !ath.is_empty() => format!("{}:{}", claims.jti, ath),
784 _ => claims.jti.clone(),
785 };
786 if replay_key.len() > 512 {
787 return Err(OAuthError::InvalidDpopProof(
788 "DPoP replay key exceeds maximum length".to_string(),
789 ));
790 }
791
792 const MAX_JTI_CACHE_SIZE: usize = 8192;
798
799 let mut cache = self.dpop_jti_cache.write().await;
800
801 cache.retain(|_, ts| *ts >= oldest_allowed);
803
804 if cache.contains_key(&replay_key) {
806 return Err(OAuthError::DpopReplayDetected);
807 }
808
809 if cache.len() >= MAX_JTI_CACHE_SIZE {
811 return Err(OAuthError::InvalidDpopProof(
812 "DPoP replay cache at capacity — try again later".to_string(),
813 ));
814 }
815
816 cache.insert(replay_key, now_u64);
817
818 Ok(())
819 }
820
821 async fn get_decoding_key(
827 &self,
828 kid: &str,
829 alg: &Algorithm,
830 ) -> Result<DecodingKey, OAuthError> {
831 {
833 let cache = self.jwks_cache.read().await;
834 if let Some(cached) = cache.as_ref() {
835 if cached.fetched_at.elapsed() < self.cache_ttl {
836 if let Some(key) = find_key_in_jwks(&cached.keys, kid, alg) {
837 return Ok(key);
838 }
839 }
840 }
841 }
842 let mut cache = self.jwks_cache.write().await;
846
847 if let Some(cached) = cache.as_ref() {
849 if cached.fetched_at.elapsed() < self.cache_ttl {
850 if let Some(key) = find_key_in_jwks(&cached.keys, kid, alg) {
851 return Ok(key);
852 }
853 }
854 }
855
856 let jwks = self.fetch_jwks().await?;
858
859 if kid.is_empty() && jwks.keys.len() > 1 {
863 return Err(OAuthError::MissingKid(jwks.keys.len()));
864 }
865
866 let key = find_key_in_jwks(&jwks, kid, alg)
867 .ok_or_else(|| OAuthError::NoMatchingKey(kid.to_string()))?;
868
869 *cache = Some(CachedJwks {
871 keys: jwks,
872 fetched_at: Instant::now(),
873 });
874
875 Ok(key)
876 }
877
878 async fn fetch_jwks(&self) -> Result<JwkSet, OAuthError> {
880 let uri = self.config.effective_jwks_uri();
881
882 tracing::debug!("Fetching JWKS from {}", uri);
883
884 let response = self
885 .http_client
886 .get(&uri)
887 .timeout(Duration::from_secs(10))
888 .send()
889 .await
890 .map_err(|e| OAuthError::JwksFetchFailed(format!("request failed: {}", e)))?;
891
892 if !response.status().is_success() {
893 return Err(OAuthError::JwksFetchFailed(format!(
894 "HTTP {}",
895 response.status()
896 )));
897 }
898
899 const MAX_JWKS_BODY_SIZE: usize = 1024 * 1024;
902
903 if let Some(len) = response.content_length() {
904 if len as usize > MAX_JWKS_BODY_SIZE {
905 return Err(OAuthError::JwksFetchFailed(format!(
906 "JWKS Content-Length {} exceeds {} byte limit",
907 len, MAX_JWKS_BODY_SIZE
908 )));
909 }
910 }
911
912 let capacity = std::cmp::min(
915 response.content_length().unwrap_or(8192) as usize,
916 MAX_JWKS_BODY_SIZE,
917 );
918 let mut body = Vec::with_capacity(capacity);
919 let mut response = response;
920 while let Some(chunk) = response
921 .chunk()
922 .await
923 .map_err(|e| OAuthError::JwksFetchFailed(format!("body read failed: {}", e)))?
924 {
925 if body.len().saturating_add(chunk.len()) > MAX_JWKS_BODY_SIZE {
926 return Err(OAuthError::JwksFetchFailed(format!(
927 "JWKS response exceeds {} byte limit",
928 MAX_JWKS_BODY_SIZE
929 )));
930 }
931 body.extend_from_slice(&chunk);
932 }
933
934 let jwks: JwkSet = serde_json::from_slice(&body)
935 .map_err(|e| OAuthError::JwksFetchFailed(format!("invalid JWKS JSON: {}", e)))?;
936
937 tracing::info!("Fetched {} keys from JWKS endpoint", jwks.keys.len());
938
939 Ok(jwks)
940 }
941
942 pub fn config(&self) -> &OAuthConfig {
944 &self.config
945 }
946}
947
948fn key_algorithm_to_algorithm(ka: &KeyAlgorithm) -> Option<Algorithm> {
953 match ka {
954 KeyAlgorithm::HS256 => Some(Algorithm::HS256),
955 KeyAlgorithm::HS384 => Some(Algorithm::HS384),
956 KeyAlgorithm::HS512 => Some(Algorithm::HS512),
957 KeyAlgorithm::ES256 => Some(Algorithm::ES256),
958 KeyAlgorithm::ES384 => Some(Algorithm::ES384),
959 KeyAlgorithm::RS256 => Some(Algorithm::RS256),
960 KeyAlgorithm::RS384 => Some(Algorithm::RS384),
961 KeyAlgorithm::RS512 => Some(Algorithm::RS512),
962 KeyAlgorithm::PS256 => Some(Algorithm::PS256),
963 KeyAlgorithm::PS384 => Some(Algorithm::PS384),
964 KeyAlgorithm::PS512 => Some(Algorithm::PS512),
965 KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
966 _ => None,
968 }
969}
970
971fn find_key_in_jwks(jwks: &JwkSet, kid: &str, alg: &Algorithm) -> Option<DecodingKey> {
980 for key in &jwks.keys {
981 if !kid.is_empty() {
983 match &key.common.key_id {
984 Some(key_kid) if key_kid == kid => {} Some(_) => continue, None => continue, }
988 }
989
990 if let Some(ref key_alg) = key.common.key_algorithm {
992 match key_algorithm_to_algorithm(key_alg) {
993 Some(mapped) if &mapped == alg => {} _ => continue, }
996 }
997
998 if let Ok(dk) = DecodingKey::from_jwk(key) {
1000 return Some(dk);
1001 }
1002 }
1003 None
1004}
1005
1006pub fn verify_pkce_support(metadata: &serde_json::Value) -> Result<(), OAuthError> {
1039 let supported = metadata
1040 .get("code_challenge_methods_supported")
1041 .and_then(|v| v.as_array())
1042 .map(|arr| arr.iter().any(|m| m.as_str() == Some("S256")))
1043 .unwrap_or(false);
1044
1045 if !supported {
1046 return Err(OAuthError::PkceNotSupported);
1047 }
1048 Ok(())
1049}
1050
1051#[cfg(test)]
1052mod tests {
1053 use super::*;
1054 use vellaveto_types::uri_util::decode_unreserved_percent;
1055
1056 #[test]
1057 fn test_oauth_config_effective_jwks_uri_explicit() {
1058 let config = OAuthConfig {
1059 issuer: "https://auth.example.com".to_string(),
1060 audience: "mcp-server".to_string(),
1061 jwks_uri: Some("https://auth.example.com/keys".to_string()),
1062 required_scopes: vec![],
1063 pass_through: false,
1064 allowed_algorithms: default_allowed_algorithms(),
1065 expected_resource: None,
1066 clock_skew_leeway: Duration::from_secs(30),
1067 require_audience: true,
1068 dpop_mode: DpopMode::Off,
1069 dpop_allowed_algorithms: default_dpop_allowed_algorithms(),
1070 dpop_require_ath: true,
1071 dpop_max_clock_skew: Duration::from_secs(300),
1072 };
1073 assert_eq!(config.effective_jwks_uri(), "https://auth.example.com/keys");
1074 }
1075
1076 #[test]
1077 fn test_oauth_config_effective_jwks_uri_wellknown() {
1078 let config = OAuthConfig {
1079 issuer: "https://auth.example.com".to_string(),
1080 audience: "mcp-server".to_string(),
1081 jwks_uri: None,
1082 required_scopes: vec![],
1083 pass_through: false,
1084 allowed_algorithms: default_allowed_algorithms(),
1085 expected_resource: None,
1086 clock_skew_leeway: Duration::from_secs(30),
1087 require_audience: true,
1088 dpop_mode: DpopMode::Off,
1089 dpop_allowed_algorithms: default_dpop_allowed_algorithms(),
1090 dpop_require_ath: true,
1091 dpop_max_clock_skew: Duration::from_secs(300),
1092 };
1093 assert_eq!(
1094 config.effective_jwks_uri(),
1095 "https://auth.example.com/.well-known/jwks.json"
1096 );
1097 }
1098
1099 #[test]
1100 fn test_oauth_config_effective_jwks_uri_trailing_slash() {
1101 let config = OAuthConfig {
1102 issuer: "https://auth.example.com/".to_string(),
1103 audience: "mcp-server".to_string(),
1104 jwks_uri: None,
1105 required_scopes: vec![],
1106 pass_through: false,
1107 allowed_algorithms: default_allowed_algorithms(),
1108 expected_resource: None,
1109 clock_skew_leeway: Duration::from_secs(30),
1110 require_audience: true,
1111 dpop_mode: DpopMode::Off,
1112 dpop_allowed_algorithms: default_dpop_allowed_algorithms(),
1113 dpop_require_ath: true,
1114 dpop_max_clock_skew: Duration::from_secs(300),
1115 };
1116 assert_eq!(
1117 config.effective_jwks_uri(),
1118 "https://auth.example.com/.well-known/jwks.json"
1119 );
1120 }
1121
1122 #[test]
1123 fn test_oauth_claims_scopes_parsing() {
1124 let claims = OAuthClaims {
1125 sub: "user-123".to_string(),
1126 iss: "https://auth.example.com".to_string(),
1127 aud: vec!["mcp-server".to_string()],
1128 exp: 0,
1129 iat: 0,
1130 scope: "tools.call resources.read admin".to_string(),
1131 resource: None,
1132 cnf: None,
1133 };
1134 let scopes = claims.scopes();
1135 assert_eq!(scopes, vec!["tools.call", "resources.read", "admin"]);
1136 }
1137
1138 #[test]
1139 fn test_oauth_claims_empty_scope() {
1140 let claims = OAuthClaims {
1141 sub: "user-123".to_string(),
1142 iss: "https://auth.example.com".to_string(),
1143 aud: vec![],
1144 exp: 0,
1145 iat: 0,
1146 scope: String::new(),
1147 resource: None,
1148 cnf: None,
1149 };
1150 let scopes = claims.scopes();
1151 assert!(scopes.is_empty());
1152 }
1153
1154 #[test]
1155 fn test_deserialize_aud_string() {
1156 let json = r#"{"sub":"user","aud":"mcp-server","scope":""}"#;
1157 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1158 assert_eq!(claims.aud, vec!["mcp-server"]);
1159 }
1160
1161 #[test]
1162 fn test_deserialize_aud_array() {
1163 let json = r#"{"sub":"user","aud":["mcp-server","other"],"scope":""}"#;
1164 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1165 assert_eq!(claims.aud, vec!["mcp-server", "other"]);
1166 }
1167
1168 #[test]
1169 fn test_oauth_error_display() {
1170 let err = OAuthError::MissingToken;
1171 assert_eq!(err.to_string(), "missing Authorization header");
1172
1173 let err = OAuthError::InsufficientScope {
1174 required: "tools.call admin".to_string(),
1175 found: "tools.call".to_string(),
1176 };
1177 assert!(err.to_string().contains("insufficient scope"));
1178 }
1179
1180 #[test]
1182 fn test_default_allowed_algorithms_excludes_hmac() {
1183 let allowed = default_allowed_algorithms();
1184 assert!(!allowed.contains(&Algorithm::HS256));
1185 assert!(!allowed.contains(&Algorithm::HS384));
1186 assert!(!allowed.contains(&Algorithm::HS512));
1187 }
1188
1189 #[test]
1190 fn test_default_allowed_algorithms_includes_asymmetric() {
1191 let allowed = default_allowed_algorithms();
1192 assert!(allowed.contains(&Algorithm::RS256));
1193 assert!(allowed.contains(&Algorithm::ES256));
1194 assert!(allowed.contains(&Algorithm::PS256));
1195 assert!(allowed.contains(&Algorithm::EdDSA));
1196 }
1197
1198 #[test]
1199 fn test_disallowed_algorithm_error_display() {
1200 let err = OAuthError::DisallowedAlgorithm(Algorithm::HS256);
1201 assert!(err.to_string().contains("disallowed algorithm"));
1202 assert!(err.to_string().contains("HS256"));
1203 }
1204
1205 #[test]
1206 fn test_missing_kid_error_display() {
1207 let err = OAuthError::MissingKid(3);
1208 assert!(err.to_string().contains("missing 'kid'"));
1209 assert!(err.to_string().contains("3 keys"));
1210 }
1211
1212 #[test]
1214 fn test_key_algorithm_to_algorithm_all_signing() {
1215 assert_eq!(
1216 key_algorithm_to_algorithm(&KeyAlgorithm::HS256),
1217 Some(Algorithm::HS256)
1218 );
1219 assert_eq!(
1220 key_algorithm_to_algorithm(&KeyAlgorithm::RS256),
1221 Some(Algorithm::RS256)
1222 );
1223 assert_eq!(
1224 key_algorithm_to_algorithm(&KeyAlgorithm::ES256),
1225 Some(Algorithm::ES256)
1226 );
1227 assert_eq!(
1228 key_algorithm_to_algorithm(&KeyAlgorithm::PS256),
1229 Some(Algorithm::PS256)
1230 );
1231 assert_eq!(
1232 key_algorithm_to_algorithm(&KeyAlgorithm::EdDSA),
1233 Some(Algorithm::EdDSA)
1234 );
1235 }
1236
1237 #[test]
1238 fn test_key_algorithm_to_algorithm_encryption_returns_none() {
1239 assert_eq!(key_algorithm_to_algorithm(&KeyAlgorithm::RSA1_5), None);
1240 assert_eq!(key_algorithm_to_algorithm(&KeyAlgorithm::RSA_OAEP), None);
1241 assert_eq!(
1242 key_algorithm_to_algorithm(&KeyAlgorithm::RSA_OAEP_256),
1243 None
1244 );
1245 }
1246
1247 #[test]
1249 fn test_resource_mismatch_error_display() {
1250 let err = OAuthError::ResourceMismatch {
1251 expected: "https://mcp.example.com".to_string(),
1252 token: "https://other.example.com".to_string(),
1253 };
1254 let msg = err.to_string();
1255 assert!(msg.contains("resource mismatch"));
1256 assert!(msg.contains("https://mcp.example.com"));
1257 assert!(msg.contains("https://other.example.com"));
1258 assert!(msg.contains("RFC 8707"));
1259 }
1260
1261 #[test]
1262 fn test_resource_mismatch_missing_claim_error_display() {
1263 let err = OAuthError::ResourceMismatch {
1264 expected: "https://mcp.example.com".to_string(),
1265 token: String::new(),
1266 };
1267 let msg = err.to_string();
1268 assert!(msg.contains("resource mismatch"));
1269 assert!(msg.contains("https://mcp.example.com"));
1270 }
1271
1272 #[test]
1273 fn test_deserialize_claims_with_resource() {
1274 let json =
1275 r#"{"sub":"user","aud":"mcp-server","scope":"","resource":"https://mcp.example.com"}"#;
1276 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1277 assert_eq!(claims.resource, Some("https://mcp.example.com".to_string()));
1278 }
1279
1280 #[test]
1281 fn test_deserialize_claims_without_resource() {
1282 let json = r#"{"sub":"user","aud":"mcp-server","scope":""}"#;
1283 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1284 assert_eq!(claims.resource, None);
1285 }
1286
1287 #[test]
1288 fn test_deserialize_claims_with_cnf_jkt() {
1289 let json = r#"{"sub":"user","aud":"mcp-server","scope":"","cnf":{"jkt":"thumbprint-123"}}"#;
1290 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1291 let jkt = claims
1292 .cnf
1293 .as_ref()
1294 .and_then(|cnf| cnf.jkt.as_deref())
1295 .expect("cnf.jkt must deserialize");
1296 assert_eq!(jkt, "thumbprint-123");
1297 }
1298
1299 #[test]
1300 fn test_dpop_jwk_thumbprint_sha256_rsa() {
1301 let jwk: jsonwebtoken::jwk::Jwk = serde_json::from_value(serde_json::json!({
1302 "kty": "RSA",
1303 "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
1304 "e": "AQAB"
1305 }))
1306 .expect("valid RSA JWK");
1307
1308 let thumbprint = dpop_jwk_thumbprint_sha256(&jwk).expect("thumbprint should compute");
1309 assert_eq!(thumbprint, "NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs");
1310 }
1311
1312 #[test]
1313 fn test_dpop_jwk_thumbprint_rejects_unsupported_key_type() {
1314 let jwk: jsonwebtoken::jwk::Jwk =
1315 serde_json::from_value(serde_json::json!({"kty": "oct", "k": "AQAB"}))
1316 .expect("valid octet JWK");
1317
1318 let err = dpop_jwk_thumbprint_sha256(&jwk).expect_err("octet keys are not valid for DPoP");
1319 assert!(err.to_string().contains("unsupported DPoP JWK key type"));
1320 }
1321
1322 #[test]
1323 fn test_clock_skew_leeway_configurable() {
1324 let config = OAuthConfig {
1325 issuer: "https://auth.example.com".to_string(),
1326 audience: "mcp-server".to_string(),
1327 jwks_uri: None,
1328 required_scopes: vec![],
1329 pass_through: false,
1330 allowed_algorithms: default_allowed_algorithms(),
1331 expected_resource: None,
1332 clock_skew_leeway: Duration::from_secs(60),
1333 require_audience: true,
1334 dpop_mode: DpopMode::Off,
1335 dpop_allowed_algorithms: default_dpop_allowed_algorithms(),
1336 dpop_require_ath: true,
1337 dpop_max_clock_skew: Duration::from_secs(300),
1338 };
1339 assert_eq!(config.clock_skew_leeway, Duration::from_secs(60));
1340 }
1341
1342 #[test]
1343 fn test_deserialize_missing_aud_yields_empty_vec() {
1344 let json = r#"{"sub":"user","scope":"read"}"#;
1345 let claims: OAuthClaims = serde_json::from_str(json).unwrap();
1346 assert!(claims.aud.is_empty());
1347 }
1348
1349 #[test]
1350 fn test_missing_audience_error_display() {
1351 let err = OAuthError::MissingAudience;
1352 assert_eq!(err.to_string(), "token missing required 'aud' claim");
1353 }
1354
1355 #[test]
1356 fn test_audience_mismatch_error_display() {
1357 let err = OAuthError::AudienceMismatch {
1358 expected: "mcp-server".to_string(),
1359 found: "other-aud".to_string(),
1360 };
1361 let msg = err.to_string();
1362 assert!(msg.contains("audience mismatch"));
1363 assert!(msg.contains("mcp-server"));
1364 assert!(msg.contains("other-aud"));
1365 }
1366
1367 #[test]
1372 fn test_verify_pkce_support_s256_supported() {
1373 let metadata = serde_json::json!({
1374 "issuer": "https://auth.example.com",
1375 "code_challenge_methods_supported": ["S256", "plain"]
1376 });
1377 assert!(verify_pkce_support(&metadata).is_ok());
1378 }
1379
1380 #[test]
1381 fn test_verify_pkce_support_s256_only() {
1382 let metadata = serde_json::json!({
1383 "issuer": "https://auth.example.com",
1384 "code_challenge_methods_supported": ["S256"]
1385 });
1386 assert!(verify_pkce_support(&metadata).is_ok());
1387 }
1388
1389 #[test]
1390 fn test_verify_pkce_support_missing_field() {
1391 let metadata = serde_json::json!({
1392 "issuer": "https://auth.example.com"
1393 });
1394 let result = verify_pkce_support(&metadata);
1395 assert!(matches!(result, Err(OAuthError::PkceNotSupported)));
1396 }
1397
1398 #[test]
1399 fn test_verify_pkce_support_plain_only() {
1400 let metadata = serde_json::json!({
1402 "issuer": "https://auth.example.com",
1403 "code_challenge_methods_supported": ["plain"]
1404 });
1405 let result = verify_pkce_support(&metadata);
1406 assert!(matches!(result, Err(OAuthError::PkceNotSupported)));
1407 }
1408
1409 #[test]
1410 fn test_verify_pkce_support_empty_array() {
1411 let metadata = serde_json::json!({
1412 "issuer": "https://auth.example.com",
1413 "code_challenge_methods_supported": []
1414 });
1415 let result = verify_pkce_support(&metadata);
1416 assert!(matches!(result, Err(OAuthError::PkceNotSupported)));
1417 }
1418
1419 #[test]
1420 fn test_pkce_not_supported_error_display() {
1421 let err = OAuthError::PkceNotSupported;
1422 assert!(err.to_string().contains("PKCE"));
1423 assert!(err.to_string().contains("S256"));
1424 }
1425
1426 const TEST_RSA_PRIVATE_PEM: &str = r#"-----BEGIN PRIVATE KEY-----
1441MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDXsFTcrmrrw3RK
1442Ll2pK5mhqySdvuoQY/U1CwFXmu1S2BAEnh+/Yzsilc/LWJjBmcDdmY88NC+F8PhO
1443q6+hQjWZR08QewinBg69w2+TRqr4x09XXZm/w3Y+jOlspHR85PISy8sqkHzGk3o4
1444cLNCxDkw2mwQaVDQQz1YJ0x22+IoOZniRTntUK1yAyI0jqhpZJjn9dY+CbDt/H8B
1445nGollAlhKQFizDAIMYOL/duJLJv1jtgv5hwvH94tSYgLGzJcufmxvioBBD4ZcxDq
1446Lk2vNdC5ETVkS9GeyYfuQbHW55lYSACe2NfYwgwYc3PO6X6PlJxuksL7JfR6kivU
1447WkYHOTVLAgMBAAECggEATuXElRkEKYvMrRn6ztgREa9N7JoaerZlyupkqkwUxfod
1448GeNRj6vXxNXyNdsJvb/laeozF/2q6J715aktzJowiwonpMqsppQzrjygQspV3jzi
1449C/5EMH5qcYUQGdqqdck1t6Rug/poeicWTTEEkca/eNxdLT+o/RWrieSONuhF+Ro0
14507S60Dc4tFRA6XBDayikUzuFd2XoroRfoukC+HcC7mHMQdPHNt7QjORJNitdjwP+P
1451BcwNm61043sz9VSdW9FMtrdpg+pndbzRiYwYCDRt7r0hhUSZ4cojY6Tyoqexa5BD
14527W5jDTmySO5/Jzl3QGvtevyKVx3x6DQE9858W8kucQKBgQDynqb19kVK6IFPccFX
145301D9qVg0vZ1WoAR30s79DkoKA/NyM3sjP431p38Kkj1QomERBSCb0O1OOfsfjAEO
14542SoEqTSa+2cgDYikQ16IEqKfbucilDNMTsYQz9Jwx16BEJRGoz+lbt52exhZN+nf
1455qfVtuwIvlb46bksxh5pXJ9L7wwKBgQDjlXatVjiZgigpGwmj4/Uj9tI0c+AmtooX
1456zA/2B3GJdXbRVtMvFsQB73/d7U2lCwUHJmG56FYS1Dg8C88Xn7nE6kBSYoeguxCA
1457bLPBbCGtPD/VeGP2ymxLxsULLiiRx4S+6K7hUulCkg9m3CWkc3m5AH5lrHVEXgi1
1458YadcFKMv2QKBgAuaJqXQdxPT9osUB4jppA/dT0iGYMXJtSz9ucREMKo18ihd6d+P
1459pHxA3ERnJeN7QGUN97c70H1TLH0fttU88VNzu/5FU3Mm8ofYaObc7UXuicMPjzxw
14607+vR5GBcSFqnrk+Kcvq4SI8l584sbFSzzfbHYJ1h7czhhVsC/xB36RD9AoGBANQ8
1461JXGer6fQrp0u3r2dL5Y7bmqGCWpw3rU0k0nwRRxYk9bDbqxCQcZAUHFpBPi+HxE8
14625PQXTHXAvTSaGqXASeDuR8/MnQjyioAJX1Uo/vrr7eeonyieO4IrOsSjZigU9aGH
1463otb0mB2B0qUs9lm3arNxV25/9tgsDVkBWa7QfCJ5AoGBAJF/XTU+YnGjQYGgxvfg
1464Ma5j2E3NRga/10ncKjDKbRNzLXk887xp4kl68vDTayAKGLu+ndYQ9dMpHCUTPky5
14652KGQijoG2H/1Ri4JE8dGa+RbjG3gMIRIdbYApn/Q4nrAadrWrDLaTpbnAhhL95FJ
1466TfzccotDw2uXy3Xbwy/kdpfK
1467-----END PRIVATE KEY-----"#;
1468
1469 const TEST_RSA_N: &str = "17BU3K5q68N0Si5dqSuZoasknb7qEGP1NQsBV5rtUtgQBJ4fv2M7IpXPy1iYwZnA3ZmPPDQvhfD4TquvoUI1mUdPEHsIpwYOvcNvk0aq-MdPV12Zv8N2PozpbKR0fOTyEsvLKpB8xpN6OHCzQsQ5MNpsEGlQ0EM9WCdMdtviKDmZ4kU57VCtcgMiNI6oaWSY5_XWPgmw7fx_AZxqJZQJYSkBYswwCDGDi_3biSyb9Y7YL-YcLx_eLUmICxsyXLn5sb4qAQQ-GXMQ6i5NrzXQuRE1ZEvRnsmH7kGx1ueZWEgAntjX2MIMGHNzzul-j5ScbpLC-yX0epIr1FpGBzk1Sw";
1471 const TEST_RSA_E: &str = "AQAB";
1472
1473 fn test_jwks_json(kid: &str) -> String {
1475 serde_json::json!({
1476 "keys": [{
1477 "kty": "RSA",
1478 "use": "sig",
1479 "alg": "RS256",
1480 "kid": kid,
1481 "n": TEST_RSA_N,
1482 "e": TEST_RSA_E
1483 }]
1484 })
1485 .to_string()
1486 }
1487
1488 async fn start_mock_jwks_server(
1494 jwks_json: String,
1495 ) -> Option<(String, tokio::task::JoinHandle<()>)> {
1496 use axum::{routing::get, Router};
1497 use std::net::SocketAddr;
1498
1499 let app = Router::new().route(
1500 "/.well-known/jwks.json",
1501 get(move || {
1502 let json = jwks_json.clone();
1503 async move {
1504 (
1505 [(
1506 http::header::CONTENT_TYPE,
1507 http::HeaderValue::from_static("application/json"),
1508 )],
1509 json,
1510 )
1511 }
1512 }),
1513 );
1514
1515 let listener = match tokio::net::TcpListener::bind("127.0.0.1:0").await {
1516 Ok(listener) => listener,
1517 Err(error) if error.kind() == std::io::ErrorKind::PermissionDenied => {
1518 eprintln!("skipping oauth e2e test: cannot bind local jwks server: {error}");
1519 return None;
1520 }
1521 Err(error) => panic!("bind to random port: {error}"),
1522 };
1523 let addr: SocketAddr = listener.local_addr().expect("local addr");
1524 let base_url = format!("http://127.0.0.1:{}", addr.port());
1525
1526 let handle = tokio::spawn(async move {
1527 axum::serve(listener, app)
1528 .await
1529 .expect("mock JWKS server failed");
1530 });
1531
1532 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1534
1535 Some((base_url, handle))
1536 }
1537
1538 fn test_oauth_config(jwks_url: String) -> OAuthConfig {
1540 OAuthConfig {
1541 issuer: "https://auth.example.com".to_string(),
1542 audience: "mcp-server".to_string(),
1543 jwks_uri: Some(jwks_url),
1544 required_scopes: vec!["tools.call".to_string()],
1545 pass_through: false,
1546 allowed_algorithms: default_allowed_algorithms(),
1547 expected_resource: None,
1548 clock_skew_leeway: Duration::from_secs(30),
1549 require_audience: true,
1550 dpop_mode: DpopMode::Off,
1551 dpop_allowed_algorithms: default_dpop_allowed_algorithms(),
1552 dpop_require_ath: true,
1553 dpop_max_clock_skew: Duration::from_secs(300),
1554 }
1555 }
1556
1557 fn sign_test_jwt(claims: &serde_json::Value, kid: &str) -> String {
1559 use jsonwebtoken::{encode, EncodingKey, Header};
1560
1561 let key =
1562 EncodingKey::from_rsa_pem(TEST_RSA_PRIVATE_PEM.as_bytes()).expect("valid RSA PEM");
1563 let mut header = Header::new(Algorithm::RS256);
1564 header.kid = Some(kid.to_string());
1565
1566 encode(&header, claims, &key).expect("JWT signing must succeed")
1567 }
1568
1569 fn valid_claims() -> serde_json::Value {
1571 let now = chrono::Utc::now().timestamp() as u64;
1572 serde_json::json!({
1573 "sub": "user-123",
1574 "iss": "https://auth.example.com",
1575 "aud": "mcp-server",
1576 "exp": now + 3600,
1577 "iat": now,
1578 "nbf": now - 10,
1579 "scope": "tools.call resources.read"
1580 })
1581 }
1582
1583 #[tokio::test]
1584 async fn test_e2e_valid_jwt_accepted() {
1585 let kid = "test-key-1";
1586 let jwks = test_jwks_json(kid);
1587 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1588 return;
1589 };
1590
1591 let config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1592 let validator = OAuthValidator::new(config, reqwest::Client::new());
1593
1594 let token = sign_test_jwt(&valid_claims(), kid);
1595 let auth_header = format!("Bearer {}", token);
1596
1597 let claims = validator
1598 .validate_token(&auth_header)
1599 .await
1600 .expect("valid JWT must be accepted");
1601 assert_eq!(claims.sub, "user-123");
1602 assert_eq!(claims.iss, "https://auth.example.com");
1603 assert!(claims.scopes().contains(&"tools.call"));
1604 }
1605
1606 #[tokio::test]
1607 async fn test_e2e_expired_jwt_rejected() {
1608 let kid = "test-key-1";
1609 let jwks = test_jwks_json(kid);
1610 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1611 return;
1612 };
1613
1614 let config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1615 let validator = OAuthValidator::new(config, reqwest::Client::new());
1616
1617 let now = chrono::Utc::now().timestamp() as u64;
1618 let mut claims = valid_claims();
1619 claims["exp"] = serde_json::json!(now - 600); let token = sign_test_jwt(&claims, kid);
1622 let auth_header = format!("Bearer {}", token);
1623
1624 let err = validator
1625 .validate_token(&auth_header)
1626 .await
1627 .expect_err("expired JWT must be rejected");
1628 assert!(
1629 matches!(err, OAuthError::JwtError(_)),
1630 "expected JwtError for expired token, got: {err}"
1631 );
1632 }
1633
1634 #[tokio::test]
1635 async fn test_e2e_wrong_algorithm_rejected() {
1636 let kid = "test-key-1";
1637 let jwks = test_jwks_json(kid);
1638 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1639 return;
1640 };
1641
1642 let mut config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1644 config.allowed_algorithms = vec![Algorithm::ES256];
1645
1646 let validator = OAuthValidator::new(config, reqwest::Client::new());
1647
1648 let token = sign_test_jwt(&valid_claims(), kid);
1650 let auth_header = format!("Bearer {}", token);
1651
1652 let err = validator
1653 .validate_token(&auth_header)
1654 .await
1655 .expect_err("RS256 JWT must be rejected when only ES256 is allowed");
1656 assert!(
1657 matches!(err, OAuthError::DisallowedAlgorithm(Algorithm::RS256)),
1658 "expected DisallowedAlgorithm(RS256), got: {err}"
1659 );
1660 }
1661
1662 #[tokio::test]
1663 async fn test_e2e_wrong_issuer_rejected() {
1664 let kid = "test-key-1";
1665 let jwks = test_jwks_json(kid);
1666 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1667 return;
1668 };
1669
1670 let config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1671 let validator = OAuthValidator::new(config, reqwest::Client::new());
1672
1673 let mut claims = valid_claims();
1674 claims["iss"] = serde_json::json!("https://evil.example.com");
1675
1676 let token = sign_test_jwt(&claims, kid);
1677 let auth_header = format!("Bearer {}", token);
1678
1679 let err = validator
1680 .validate_token(&auth_header)
1681 .await
1682 .expect_err("wrong issuer must be rejected");
1683 assert!(
1684 matches!(err, OAuthError::JwtError(_)),
1685 "expected JwtError for issuer mismatch, got: {err}"
1686 );
1687 }
1688
1689 #[tokio::test]
1690 async fn test_e2e_wrong_audience_rejected() {
1691 let kid = "test-key-1";
1692 let jwks = test_jwks_json(kid);
1693 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1694 return;
1695 };
1696
1697 let config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1698 let validator = OAuthValidator::new(config, reqwest::Client::new());
1699
1700 let mut claims = valid_claims();
1701 claims["aud"] = serde_json::json!("wrong-audience");
1702
1703 let token = sign_test_jwt(&claims, kid);
1704 let auth_header = format!("Bearer {}", token);
1705
1706 let err = validator
1707 .validate_token(&auth_header)
1708 .await
1709 .expect_err("wrong audience must be rejected");
1710 assert!(
1712 matches!(
1713 err,
1714 OAuthError::JwtError(_) | OAuthError::AudienceMismatch { .. }
1715 ),
1716 "expected audience rejection, got: {err}"
1717 );
1718 }
1719
1720 #[tokio::test]
1721 async fn test_e2e_missing_required_scope_rejected() {
1722 let kid = "test-key-1";
1723 let jwks = test_jwks_json(kid);
1724 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1725 return;
1726 };
1727
1728 let config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1729 let validator = OAuthValidator::new(config, reqwest::Client::new());
1730
1731 let mut claims = valid_claims();
1732 claims["scope"] = serde_json::json!("resources.read"); let token = sign_test_jwt(&claims, kid);
1735 let auth_header = format!("Bearer {}", token);
1736
1737 let err = validator
1738 .validate_token(&auth_header)
1739 .await
1740 .expect_err("missing required scope must be rejected");
1741 assert!(
1742 matches!(err, OAuthError::InsufficientScope { .. }),
1743 "expected InsufficientScope, got: {err}"
1744 );
1745 }
1746
1747 #[tokio::test]
1748 async fn test_e2e_resource_mismatch_rejected() {
1749 let kid = "test-key-1";
1750 let jwks = test_jwks_json(kid);
1751 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1752 return;
1753 };
1754
1755 let mut config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1756 config.expected_resource = Some("https://mcp.example.com".to_string());
1757
1758 let validator = OAuthValidator::new(config, reqwest::Client::new());
1759
1760 let mut claims = valid_claims();
1762 claims["resource"] = serde_json::json!("https://evil.example.com");
1763
1764 let token = sign_test_jwt(&claims, kid);
1765 let auth_header = format!("Bearer {}", token);
1766
1767 let err = validator
1768 .validate_token(&auth_header)
1769 .await
1770 .expect_err("resource mismatch must be rejected");
1771 assert!(
1772 matches!(err, OAuthError::ResourceMismatch { .. }),
1773 "expected ResourceMismatch, got: {err}"
1774 );
1775 }
1776
1777 #[tokio::test]
1778 async fn test_e2e_resource_missing_when_required_rejected() {
1779 let kid = "test-key-1";
1780 let jwks = test_jwks_json(kid);
1781 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1782 return;
1783 };
1784
1785 let mut config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1786 config.expected_resource = Some("https://mcp.example.com".to_string());
1787
1788 let validator = OAuthValidator::new(config, reqwest::Client::new());
1789
1790 let token = sign_test_jwt(&valid_claims(), kid);
1792 let auth_header = format!("Bearer {}", token);
1793
1794 let err = validator
1795 .validate_token(&auth_header)
1796 .await
1797 .expect_err("missing resource when required must be rejected");
1798 assert!(
1799 matches!(err, OAuthError::ResourceMismatch { .. }),
1800 "expected ResourceMismatch, got: {err}"
1801 );
1802 }
1803
1804 #[tokio::test]
1805 async fn test_e2e_resource_match_accepted() {
1806 let kid = "test-key-1";
1807 let jwks = test_jwks_json(kid);
1808 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1809 return;
1810 };
1811
1812 let mut config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1813 config.expected_resource = Some("https://mcp.example.com".to_string());
1814
1815 let validator = OAuthValidator::new(config, reqwest::Client::new());
1816
1817 let mut claims = valid_claims();
1818 claims["resource"] = serde_json::json!("https://mcp.example.com");
1819
1820 let token = sign_test_jwt(&claims, kid);
1821 let auth_header = format!("Bearer {}", token);
1822
1823 let result = validator.validate_token(&auth_header).await;
1824 assert!(
1825 result.is_ok(),
1826 "matching resource must be accepted: {result:?}"
1827 );
1828 }
1829
1830 #[tokio::test]
1831 async fn test_e2e_kid_mismatch_rejected() {
1832 let jwks = test_jwks_json("server-key-1");
1833 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1834 return;
1835 };
1836
1837 let config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1838 let validator = OAuthValidator::new(config, reqwest::Client::new());
1839
1840 let token = sign_test_jwt(&valid_claims(), "wrong-key");
1842 let auth_header = format!("Bearer {}", token);
1843
1844 let err = validator
1845 .validate_token(&auth_header)
1846 .await
1847 .expect_err("kid mismatch must be rejected");
1848 assert!(
1849 matches!(err, OAuthError::NoMatchingKey(_)),
1850 "expected NoMatchingKey, got: {err}"
1851 );
1852 }
1853
1854 #[tokio::test]
1855 async fn test_e2e_multi_key_jwks_no_kid_rejected() {
1856 let jwks = serde_json::json!({
1858 "keys": [
1859 {
1860 "kty": "RSA", "use": "sig", "alg": "RS256",
1861 "kid": "key-1",
1862 "n": TEST_RSA_N, "e": TEST_RSA_E
1863 },
1864 {
1865 "kty": "RSA", "use": "sig", "alg": "RS256",
1866 "kid": "key-2",
1867 "n": TEST_RSA_N, "e": TEST_RSA_E
1868 }
1869 ]
1870 })
1871 .to_string();
1872 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1873 return;
1874 };
1875
1876 let config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1877 let validator = OAuthValidator::new(config, reqwest::Client::new());
1878
1879 let key = jsonwebtoken::EncodingKey::from_rsa_pem(TEST_RSA_PRIVATE_PEM.as_bytes())
1881 .expect("valid RSA PEM");
1882 let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
1883 header.kid = None; let token = jsonwebtoken::encode(&header, &valid_claims(), &key).expect("JWT signing");
1885 let auth_header = format!("Bearer {}", token);
1886
1887 let err = validator
1888 .validate_token(&auth_header)
1889 .await
1890 .expect_err("missing kid with multi-key JWKS must be rejected");
1891 assert!(
1892 matches!(err, OAuthError::MissingKid(2)),
1893 "expected MissingKid(2), got: {err}"
1894 );
1895 }
1896
1897 #[tokio::test]
1898 async fn test_e2e_tampered_signature_rejected() {
1899 let kid = "test-key-1";
1900 let jwks = test_jwks_json(kid);
1901 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1902 return;
1903 };
1904
1905 let config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1906 let validator = OAuthValidator::new(config, reqwest::Client::new());
1907
1908 let token = sign_test_jwt(&valid_claims(), kid);
1909 let mut tampered = token.clone();
1911 let last_char = tampered.pop().unwrap_or('A');
1912 tampered.push(if last_char == 'A' { 'B' } else { 'A' });
1913
1914 let auth_header = format!("Bearer {}", tampered);
1915
1916 let err = validator
1917 .validate_token(&auth_header)
1918 .await
1919 .expect_err("tampered signature must be rejected");
1920 assert!(
1921 matches!(err, OAuthError::JwtError(_)),
1922 "expected JwtError for signature tampering, got: {err}"
1923 );
1924 }
1925
1926 #[tokio::test]
1927 async fn test_e2e_missing_audience_with_require_audience_rejected() {
1928 let kid = "test-key-1";
1929 let jwks = test_jwks_json(kid);
1930 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1931 return;
1932 };
1933
1934 let mut config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1935 config.require_audience = true;
1936
1937 let validator = OAuthValidator::new(config, reqwest::Client::new());
1938
1939 let now = chrono::Utc::now().timestamp() as u64;
1941 let claims = serde_json::json!({
1942 "sub": "user-123",
1943 "iss": "https://auth.example.com",
1944 "exp": now + 3600,
1945 "iat": now,
1946 "nbf": now - 10,
1947 "scope": "tools.call"
1948 });
1949
1950 let token = sign_test_jwt(&claims, kid);
1951 let auth_header = format!("Bearer {}", token);
1952
1953 let err = validator
1954 .validate_token(&auth_header)
1955 .await
1956 .expect_err("missing aud with require_audience must be rejected");
1957 assert!(
1958 matches!(err, OAuthError::JwtError(_) | OAuthError::MissingAudience),
1959 "expected audience rejection, got: {err}"
1960 );
1961 }
1962
1963 #[tokio::test]
1964 async fn test_e2e_bearer_case_insensitive() {
1965 let kid = "test-key-1";
1966 let jwks = test_jwks_json(kid);
1967 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1968 return;
1969 };
1970
1971 let config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1972 let validator = OAuthValidator::new(config, reqwest::Client::new());
1973
1974 let token = sign_test_jwt(&valid_claims(), kid);
1975
1976 let auth_header = format!("BEARER {}", token);
1978 let result = validator.validate_token(&auth_header).await;
1979 assert!(
1980 result.is_ok(),
1981 "BEARER (uppercase) must be accepted: {result:?}"
1982 );
1983 }
1984
1985 #[tokio::test]
1986 async fn test_e2e_not_before_future_rejected() {
1987 let kid = "test-key-1";
1988 let jwks = test_jwks_json(kid);
1989 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
1990 return;
1991 };
1992
1993 let config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
1994 let validator = OAuthValidator::new(config, reqwest::Client::new());
1995
1996 let now = chrono::Utc::now().timestamp() as u64;
1997 let mut claims = valid_claims();
1998 claims["nbf"] = serde_json::json!(now + 600); let token = sign_test_jwt(&claims, kid);
2001 let auth_header = format!("Bearer {}", token);
2002
2003 let err = validator
2004 .validate_token(&auth_header)
2005 .await
2006 .expect_err("token with future nbf must be rejected");
2007 assert!(
2008 matches!(err, OAuthError::JwtError(_)),
2009 "expected JwtError for nbf in the future, got: {err}"
2010 );
2011 }
2012
2013 #[tokio::test]
2014 async fn test_e2e_dpop_required_but_no_cnf_jkt_rejected() {
2015 let kid = "test-key-1";
2016 let jwks = test_jwks_json(kid);
2017 let Some((base_url, _handle)) = start_mock_jwks_server(jwks).await else {
2018 return;
2019 };
2020
2021 let mut config = test_oauth_config(format!("{}/.well-known/jwks.json", base_url));
2022 config.dpop_mode = DpopMode::Required;
2023
2024 let validator = OAuthValidator::new(config, reqwest::Client::new());
2025
2026 let token = sign_test_jwt(&valid_claims(), kid);
2028 let auth_header = format!("Bearer {}", token);
2029
2030 let err = validator
2031 .validate_token(&auth_header)
2032 .await
2033 .expect_err("DPoP required but no cnf.jkt must be rejected");
2034 assert!(
2035 matches!(err, OAuthError::InvalidDpopProof(_)),
2036 "expected InvalidDpopProof, got: {err}"
2037 );
2038 }
2039
2040 #[test]
2043 fn test_contains_control_chars_rejects_dangerous_chars() {
2044 assert!(!contains_control_chars("normal-user@example.com"));
2046 assert!(!contains_control_chars("admin"));
2047 assert!(!contains_control_chars(""));
2048
2049 assert!(contains_control_chars("user\x00name")); assert!(contains_control_chars("user\x1bname")); assert!(contains_control_chars("user\x07name")); assert!(contains_control_chars("user\u{200B}name")); assert!(contains_control_chars("user\u{202E}name")); assert!(contains_control_chars("user\u{FEFF}name")); assert!(contains_control_chars("line1\nline2")); assert!(contains_control_chars("col1\tcol2")); assert!(contains_control_chars("x\u{00AD}y")); assert!(contains_control_chars("x\u{FFF9}y")); }
2069
2070 #[test]
2075 fn test_decode_unreserved_percent_decodes_unreserved_chars() {
2076 assert_eq!(decode_unreserved_percent("%2D"), "-");
2077 assert_eq!(decode_unreserved_percent("%2E"), ".");
2078 assert_eq!(decode_unreserved_percent("%5F"), "_");
2079 assert_eq!(decode_unreserved_percent("%7E"), "~");
2080 assert_eq!(decode_unreserved_percent("%41"), "A");
2081 assert_eq!(decode_unreserved_percent("%61"), "a");
2082 assert_eq!(decode_unreserved_percent("%30"), "0");
2083 }
2084
2085 #[test]
2086 fn test_decode_unreserved_percent_keeps_reserved_encoded() {
2087 assert_eq!(decode_unreserved_percent("%2F"), "%2F"); assert_eq!(decode_unreserved_percent("%40"), "%40"); assert_eq!(decode_unreserved_percent("%3A"), "%3A"); assert_eq!(decode_unreserved_percent("%00"), "%00"); assert_eq!(decode_unreserved_percent("%20"), "%20"); assert_eq!(decode_unreserved_percent("%3F"), "%3F"); assert_eq!(decode_unreserved_percent("%23"), "%23"); }
2095
2096 #[test]
2097 fn test_decode_unreserved_percent_normalizes_hex_case() {
2098 assert_eq!(decode_unreserved_percent("%2d"), "-");
2100 assert_eq!(decode_unreserved_percent("%7e"), "~");
2101 assert_eq!(decode_unreserved_percent("%2f"), "%2F");
2103 assert_eq!(decode_unreserved_percent("%3a"), "%3A");
2104 }
2105
2106 #[test]
2107 fn test_decode_unreserved_percent_incomplete_sequences() {
2108 assert_eq!(decode_unreserved_percent("foo%"), "foo%");
2109 assert_eq!(decode_unreserved_percent("foo%2"), "foo%2");
2110 assert_eq!(decode_unreserved_percent("%"), "%");
2111 assert_eq!(decode_unreserved_percent("%G0"), "%G0"); }
2113
2114 #[test]
2115 fn test_decode_unreserved_percent_mixed_content() {
2116 assert_eq!(
2117 decode_unreserved_percent("foo%2Dbar%2Fbaz"),
2118 "foo-bar%2Fbaz"
2119 );
2120 assert_eq!(decode_unreserved_percent(""), "");
2121 assert_eq!(decode_unreserved_percent("no-encoding"), "no-encoding");
2122 assert_eq!(decode_unreserved_percent("a%2Db%2Ec%5Fd%7Ee"), "a-b.c_d~e");
2123 }
2124
2125 }