1use std::path::PathBuf;
36use std::sync::Arc;
37use std::time::Duration;
38
39use reqwest::Client;
40use serde::{Serialize, de::DeserializeOwned};
41use tokio::sync::RwLock;
42use tokio::task::JoinHandle;
43
44use crate::Codec as _;
45use crate::jwt::jwks::JwksDocument;
46use crate::jwt::{JsonWebToken, JsonWebTokenOptions};
47
48const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(3);
49const DEFAULT_REFRESH_INTERVAL: Duration = Duration::from_secs(60);
50
51type SharedVerifier<P> = Arc<RwLock<Option<Arc<JsonWebToken<P>>>>>;
52
53#[derive(Clone, Debug)]
55pub struct RemoteJwksVerifierConfig {
56 pub jwks_url: String,
58 pub http_timeout: Duration,
60 pub refresh_interval: Duration,
62 pub cache_path: Option<PathBuf>,
64}
65
66impl RemoteJwksVerifierConfig {
67 pub fn from_jwks_url(jwks_url: impl Into<String>) -> Self {
69 Self {
70 jwks_url: jwks_url.into(),
71 http_timeout: DEFAULT_HTTP_TIMEOUT,
72 refresh_interval: DEFAULT_REFRESH_INTERVAL,
73 cache_path: None,
74 }
75 }
76
77 #[must_use]
79 pub fn with_http_timeout(mut self, timeout: Duration) -> Self {
80 self.http_timeout = timeout;
81 self
82 }
83
84 #[must_use]
86 pub fn with_refresh_interval(mut self, refresh_interval: Duration) -> Self {
87 self.refresh_interval = refresh_interval;
88 self
89 }
90
91 #[must_use]
93 pub fn with_cache_path(mut self, cache_path: impl Into<PathBuf>) -> Self {
94 self.cache_path = Some(cache_path.into());
95 self
96 }
97}
98
99#[derive(thiserror::Error, Debug)]
101pub enum RemoteJwksVerifierError {
102 #[error("failed to build HTTP client: {0}")]
104 HttpClientBuild(#[from] reqwest::Error),
105 #[error("failed to fetch JWKS document from {url}: {message}")]
107 Fetch {
108 url: String,
110 message: String,
112 },
113 #[error("failed to parse JWKS response: {0}")]
115 ParseResponse(String),
116 #[error("JWKS document did not contain any valid ES384 keys")]
118 NoValidKeys,
119 #[error("failed to persist JWKS cache at {path}: {message}")]
121 CacheWrite {
122 path: String,
124 message: String,
126 },
127 #[error("failed to read JWKS cache at {path}: {message}")]
129 CacheRead {
130 path: String,
132 message: String,
134 },
135 #[error("missing JWT `kid` and refresh did not provide a fallback key")]
137 MissingKidWithoutFallback,
138 #[error("JWT key id `{kid}` not found after refresh")]
140 UnknownKid {
141 kid: String,
143 },
144 #[error("token verification failed: {0}")]
146 Verify(String),
147 #[error("startup failed because no live JWKS or cached JWKS was available")]
149 StartupNoKeys,
150}
151
152#[derive(Clone)]
156pub struct RemoteJwksVerifier<P>
157where
158 P: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
159{
160 config: RemoteJwksVerifierConfig,
161 client: Client,
162 verifier: SharedVerifier<P>,
163 refresh_lock: Arc<tokio::sync::Mutex<()>>,
164}
165
166impl<P> RemoteJwksVerifier<P>
167where
168 P: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
169{
170 pub async fn bootstrap(
178 config: RemoteJwksVerifierConfig,
179 ) -> Result<Self, RemoteJwksVerifierError> {
180 let client = Client::builder().timeout(config.http_timeout).build()?;
181 let verifier: SharedVerifier<P> = Arc::new(RwLock::new(None));
182
183 let this = Self {
184 config,
185 client,
186 verifier,
187 refresh_lock: Arc::new(tokio::sync::Mutex::new(())),
188 };
189
190 let mut has_cache = false;
191 if let Some(cached) = this.load_cached_verifier().await? {
192 *this.verifier.write().await = Some(cached);
193 has_cache = true;
194 tracing::warn!("starting with cached JWKS keys while attempting live refresh");
195 }
196
197 match this.refresh().await {
198 Ok(()) => {}
199 Err(error) if has_cache => {
200 tracing::warn!(error = %error, "live JWKS refresh failed, continuing with cached keys");
201 }
202 Err(_) => return Err(RemoteJwksVerifierError::StartupNoKeys),
203 }
204
205 Ok(this)
206 }
207
208 pub fn start_background_refresh(&self) -> JoinHandle<()> {
213 let refresh_interval = self.config.refresh_interval;
214 let this = self.clone();
215
216 tokio::spawn(async move {
217 let mut ticker = tokio::time::interval(refresh_interval);
218 loop {
219 ticker.tick().await;
220 if let Err(error) = this.refresh().await {
221 tracing::warn!(error = %error, "background JWKS refresh failed");
222 }
223 }
224 })
225 }
226
227 pub async fn refresh(&self) -> Result<(), RemoteJwksVerifierError> {
233 let _lock = self.refresh_lock.lock().await;
234 let jwks = self.fetch_jwks().await?;
235 let codec = Arc::new(codec_from_jwks(&jwks)?);
236
237 if let Some(cache_path) = &self.config.cache_path {
238 persist_jwks_cache(cache_path, &jwks).await?;
239 }
240
241 *self.verifier.write().await = Some(codec);
242 Ok(())
243 }
244
245 pub async fn verify_token(&self, token: &str) -> Result<P, RemoteJwksVerifierError> {
256 match self.verify_once(token).await {
257 Ok(claims) => Ok(claims),
258 Err(RemoteJwksVerifierError::Verify(ref message))
259 if message.contains("missing `kid`") || message.contains("not configured") =>
260 {
261 let message = message.clone();
262 self.refresh().await?;
263 match self.verify_once(token).await {
264 Ok(claims) => Ok(claims),
265 Err(RemoteJwksVerifierError::Verify(ref refreshed_message))
266 if refreshed_message.contains("missing `kid`") =>
267 {
268 Err(RemoteJwksVerifierError::MissingKidWithoutFallback)
269 }
270 Err(RemoteJwksVerifierError::Verify(ref refreshed_message)) => {
271 if let Some(kid) = kid_from_token_error(refreshed_message) {
272 Err(RemoteJwksVerifierError::UnknownKid { kid })
273 } else {
274 Err(RemoteJwksVerifierError::Verify(refreshed_message.clone()))
275 }
276 }
277 Err(error) => {
278 let _ = message;
279 Err(error)
280 }
281 }
282 }
283 Err(error) => Err(error),
284 }
285 }
286
287 async fn verify_once(&self, token: &str) -> Result<P, RemoteJwksVerifierError> {
288 let verifier = self
289 .verifier
290 .read()
291 .await
292 .clone()
293 .ok_or(RemoteJwksVerifierError::StartupNoKeys)?;
294 verifier
295 .decode(token.as_bytes())
296 .map_err(|error: crate::Error| RemoteJwksVerifierError::Verify(error.to_string()))
297 }
298
299 async fn load_cached_verifier(
300 &self,
301 ) -> Result<Option<Arc<JsonWebToken<P>>>, RemoteJwksVerifierError> {
302 let Some(cache_path) = &self.config.cache_path else {
303 return Ok(None);
304 };
305
306 if !cache_path.exists() {
307 return Ok(None);
308 }
309
310 let raw = tokio::fs::read_to_string(cache_path)
311 .await
312 .map_err(|error| RemoteJwksVerifierError::CacheRead {
313 path: cache_path.display().to_string(),
314 message: error.to_string(),
315 })?;
316 let jwks: JwksDocument = serde_json::from_str(&raw)
317 .map_err(|error| RemoteJwksVerifierError::ParseResponse(error.to_string()))?;
318 let codec = Arc::new(codec_from_jwks(&jwks)?);
319 Ok(Some(codec))
320 }
321
322 async fn fetch_jwks(&self) -> Result<JwksDocument, RemoteJwksVerifierError> {
323 let response = self
324 .client
325 .get(&self.config.jwks_url)
326 .send()
327 .await
328 .map_err(|error| RemoteJwksVerifierError::Fetch {
329 url: self.config.jwks_url.clone(),
330 message: error.to_string(),
331 })?;
332
333 if !response.status().is_success() {
334 return Err(RemoteJwksVerifierError::Fetch {
335 url: self.config.jwks_url.clone(),
336 message: format!("unexpected HTTP status {}", response.status()),
337 });
338 }
339
340 response
341 .json::<JwksDocument>()
342 .await
343 .map_err(|error| RemoteJwksVerifierError::ParseResponse(error.to_string()))
344 }
345}
346
347fn kid_from_token_error(message: &str) -> Option<String> {
348 let marker = "JWT `kid` `";
349 let index = message.find(marker)? + marker.len();
350 let rest = &message[index..];
351 let end = rest.find('`')?;
352 Some(rest[..end].to_string())
353}
354
355fn codec_from_jwks<P>(document: &JwksDocument) -> Result<JsonWebToken<P>, RemoteJwksVerifierError>
356where
357 P: Serialize + DeserializeOwned + Clone,
358{
359 let keys: Vec<_> = document
360 .keys
361 .iter()
362 .filter(|key| {
363 key.alg == "ES384" && key.crv == "P-384" && key.kty == "EC" && key.use_ == "sig"
364 })
365 .cloned()
366 .collect();
367
368 if keys.is_empty() {
369 return Err(RemoteJwksVerifierError::NoValidKeys);
370 }
371
372 let options = JsonWebTokenOptions::for_es384_jwks_keys(&keys)
373 .map_err(|error| RemoteJwksVerifierError::Verify(error.to_string()))?;
374 Ok(JsonWebToken::new_with_options(options))
375}
376
377async fn persist_jwks_cache(
378 cache_path: &PathBuf,
379 jwks: &JwksDocument,
380) -> Result<(), RemoteJwksVerifierError> {
381 if let Some(parent) = cache_path.parent()
382 && !parent.as_os_str().is_empty()
383 {
384 tokio::fs::create_dir_all(parent).await.map_err(|error| {
385 RemoteJwksVerifierError::CacheWrite {
386 path: parent.display().to_string(),
387 message: error.to_string(),
388 }
389 })?;
390 }
391
392 let raw = serde_json::to_string_pretty(jwks)
393 .map_err(|error| RemoteJwksVerifierError::ParseResponse(error.to_string()))?;
394 tokio::fs::write(cache_path, raw)
395 .await
396 .map_err(|error| RemoteJwksVerifierError::CacheWrite {
397 path: cache_path.display().to_string(),
398 message: error.to_string(),
399 })
400}
401
402#[cfg(test)]
403#[allow(clippy::unwrap_used, clippy::expect_used)]
404mod tests {
405 use super::*;
406 use crate::jwt::jwks::EcP384Jwk;
407
408 #[test]
409 fn config_defaults_require_only_jwks_url() {
410 let config = RemoteJwksVerifierConfig::from_jwks_url(
411 "https://example.invalid/.well-known/jwks.json",
412 );
413
414 assert_eq!(
415 config.jwks_url,
416 "https://example.invalid/.well-known/jwks.json"
417 );
418 assert_eq!(config.http_timeout, Duration::from_secs(3));
419 assert_eq!(config.refresh_interval, Duration::from_secs(60));
420 assert!(config.cache_path.is_none());
421 }
422
423 #[test]
424 fn jwks_document_rejects_empty_keys() {
425 let document = JwksDocument { keys: vec![] };
426 let result = codec_from_jwks::<crate::jwt::JwtClaims<()>>(&document);
427 assert!(matches!(result, Err(RemoteJwksVerifierError::NoValidKeys)));
428 }
429
430 #[test]
431 fn kid_parser_extracts_unknown_kid() {
432 let message = "JWT `kid` `next-key` is not configured for verification";
433 assert_eq!(kid_from_token_error(message).as_deref(), Some("next-key"));
434 }
435
436 #[test]
437 fn codec_builds_from_valid_es384_jwks() {
438 const TEST_ES384_PUBLIC_KEY_PEM: &[u8] = br#"-----BEGIN PUBLIC KEY-----
439MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEsjQ/XkOUJO2bXkhDzKRMW1SXp0VsMqGx
440MSTG+tppqd3gOxbM8vLgWy4/B0Qdest0Gy3E8QgaKJXQV3zRczNd9zrk1dmwVl6u
441Yd+JfgNIeIFP6HWeu/C3wIJ60WDBuGY1
442-----END PUBLIC KEY-----
443"#;
444
445 let key = EcP384Jwk::from_public_key_pem("key-a", TEST_ES384_PUBLIC_KEY_PEM)
446 .expect("jwk generation should succeed");
447 let document = JwksDocument { keys: vec![key] };
448
449 let codec = codec_from_jwks::<crate::jwt::JwtClaims<()>>(&document)
450 .expect("codec should be created");
451 assert_eq!(codec.verification_key_count(), 1);
452 }
453}