rust_mcp_sdk/auth/spec/
jwk.rs

1use crate::auth::{Audience, AuthClaims, AuthenticationError};
2use http::StatusCode;
3use jsonwebtoken::{decode, decode_header, jwk::Jwk, DecodingKey, TokenData, Validation};
4use serde::{Deserialize, Serialize};
5
6/// A JSON Web Key Set (JWKS) containing a list of JSON Web Keys.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct JsonWebKeySet {
9    /// List of JSON Web Keys.
10    pub keys: Vec<Jwk>,
11}
12
13pub fn decode_token_header(token: &str) -> Result<jsonwebtoken::Header, AuthenticationError> {
14    let header =
15        decode_header(token).map_err(|err| AuthenticationError::TokenVerificationFailed {
16            description: err.to_string(),
17            status_code: Some(StatusCode::UNAUTHORIZED.as_u16()),
18        })?;
19    Ok(header)
20}
21
22impl JsonWebKeySet {
23    pub fn verify(
24        &self,
25        token: String,
26        validate_audience: Option<&Audience>,
27        validate_issuer: Option<&String>,
28    ) -> Result<TokenData<AuthClaims>, AuthenticationError> {
29        let header = decode_token_header(&token)?;
30
31        let kid = header.kid.ok_or(AuthenticationError::InvalidToken {
32            description: "Missing kid in token header",
33        })?;
34
35        let jwk = self
36            .keys
37            .iter()
38            .find(|key| key.common.key_id == Some(kid.clone()))
39            .ok_or(AuthenticationError::InvalidToken {
40                description: "No matching key found in JWKS",
41            })?;
42
43        let decoding_key = DecodingKey::from_jwk(jwk).map_err(|err| {
44            AuthenticationError::TokenVerificationFailed {
45                description: err.to_string(),
46                status_code: None,
47            }
48        })?;
49
50        let mut validation = Validation::new(header.alg);
51
52        let mut required_claims = vec![];
53        if let Some(validate_audience) = validate_audience {
54            let vec_audience = match validate_audience {
55                Audience::Single(aud) => &vec![aud.to_owned()],
56                Audience::Multiple(auds) => auds,
57            };
58            validation.set_audience(vec_audience);
59            required_claims.push("aud");
60        } else {
61            validation.validate_aud = false;
62        }
63
64        if let Some(validate_issuer) = validate_issuer {
65            validation.set_issuer(&[validate_issuer]);
66            required_claims.push("iss");
67        }
68        if !required_claims.is_empty() {
69            validation.set_required_spec_claims(&required_claims);
70        }
71
72        let token_data =
73            decode::<AuthClaims>(token, &decoding_key, &validation).map_err(|err| {
74                match err.kind() {
75                    jsonwebtoken::errors::ErrorKind::InvalidToken => {
76                        AuthenticationError::InvalidToken {
77                            description: "Invalid token",
78                        }
79                    }
80                    jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
81                        AuthenticationError::InvalidToken {
82                            description: "Expired token",
83                        }
84                    }
85                    _ => AuthenticationError::TokenVerificationFailed {
86                        description: err.to_string(),
87                        status_code: Some(StatusCode::BAD_REQUEST.as_u16()),
88                    },
89                }
90            })?;
91
92        Ok(token_data)
93    }
94}