Skip to main content

turul_jwt_validator/
lib.rs

1//! Generic JWT validator with JWKS caching and kid-miss refresh.
2//!
3//! Supports RS256/RS384/RS512 and ES256/ES384 signature verification via
4//! [`jsonwebtoken`], with in-memory JWKS caching and kid-miss refresh.
5
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
10use serde::{Deserialize, Serialize};
11use tokio::sync::RwLock;
12use tracing::{debug, warn};
13
14/// Errors from JWT validation.
15#[derive(Debug, thiserror::Error)]
16#[non_exhaustive]
17pub enum JwtValidationError {
18    #[error("Invalid token: {0}")]
19    InvalidToken(String),
20    #[error("Token expired")]
21    TokenExpired,
22    #[error("Invalid audience")]
23    InvalidAudience,
24    #[error("Invalid issuer")]
25    InvalidIssuer,
26    #[error("Unsupported algorithm: {0}")]
27    UnsupportedAlgorithm(String),
28    #[error("JWKS fetch error: {0}")]
29    JwksFetchError(String),
30    #[error("Key not found: {0}")]
31    KeyNotFound(String),
32    #[error("Decoding error: {0}")]
33    DecodingError(String),
34}
35
36/// Validated token claims extracted from a JWT.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct TokenClaims {
39    #[serde(default)]
40    pub sub: String,
41    #[serde(default)]
42    pub iss: String,
43    /// Audience — can be string or array in JWT.
44    #[serde(default)]
45    pub aud: serde_json::Value,
46    #[serde(default)]
47    pub exp: u64,
48    #[serde(default)]
49    pub iat: u64,
50    /// Scopes (space-separated string).
51    #[serde(default)]
52    pub scope: Option<String>,
53    /// All other claims.
54    #[serde(flatten)]
55    pub extra: HashMap<String, serde_json::Value>,
56}
57
58#[derive(Debug, Clone, Deserialize)]
59struct JwksResponse {
60    keys: Vec<JwkKey>,
61}
62
63#[derive(Debug, Clone, Deserialize)]
64struct JwkKey {
65    kty: String,
66    kid: Option<String>,
67    alg: Option<String>,
68    n: Option<String>,
69    e: Option<String>,
70    crv: Option<String>,
71    x: Option<String>,
72    y: Option<String>,
73}
74
75struct CachedJwks {
76    keys: HashMap<String, (DecodingKey, Algorithm)>,
77    last_refresh_at: Instant,
78}
79
80/// JWT validator with JWKS caching and kid-miss refresh.
81///
82/// Supports RS256 and ES256 by default. Rate-limits JWKS fetches.
83pub struct JwtValidator {
84    jwks_uri: String,
85    cached_jwks: RwLock<Option<CachedJwks>>,
86    allowed_algorithms: Vec<Algorithm>,
87    issuer: Option<String>,
88    audience: Option<String>,
89    refresh_interval: Duration,
90    http_client: reqwest::Client,
91}
92
93impl JwtValidator {
94    pub fn new(jwks_uri: impl Into<String>, audience: impl Into<String>) -> Self {
95        Self {
96            jwks_uri: jwks_uri.into(),
97            cached_jwks: RwLock::new(None),
98            allowed_algorithms: vec![Algorithm::RS256, Algorithm::ES256],
99            issuer: None,
100            audience: Some(audience.into()),
101            refresh_interval: Duration::from_secs(60),
102            http_client: reqwest::Client::new(),
103        }
104    }
105
106    pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
107        self.issuer = Some(issuer.into());
108        self
109    }
110
111    pub fn with_algorithms(mut self, algorithms: Vec<Algorithm>) -> Self {
112        self.allowed_algorithms = algorithms;
113        self
114    }
115
116    pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
117        self.refresh_interval = interval;
118        self
119    }
120
121    /// Validate a JWT token and return the claims.
122    pub async fn validate(&self, token: &str) -> Result<TokenClaims, JwtValidationError> {
123        let header = decode_header(token)
124            .map_err(|e| JwtValidationError::DecodingError(format!("Invalid JWT header: {e}")))?;
125
126        if !self.allowed_algorithms.contains(&header.alg) {
127            return Err(JwtValidationError::UnsupportedAlgorithm(format!(
128                "{:?}",
129                header.alg
130            )));
131        }
132
133        let kid = header.kid.as_deref().unwrap_or("default").to_string();
134        let (key, jwks_alg) = self.get_decoding_key(&kid).await?;
135
136        // Cross-check: token algorithm must match JWKS-advertised algorithm
137        if header.alg != jwks_alg {
138            return Err(JwtValidationError::UnsupportedAlgorithm(format!(
139                "Token uses {:?} but JWKS key '{kid}' advertises {:?}",
140                header.alg, jwks_alg
141            )));
142        }
143
144        let mut validation = Validation::new(header.alg);
145        validation.validate_exp = true;
146
147        if let Some(ref iss) = self.issuer {
148            validation.set_issuer(&[iss]);
149        }
150
151        if let Some(ref aud) = self.audience {
152            validation.set_audience(&[aud]);
153        } else {
154            validation.validate_aud = false;
155        }
156
157        let token_data =
158            decode::<TokenClaims>(token, &key, &validation).map_err(|e| match e.kind() {
159                jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
160                    JwtValidationError::TokenExpired
161                }
162                jsonwebtoken::errors::ErrorKind::InvalidAudience => {
163                    JwtValidationError::InvalidAudience
164                }
165                jsonwebtoken::errors::ErrorKind::InvalidIssuer => JwtValidationError::InvalidIssuer,
166                _ => JwtValidationError::InvalidToken(e.to_string()),
167            })?;
168
169        Ok(token_data.claims)
170    }
171
172    async fn get_decoding_key(
173        &self,
174        kid: &str,
175    ) -> Result<(DecodingKey, Algorithm), JwtValidationError> {
176        // Try cache first
177        {
178            let cache = self.cached_jwks.read().await;
179            if let Some(ref cached) = *cache {
180                if let Some((key, alg)) = cached.keys.get(kid) {
181                    return Ok((key.clone(), *alg));
182                }
183            }
184        }
185
186        // Cache miss — refresh JWKS
187        self.refresh_jwks().await?;
188
189        // Try again
190        let cache = self.cached_jwks.read().await;
191        if let Some(ref cached) = *cache {
192            if let Some((key, alg)) = cached.keys.get(kid) {
193                return Ok((key.clone(), *alg));
194            }
195        }
196
197        Err(JwtValidationError::KeyNotFound(kid.to_string()))
198    }
199
200    async fn refresh_jwks(&self) -> Result<(), JwtValidationError> {
201        // Rate limit
202        {
203            let cache = self.cached_jwks.read().await;
204            if let Some(ref cached) = *cache {
205                if cached.last_refresh_at.elapsed() < self.refresh_interval {
206                    debug!("JWKS refresh rate-limited, skipping");
207                    return Ok(());
208                }
209            }
210        }
211
212        debug!("Fetching JWKS from {}", self.jwks_uri);
213
214        let response = self
215            .http_client
216            .get(&self.jwks_uri)
217            .send()
218            .await
219            .map_err(|e| JwtValidationError::JwksFetchError(e.to_string()))?;
220
221        let jwks: JwksResponse = response
222            .json()
223            .await
224            .map_err(|e| JwtValidationError::JwksFetchError(format!("Invalid JWKS JSON: {e}")))?;
225
226        let mut keys = HashMap::new();
227
228        for key in &jwks.keys {
229            let kid = key.kid.clone().unwrap_or_else(|| "default".to_string());
230
231            match key.kty.as_str() {
232                "RSA" => {
233                    if let (Some(n), Some(e)) = (&key.n, &key.e) {
234                        match DecodingKey::from_rsa_components(n, e) {
235                            Ok(decoding_key) => {
236                                let alg = key
237                                    .alg
238                                    .as_deref()
239                                    .and_then(|a| match a {
240                                        "RS256" => Some(Algorithm::RS256),
241                                        "RS384" => Some(Algorithm::RS384),
242                                        "RS512" => Some(Algorithm::RS512),
243                                        _ => None,
244                                    })
245                                    .unwrap_or(Algorithm::RS256);
246                                keys.insert(kid, (decoding_key, alg));
247                            }
248                            Err(e) => warn!("Failed to parse RSA key: {e}"),
249                        }
250                    }
251                }
252                "EC" => {
253                    if let (Some(x), Some(y), Some(crv)) = (&key.x, &key.y, &key.crv) {
254                        match DecodingKey::from_ec_components(x, y) {
255                            Ok(decoding_key) => {
256                                let alg = match crv.as_str() {
257                                    "P-256" => Algorithm::ES256,
258                                    "P-384" => Algorithm::ES384,
259                                    _ => {
260                                        warn!("Unsupported EC curve: {crv}");
261                                        continue;
262                                    }
263                                };
264                                keys.insert(kid, (decoding_key, alg));
265                            }
266                            Err(e) => warn!("Failed to parse EC key: {e}"),
267                        }
268                    }
269                }
270                other => debug!("Skipping unsupported key type: {other}"),
271            }
272        }
273
274        debug!("JWKS loaded: {} keys", keys.len());
275
276        let now = Instant::now();
277        *self.cached_jwks.write().await = Some(CachedJwks {
278            keys,
279            last_refresh_at: now,
280        });
281
282        Ok(())
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn token_claims_deserializes_with_defaults() {
292        let json = r#"{"sub":"user-1","iss":"https://auth.example.com","exp":999999999}"#;
293        let claims: TokenClaims = serde_json::from_str(json).unwrap();
294        assert_eq!(claims.sub, "user-1");
295        assert_eq!(claims.iss, "https://auth.example.com");
296        assert_eq!(claims.exp, 999999999);
297        assert!(claims.scope.is_none());
298    }
299
300    #[test]
301    fn token_claims_handles_array_audience() {
302        let json = r#"{"sub":"u","aud":["a","b"],"exp":1}"#;
303        let claims: TokenClaims = serde_json::from_str(json).unwrap();
304        assert!(claims.aud.is_array());
305    }
306
307    #[test]
308    fn token_claims_handles_string_audience() {
309        let json = r#"{"sub":"u","aud":"single","exp":1}"#;
310        let claims: TokenClaims = serde_json::from_str(json).unwrap();
311        assert_eq!(claims.aud, "single");
312    }
313
314    #[test]
315    fn token_claims_captures_extra_fields() {
316        let json = r#"{"sub":"u","exp":1,"custom_field":"custom_value"}"#;
317        let claims: TokenClaims = serde_json::from_str(json).unwrap();
318        assert_eq!(claims.extra.get("custom_field").unwrap(), "custom_value");
319    }
320
321    #[test]
322    fn error_types_are_distinct() {
323        let errors: Vec<JwtValidationError> = vec![
324            JwtValidationError::InvalidToken("bad".into()),
325            JwtValidationError::TokenExpired,
326            JwtValidationError::InvalidAudience,
327            JwtValidationError::InvalidIssuer,
328            JwtValidationError::UnsupportedAlgorithm("HS256".into()),
329            JwtValidationError::JwksFetchError("network".into()),
330            JwtValidationError::KeyNotFound("kid-1".into()),
331            JwtValidationError::DecodingError("corrupt".into()),
332        ];
333        // All should have non-empty display
334        for err in &errors {
335            assert!(!err.to_string().is_empty());
336        }
337    }
338
339    #[test]
340    fn validator_builder_api() {
341        let _validator =
342            JwtValidator::new("https://example.com/.well-known/jwks.json", "my-audience")
343                .with_issuer("https://example.com")
344                .with_algorithms(vec![Algorithm::RS256])
345                .with_refresh_interval(Duration::from_secs(120));
346    }
347}