Skip to main content

webgates_codecs/jwt/
remote_verifier.rs

1//! Fetch, cache, refresh, and use a remote JWKS document to verify JWTs.
2//!
3//! [`RemoteJwksVerifier`] fetches, caches, and refreshes a remote JWKS document
4//! and exposes a [`RemoteJwksVerifier::verify_token`] method that both Axum and
5//! Tonic integrations can depend on without duplicating fetch/cache/refresh logic.
6//!
7//! # Behavior
8//!
9//! - **Startup**: attempts a live JWKS fetch; falls back to a persistent cache if
10//!   the fetch fails; fails closed (returns an error) when neither source provides
11//!   valid ES384 keys.
12//! - **Background refresh**: call [`RemoteJwksVerifier::start_background_refresh`]
13//!   once after bootstrap to keep keys current.
14//! - **Unknown-`kid` recovery**: on a first verification failure caused by an
15//!   unknown `kid`, the verifier performs one bounded refresh before retrying.
16//! - **Request-path verification**: all JWT validation is local; no per-request
17//!   network I/O is performed.
18//!
19//! # Example
20//!
21//! ```rust,no_run
22//! use webgates_codecs::jwt::remote_verifier::{RemoteJwksVerifier, RemoteJwksVerifierConfig};
23//! use webgates_codecs::jwt::JwtClaims;
24//!
25//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
26//! let config = RemoteJwksVerifierConfig::from_jwks_url(
27//!     "https://auth.example.com/.well-known/jwks.json",
28//! );
29//! let verifier = RemoteJwksVerifier::<JwtClaims<()>>::bootstrap(config).await?;
30//! let _refresh_handle = verifier.start_background_refresh();
31//! # Ok(())
32//! # }
33//! ```
34
35use std::path::PathBuf;
36use std::sync::Arc;
37use std::time::Duration;
38
39use reqwest::Client;
40use serde::{Serialize, de::DeserializeOwned};
41use tokio::sync::RwLock;
42use tokio::task::JoinHandle;
43
44use crate::Codec as _;
45use crate::jwt::jwks::JwksDocument;
46use crate::jwt::{JsonWebToken, JsonWebTokenOptions};
47
48const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(3);
49const DEFAULT_REFRESH_INTERVAL: Duration = Duration::from_secs(60);
50
51type SharedVerifier<P> = Arc<RwLock<Option<Arc<JsonWebToken<P>>>>>;
52
53/// Configuration for a [`RemoteJwksVerifier`].
54#[derive(Clone, Debug)]
55pub struct RemoteJwksVerifierConfig {
56    /// URL of the remote JWKS endpoint.
57    pub jwks_url: String,
58    /// HTTP request timeout for JWKS fetches.
59    pub http_timeout: Duration,
60    /// Interval between background refresh attempts.
61    pub refresh_interval: Duration,
62    /// Optional path for persisting the JWKS document as a local cache.
63    pub cache_path: Option<PathBuf>,
64}
65
66impl RemoteJwksVerifierConfig {
67    /// Build a config from a JWKS URL with default timeout and refresh interval.
68    pub fn from_jwks_url(jwks_url: impl Into<String>) -> Self {
69        Self {
70            jwks_url: jwks_url.into(),
71            http_timeout: DEFAULT_HTTP_TIMEOUT,
72            refresh_interval: DEFAULT_REFRESH_INTERVAL,
73            cache_path: None,
74        }
75    }
76
77    /// Override the HTTP request timeout.
78    #[must_use]
79    pub fn with_http_timeout(mut self, timeout: Duration) -> Self {
80        self.http_timeout = timeout;
81        self
82    }
83
84    /// Override the background refresh interval.
85    #[must_use]
86    pub fn with_refresh_interval(mut self, refresh_interval: Duration) -> Self {
87        self.refresh_interval = refresh_interval;
88        self
89    }
90
91    /// Enable a persistent local cache at the given path.
92    #[must_use]
93    pub fn with_cache_path(mut self, cache_path: impl Into<PathBuf>) -> Self {
94        self.cache_path = Some(cache_path.into());
95        self
96    }
97}
98
99/// Errors produced by [`RemoteJwksVerifier`].
100#[derive(thiserror::Error, Debug)]
101pub enum RemoteJwksVerifierError {
102    /// HTTP client construction failed.
103    #[error("failed to build HTTP client: {0}")]
104    HttpClientBuild(#[from] reqwest::Error),
105    /// JWKS fetch failed.
106    #[error("failed to fetch JWKS document from {url}: {message}")]
107    Fetch {
108        /// The URL that was fetched.
109        url: String,
110        /// The error message.
111        message: String,
112    },
113    /// JWKS response could not be parsed.
114    #[error("failed to parse JWKS response: {0}")]
115    ParseResponse(String),
116    /// No valid ES384 keys were found in the JWKS document.
117    #[error("JWKS document did not contain any valid ES384 keys")]
118    NoValidKeys,
119    /// Cache write failed.
120    #[error("failed to persist JWKS cache at {path}: {message}")]
121    CacheWrite {
122        /// The cache path.
123        path: String,
124        /// The error message.
125        message: String,
126    },
127    /// Cache read failed.
128    #[error("failed to read JWKS cache at {path}: {message}")]
129    CacheRead {
130        /// The cache path.
131        path: String,
132        /// The error message.
133        message: String,
134    },
135    /// Token has no `kid` and no fallback key was available after refresh.
136    #[error("missing JWT `kid` and refresh did not provide a fallback key")]
137    MissingKidWithoutFallback,
138    /// Token `kid` was not found even after a refresh.
139    #[error("JWT key id `{kid}` not found after refresh")]
140    UnknownKid {
141        /// The unknown key id.
142        kid: String,
143    },
144    /// Token verification failed.
145    #[error("token verification failed: {0}")]
146    Verify(String),
147    /// Startup failed because no live or cached JWKS was available.
148    #[error("startup failed because no live JWKS or cached JWKS was available")]
149    StartupNoKeys,
150}
151
152/// Transport-agnostic remote JWKS verifier.
153///
154/// Clone-cheap: all state is behind `Arc`.
155#[derive(Clone)]
156pub struct RemoteJwksVerifier<P>
157where
158    P: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
159{
160    config: RemoteJwksVerifierConfig,
161    client: Client,
162    verifier: SharedVerifier<P>,
163    refresh_lock: Arc<tokio::sync::Mutex<()>>,
164}
165
166impl<P> RemoteJwksVerifier<P>
167where
168    P: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
169{
170    /// Bootstrap the verifier: attempt a live JWKS fetch, fall back to cache,
171    /// fail closed when neither source provides valid ES384 keys.
172    ///
173    /// # Errors
174    ///
175    /// Returns [`RemoteJwksVerifierError::StartupNoKeys`] when no keys are
176    /// available from either the live endpoint or the persistent cache.
177    pub async fn bootstrap(
178        config: RemoteJwksVerifierConfig,
179    ) -> Result<Self, RemoteJwksVerifierError> {
180        let client = Client::builder().timeout(config.http_timeout).build()?;
181        let verifier: SharedVerifier<P> = Arc::new(RwLock::new(None));
182
183        let this = Self {
184            config,
185            client,
186            verifier,
187            refresh_lock: Arc::new(tokio::sync::Mutex::new(())),
188        };
189
190        let mut has_cache = false;
191        if let Some(cached) = this.load_cached_verifier().await? {
192            *this.verifier.write().await = Some(cached);
193            has_cache = true;
194            tracing::warn!("starting with cached JWKS keys while attempting live refresh");
195        }
196
197        match this.refresh().await {
198            Ok(()) => {}
199            Err(error) if has_cache => {
200                tracing::warn!(error = %error, "live JWKS refresh failed, continuing with cached keys");
201            }
202            Err(_) => return Err(RemoteJwksVerifierError::StartupNoKeys),
203        }
204
205        Ok(this)
206    }
207
208    /// Spawn a background task that refreshes the JWKS document on the configured interval.
209    ///
210    /// The returned [`JoinHandle`] should be kept alive for the lifetime of the service.
211    /// Dropping it cancels the background refresh.
212    pub fn start_background_refresh(&self) -> JoinHandle<()> {
213        let refresh_interval = self.config.refresh_interval;
214        let this = self.clone();
215
216        tokio::spawn(async move {
217            let mut ticker = tokio::time::interval(refresh_interval);
218            loop {
219                ticker.tick().await;
220                if let Err(error) = this.refresh().await {
221                    tracing::warn!(error = %error, "background JWKS refresh failed");
222                }
223            }
224        })
225    }
226
227    /// Perform a single JWKS refresh: fetch, parse, optionally persist, and swap in.
228    ///
229    /// # Errors
230    ///
231    /// Returns an error if the fetch, parse, or cache write fails.
232    pub async fn refresh(&self) -> Result<(), RemoteJwksVerifierError> {
233        let _lock = self.refresh_lock.lock().await;
234        let jwks = self.fetch_jwks().await?;
235        let codec = Arc::new(codec_from_jwks(&jwks)?);
236
237        if let Some(cache_path) = &self.config.cache_path {
238            persist_jwks_cache(cache_path, &jwks).await?;
239        }
240
241        *self.verifier.write().await = Some(codec);
242        Ok(())
243    }
244
245    /// Verify a JWT token string against the current in-memory key set.
246    ///
247    /// On a first failure caused by an unknown `kid`, performs one bounded
248    /// refresh before retrying. All verification is local; no per-request
249    /// network I/O is performed on the happy path.
250    ///
251    /// # Errors
252    ///
253    /// Returns a [`RemoteJwksVerifierError`] when the token is invalid, the
254    /// `kid` is unknown after refresh, or no keys are loaded.
255    pub async fn verify_token(&self, token: &str) -> Result<P, RemoteJwksVerifierError> {
256        match self.verify_once(token).await {
257            Ok(claims) => Ok(claims),
258            Err(RemoteJwksVerifierError::Verify(ref message))
259                if message.contains("missing `kid`") || message.contains("not configured") =>
260            {
261                let message = message.clone();
262                self.refresh().await?;
263                match self.verify_once(token).await {
264                    Ok(claims) => Ok(claims),
265                    Err(RemoteJwksVerifierError::Verify(ref refreshed_message))
266                        if refreshed_message.contains("missing `kid`") =>
267                    {
268                        Err(RemoteJwksVerifierError::MissingKidWithoutFallback)
269                    }
270                    Err(RemoteJwksVerifierError::Verify(ref refreshed_message)) => {
271                        if let Some(kid) = kid_from_token_error(refreshed_message) {
272                            Err(RemoteJwksVerifierError::UnknownKid { kid })
273                        } else {
274                            Err(RemoteJwksVerifierError::Verify(refreshed_message.clone()))
275                        }
276                    }
277                    Err(error) => {
278                        let _ = message;
279                        Err(error)
280                    }
281                }
282            }
283            Err(error) => Err(error),
284        }
285    }
286
287    async fn verify_once(&self, token: &str) -> Result<P, RemoteJwksVerifierError> {
288        let verifier = self
289            .verifier
290            .read()
291            .await
292            .clone()
293            .ok_or(RemoteJwksVerifierError::StartupNoKeys)?;
294        verifier
295            .decode(token.as_bytes())
296            .map_err(|error: crate::Error| RemoteJwksVerifierError::Verify(error.to_string()))
297    }
298
299    async fn load_cached_verifier(
300        &self,
301    ) -> Result<Option<Arc<JsonWebToken<P>>>, RemoteJwksVerifierError> {
302        let Some(cache_path) = &self.config.cache_path else {
303            return Ok(None);
304        };
305
306        if !cache_path.exists() {
307            return Ok(None);
308        }
309
310        let raw = tokio::fs::read_to_string(cache_path)
311            .await
312            .map_err(|error| RemoteJwksVerifierError::CacheRead {
313                path: cache_path.display().to_string(),
314                message: error.to_string(),
315            })?;
316        let jwks: JwksDocument = serde_json::from_str(&raw)
317            .map_err(|error| RemoteJwksVerifierError::ParseResponse(error.to_string()))?;
318        let codec = Arc::new(codec_from_jwks(&jwks)?);
319        Ok(Some(codec))
320    }
321
322    async fn fetch_jwks(&self) -> Result<JwksDocument, RemoteJwksVerifierError> {
323        let response = self
324            .client
325            .get(&self.config.jwks_url)
326            .send()
327            .await
328            .map_err(|error| RemoteJwksVerifierError::Fetch {
329                url: self.config.jwks_url.clone(),
330                message: error.to_string(),
331            })?;
332
333        if !response.status().is_success() {
334            return Err(RemoteJwksVerifierError::Fetch {
335                url: self.config.jwks_url.clone(),
336                message: format!("unexpected HTTP status {}", response.status()),
337            });
338        }
339
340        response
341            .json::<JwksDocument>()
342            .await
343            .map_err(|error| RemoteJwksVerifierError::ParseResponse(error.to_string()))
344    }
345}
346
347fn kid_from_token_error(message: &str) -> Option<String> {
348    let marker = "JWT `kid` `";
349    let index = message.find(marker)? + marker.len();
350    let rest = &message[index..];
351    let end = rest.find('`')?;
352    Some(rest[..end].to_string())
353}
354
355fn codec_from_jwks<P>(document: &JwksDocument) -> Result<JsonWebToken<P>, RemoteJwksVerifierError>
356where
357    P: Serialize + DeserializeOwned + Clone,
358{
359    let keys: Vec<_> = document
360        .keys
361        .iter()
362        .filter(|key| {
363            key.alg == "ES384" && key.crv == "P-384" && key.kty == "EC" && key.use_ == "sig"
364        })
365        .cloned()
366        .collect();
367
368    if keys.is_empty() {
369        return Err(RemoteJwksVerifierError::NoValidKeys);
370    }
371
372    let options = JsonWebTokenOptions::for_es384_jwks_keys(&keys)
373        .map_err(|error| RemoteJwksVerifierError::Verify(error.to_string()))?;
374    Ok(JsonWebToken::new_with_options(options))
375}
376
377async fn persist_jwks_cache(
378    cache_path: &PathBuf,
379    jwks: &JwksDocument,
380) -> Result<(), RemoteJwksVerifierError> {
381    if let Some(parent) = cache_path.parent()
382        && !parent.as_os_str().is_empty()
383    {
384        tokio::fs::create_dir_all(parent).await.map_err(|error| {
385            RemoteJwksVerifierError::CacheWrite {
386                path: parent.display().to_string(),
387                message: error.to_string(),
388            }
389        })?;
390    }
391
392    let raw = serde_json::to_string_pretty(jwks)
393        .map_err(|error| RemoteJwksVerifierError::ParseResponse(error.to_string()))?;
394    tokio::fs::write(cache_path, raw)
395        .await
396        .map_err(|error| RemoteJwksVerifierError::CacheWrite {
397            path: cache_path.display().to_string(),
398            message: error.to_string(),
399        })
400}
401
402#[cfg(test)]
403#[allow(clippy::unwrap_used, clippy::expect_used)]
404mod tests {
405    use super::*;
406    use crate::jwt::jwks::EcP384Jwk;
407
408    #[test]
409    fn config_defaults_require_only_jwks_url() {
410        let config = RemoteJwksVerifierConfig::from_jwks_url(
411            "https://example.invalid/.well-known/jwks.json",
412        );
413
414        assert_eq!(
415            config.jwks_url,
416            "https://example.invalid/.well-known/jwks.json"
417        );
418        assert_eq!(config.http_timeout, Duration::from_secs(3));
419        assert_eq!(config.refresh_interval, Duration::from_secs(60));
420        assert!(config.cache_path.is_none());
421    }
422
423    #[test]
424    fn jwks_document_rejects_empty_keys() {
425        let document = JwksDocument { keys: vec![] };
426        let result = codec_from_jwks::<crate::jwt::JwtClaims<()>>(&document);
427        assert!(matches!(result, Err(RemoteJwksVerifierError::NoValidKeys)));
428    }
429
430    #[test]
431    fn kid_parser_extracts_unknown_kid() {
432        let message = "JWT `kid` `next-key` is not configured for verification";
433        assert_eq!(kid_from_token_error(message).as_deref(), Some("next-key"));
434    }
435
436    #[test]
437    fn codec_builds_from_valid_es384_jwks() {
438        const TEST_ES384_PUBLIC_KEY_PEM: &[u8] = br#"-----BEGIN PUBLIC KEY-----
439MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEsjQ/XkOUJO2bXkhDzKRMW1SXp0VsMqGx
440MSTG+tppqd3gOxbM8vLgWy4/B0Qdest0Gy3E8QgaKJXQV3zRczNd9zrk1dmwVl6u
441Yd+JfgNIeIFP6HWeu/C3wIJ60WDBuGY1
442-----END PUBLIC KEY-----
443"#;
444
445        let key = EcP384Jwk::from_public_key_pem("key-a", TEST_ES384_PUBLIC_KEY_PEM)
446            .expect("jwk generation should succeed");
447        let document = JwksDocument { keys: vec![key] };
448
449        let codec = codec_from_jwks::<crate::jwt::JwtClaims<()>>(&document)
450            .expect("codec should be created");
451        assert_eq!(codec.verification_key_count(), 1);
452    }
453}