rust_mcp_sdk/auth/spec/
jwk.rs1use crate::auth::{Audience, AuthClaims, AuthenticationError};
2use http::StatusCode;
3use jsonwebtoken::{decode, decode_header, jwk::Jwk, DecodingKey, TokenData, Validation};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct JsonWebKeySet {
9 pub keys: Vec<Jwk>,
11}
12
13pub fn decode_token_header(token: &str) -> Result<jsonwebtoken::Header, AuthenticationError> {
14 let header =
15 decode_header(token).map_err(|err| AuthenticationError::TokenVerificationFailed {
16 description: err.to_string(),
17 status_code: Some(StatusCode::UNAUTHORIZED.as_u16()),
18 })?;
19 Ok(header)
20}
21
22impl JsonWebKeySet {
23 pub fn verify(
24 &self,
25 token: String,
26 validate_audience: Option<&Audience>,
27 validate_issuer: Option<&String>,
28 ) -> Result<TokenData<AuthClaims>, AuthenticationError> {
29 let header = decode_token_header(&token)?;
30
31 let kid = header.kid.ok_or(AuthenticationError::InvalidToken {
32 description: "Missing kid in token header",
33 })?;
34
35 let jwk = self
36 .keys
37 .iter()
38 .find(|key| key.common.key_id == Some(kid.clone()))
39 .ok_or(AuthenticationError::InvalidToken {
40 description: "No matching key found in JWKS",
41 })?;
42
43 let decoding_key = DecodingKey::from_jwk(jwk).map_err(|err| {
44 AuthenticationError::TokenVerificationFailed {
45 description: err.to_string(),
46 status_code: None,
47 }
48 })?;
49
50 let mut validation = Validation::new(header.alg);
51
52 let mut required_claims = vec![];
53 if let Some(validate_audience) = validate_audience {
54 let vec_audience = match validate_audience {
55 Audience::Single(aud) => &vec![aud.to_owned()],
56 Audience::Multiple(auds) => auds,
57 };
58 validation.set_audience(vec_audience);
59 required_claims.push("aud");
60 } else {
61 validation.validate_aud = false;
62 }
63
64 if let Some(validate_issuer) = validate_issuer {
65 validation.set_issuer(&[validate_issuer]);
66 required_claims.push("iss");
67 }
68 if !required_claims.is_empty() {
69 validation.set_required_spec_claims(&required_claims);
70 }
71
72 let token_data =
73 decode::<AuthClaims>(token, &decoding_key, &validation).map_err(|err| {
74 match err.kind() {
75 jsonwebtoken::errors::ErrorKind::InvalidToken => {
76 AuthenticationError::InvalidToken {
77 description: "Invalid token",
78 }
79 }
80 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
81 AuthenticationError::InvalidToken {
82 description: "Expired token",
83 }
84 }
85 _ => AuthenticationError::TokenVerificationFailed {
86 description: err.to_string(),
87 status_code: Some(StatusCode::BAD_REQUEST.as_u16()),
88 },
89 }
90 })?;
91
92 Ok(token_data)
93 }
94}