1use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
10use serde::{Deserialize, Serialize};
11use tokio::sync::RwLock;
12use tracing::{debug, warn};
13
14#[derive(Debug, thiserror::Error)]
16#[non_exhaustive]
17pub enum JwtValidationError {
18 #[error("Invalid token: {0}")]
19 InvalidToken(String),
20 #[error("Token expired")]
21 TokenExpired,
22 #[error("Invalid audience")]
23 InvalidAudience,
24 #[error("Invalid issuer")]
25 InvalidIssuer,
26 #[error("Unsupported algorithm: {0}")]
27 UnsupportedAlgorithm(String),
28 #[error("JWKS fetch error: {0}")]
29 JwksFetchError(String),
30 #[error("Key not found: {0}")]
31 KeyNotFound(String),
32 #[error("Decoding error: {0}")]
33 DecodingError(String),
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct TokenClaims {
39 #[serde(default)]
40 pub sub: String,
41 #[serde(default)]
42 pub iss: String,
43 #[serde(default)]
45 pub aud: serde_json::Value,
46 #[serde(default)]
47 pub exp: u64,
48 #[serde(default)]
49 pub iat: u64,
50 #[serde(default)]
52 pub scope: Option<String>,
53 #[serde(flatten)]
55 pub extra: HashMap<String, serde_json::Value>,
56}
57
58#[derive(Debug, Clone, Deserialize)]
59struct JwksResponse {
60 keys: Vec<JwkKey>,
61}
62
63#[derive(Debug, Clone, Deserialize)]
64struct JwkKey {
65 kty: String,
66 kid: Option<String>,
67 alg: Option<String>,
68 n: Option<String>,
69 e: Option<String>,
70 crv: Option<String>,
71 x: Option<String>,
72 y: Option<String>,
73}
74
75struct CachedJwks {
76 keys: HashMap<String, (DecodingKey, Algorithm)>,
77 last_refresh_at: Instant,
78}
79
80pub struct JwtValidator {
84 jwks_uri: String,
85 cached_jwks: RwLock<Option<CachedJwks>>,
86 allowed_algorithms: Vec<Algorithm>,
87 issuer: Option<String>,
88 audience: Option<String>,
89 refresh_interval: Duration,
90 http_client: reqwest::Client,
91}
92
93impl JwtValidator {
94 pub fn new(jwks_uri: impl Into<String>, audience: impl Into<String>) -> Self {
95 Self {
96 jwks_uri: jwks_uri.into(),
97 cached_jwks: RwLock::new(None),
98 allowed_algorithms: vec![Algorithm::RS256, Algorithm::ES256],
99 issuer: None,
100 audience: Some(audience.into()),
101 refresh_interval: Duration::from_secs(60),
102 http_client: reqwest::Client::new(),
103 }
104 }
105
106 pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
107 self.issuer = Some(issuer.into());
108 self
109 }
110
111 pub fn with_algorithms(mut self, algorithms: Vec<Algorithm>) -> Self {
112 self.allowed_algorithms = algorithms;
113 self
114 }
115
116 pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
117 self.refresh_interval = interval;
118 self
119 }
120
121 pub async fn validate(&self, token: &str) -> Result<TokenClaims, JwtValidationError> {
123 let header = decode_header(token)
124 .map_err(|e| JwtValidationError::DecodingError(format!("Invalid JWT header: {e}")))?;
125
126 if !self.allowed_algorithms.contains(&header.alg) {
127 return Err(JwtValidationError::UnsupportedAlgorithm(format!(
128 "{:?}",
129 header.alg
130 )));
131 }
132
133 let kid = header.kid.as_deref().unwrap_or("default").to_string();
134 let (key, jwks_alg) = self.get_decoding_key(&kid).await?;
135
136 if header.alg != jwks_alg {
138 return Err(JwtValidationError::UnsupportedAlgorithm(format!(
139 "Token uses {:?} but JWKS key '{kid}' advertises {:?}",
140 header.alg, jwks_alg
141 )));
142 }
143
144 let mut validation = Validation::new(header.alg);
145 validation.validate_exp = true;
146
147 if let Some(ref iss) = self.issuer {
148 validation.set_issuer(&[iss]);
149 }
150
151 if let Some(ref aud) = self.audience {
152 validation.set_audience(&[aud]);
153 } else {
154 validation.validate_aud = false;
155 }
156
157 let token_data =
158 decode::<TokenClaims>(token, &key, &validation).map_err(|e| match e.kind() {
159 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
160 JwtValidationError::TokenExpired
161 }
162 jsonwebtoken::errors::ErrorKind::InvalidAudience => {
163 JwtValidationError::InvalidAudience
164 }
165 jsonwebtoken::errors::ErrorKind::InvalidIssuer => JwtValidationError::InvalidIssuer,
166 _ => JwtValidationError::InvalidToken(e.to_string()),
167 })?;
168
169 Ok(token_data.claims)
170 }
171
172 async fn get_decoding_key(
173 &self,
174 kid: &str,
175 ) -> Result<(DecodingKey, Algorithm), JwtValidationError> {
176 {
178 let cache = self.cached_jwks.read().await;
179 if let Some(ref cached) = *cache {
180 if let Some((key, alg)) = cached.keys.get(kid) {
181 return Ok((key.clone(), *alg));
182 }
183 }
184 }
185
186 self.refresh_jwks().await?;
188
189 let cache = self.cached_jwks.read().await;
191 if let Some(ref cached) = *cache {
192 if let Some((key, alg)) = cached.keys.get(kid) {
193 return Ok((key.clone(), *alg));
194 }
195 }
196
197 Err(JwtValidationError::KeyNotFound(kid.to_string()))
198 }
199
200 async fn refresh_jwks(&self) -> Result<(), JwtValidationError> {
201 {
203 let cache = self.cached_jwks.read().await;
204 if let Some(ref cached) = *cache {
205 if cached.last_refresh_at.elapsed() < self.refresh_interval {
206 debug!("JWKS refresh rate-limited, skipping");
207 return Ok(());
208 }
209 }
210 }
211
212 debug!("Fetching JWKS from {}", self.jwks_uri);
213
214 let response = self
215 .http_client
216 .get(&self.jwks_uri)
217 .send()
218 .await
219 .map_err(|e| JwtValidationError::JwksFetchError(e.to_string()))?;
220
221 let jwks: JwksResponse = response
222 .json()
223 .await
224 .map_err(|e| JwtValidationError::JwksFetchError(format!("Invalid JWKS JSON: {e}")))?;
225
226 let mut keys = HashMap::new();
227
228 for key in &jwks.keys {
229 let kid = key.kid.clone().unwrap_or_else(|| "default".to_string());
230
231 match key.kty.as_str() {
232 "RSA" => {
233 if let (Some(n), Some(e)) = (&key.n, &key.e) {
234 match DecodingKey::from_rsa_components(n, e) {
235 Ok(decoding_key) => {
236 let alg = key
237 .alg
238 .as_deref()
239 .and_then(|a| match a {
240 "RS256" => Some(Algorithm::RS256),
241 "RS384" => Some(Algorithm::RS384),
242 "RS512" => Some(Algorithm::RS512),
243 _ => None,
244 })
245 .unwrap_or(Algorithm::RS256);
246 keys.insert(kid, (decoding_key, alg));
247 }
248 Err(e) => warn!("Failed to parse RSA key: {e}"),
249 }
250 }
251 }
252 "EC" => {
253 if let (Some(x), Some(y), Some(crv)) = (&key.x, &key.y, &key.crv) {
254 match DecodingKey::from_ec_components(x, y) {
255 Ok(decoding_key) => {
256 let alg = match crv.as_str() {
257 "P-256" => Algorithm::ES256,
258 "P-384" => Algorithm::ES384,
259 _ => {
260 warn!("Unsupported EC curve: {crv}");
261 continue;
262 }
263 };
264 keys.insert(kid, (decoding_key, alg));
265 }
266 Err(e) => warn!("Failed to parse EC key: {e}"),
267 }
268 }
269 }
270 other => debug!("Skipping unsupported key type: {other}"),
271 }
272 }
273
274 debug!("JWKS loaded: {} keys", keys.len());
275
276 let now = Instant::now();
277 *self.cached_jwks.write().await = Some(CachedJwks {
278 keys,
279 last_refresh_at: now,
280 });
281
282 Ok(())
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn token_claims_deserializes_with_defaults() {
292 let json = r#"{"sub":"user-1","iss":"https://auth.example.com","exp":999999999}"#;
293 let claims: TokenClaims = serde_json::from_str(json).unwrap();
294 assert_eq!(claims.sub, "user-1");
295 assert_eq!(claims.iss, "https://auth.example.com");
296 assert_eq!(claims.exp, 999999999);
297 assert!(claims.scope.is_none());
298 }
299
300 #[test]
301 fn token_claims_handles_array_audience() {
302 let json = r#"{"sub":"u","aud":["a","b"],"exp":1}"#;
303 let claims: TokenClaims = serde_json::from_str(json).unwrap();
304 assert!(claims.aud.is_array());
305 }
306
307 #[test]
308 fn token_claims_handles_string_audience() {
309 let json = r#"{"sub":"u","aud":"single","exp":1}"#;
310 let claims: TokenClaims = serde_json::from_str(json).unwrap();
311 assert_eq!(claims.aud, "single");
312 }
313
314 #[test]
315 fn token_claims_captures_extra_fields() {
316 let json = r#"{"sub":"u","exp":1,"custom_field":"custom_value"}"#;
317 let claims: TokenClaims = serde_json::from_str(json).unwrap();
318 assert_eq!(claims.extra.get("custom_field").unwrap(), "custom_value");
319 }
320
321 #[test]
322 fn error_types_are_distinct() {
323 let errors: Vec<JwtValidationError> = vec![
324 JwtValidationError::InvalidToken("bad".into()),
325 JwtValidationError::TokenExpired,
326 JwtValidationError::InvalidAudience,
327 JwtValidationError::InvalidIssuer,
328 JwtValidationError::UnsupportedAlgorithm("HS256".into()),
329 JwtValidationError::JwksFetchError("network".into()),
330 JwtValidationError::KeyNotFound("kid-1".into()),
331 JwtValidationError::DecodingError("corrupt".into()),
332 ];
333 for err in &errors {
335 assert!(!err.to_string().is_empty());
336 }
337 }
338
339 #[test]
340 fn validator_builder_api() {
341 let _validator =
342 JwtValidator::new("https://example.com/.well-known/jwks.json", "my-audience")
343 .with_issuer("https://example.com")
344 .with_algorithms(vec![Algorithm::RS256])
345 .with_refresh_interval(Duration::from_secs(120));
346 }
347}