Skip to main content

tako_rs_plugins/middleware/
jwt_auth.rs

1//! JWT (JSON Web Token) authentication middleware.
2//!
3//! Trait-based: implement [`JwtVerifier`] with your preferred JWT library
4//! and pass it to [`JwtAuth`]. Enable the `jwt-simple` cargo feature for the
5//! batteries-included verifier built on top of `jwt-simple` — it supports
6//! HMAC, RSA, RSA-PSS, ECDSA, `EdDSA` and `BLAKE2b`.
7//!
8//! v2 additions:
9//!
10//! - **JWKS rotation** via [`stores::JwksProvider`](crate::stores::JwksProvider).
11//!   The bundled `MultiKeyVerifier` (under the `jwt-simple` feature) selects keys by `kid`, falling back to
12//!   the configured static map when the provider returns no match.
13//! - **Configurable issuer / audience / leeway** through
14//!   [`VerifyConstraints`]. Applied uniformly across every algorithm.
15//! - **Revocation list** via the [`RevocationList`] trait — simple in-memory
16//!   `HashSet<String>` of revoked `jti` values is provided.
17//! - **Optional remote introspection** via [`IntrospectionFn`] — the
18//!   middleware calls back on every request when configured, which is the
19//!   correct hook for opaque tokens or tenant-scoped revocation.
20
21use std::fmt;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::Arc;
25
26use http::StatusCode;
27use http::header::AUTHORIZATION;
28use scc::HashSet as SccHashSet;
29use tako_rs_core::middleware::IntoMiddleware;
30use tako_rs_core::middleware::Next;
31use tako_rs_core::responder::Responder;
32use tako_rs_core::types::Request;
33use tako_rs_core::types::Response;
34
35/// Trait for verifying JWT tokens.
36pub trait JwtVerifier: Send + Sync + Clone + 'static {
37  /// Decoded claims inserted into request extensions.
38  type Claims: Send + Sync + Clone + 'static;
39  /// Verification error.
40  type Error: fmt::Display;
41
42  /// Verifies a raw JWT token string.
43  fn verify(&self, token: &str) -> Result<Self::Claims, Self::Error>;
44
45  /// Validate `iss` / `aud` / `leeway` constraints against the decoded claims.
46  ///
47  /// The default implementation **fails closed** when any non-default
48  /// constraint is configured — concrete verifiers MUST override this if they
49  /// want to silently accept (because they already enforce constraints
50  /// internally) or to apply their own logic. Failing closed prevents the
51  /// previous v1.x behavior where custom verifiers silently dropped the
52  /// `VerifyConstraints` configured on `JwtAuth`, leaving iss/aud/leeway
53  /// unenforced.
54  fn validate_constraints(
55    &self,
56    _claims: &Self::Claims,
57    constraints: &VerifyConstraints,
58  ) -> Result<(), ConstraintsNotSupported> {
59    if constraints.issuer.is_some()
60      || constraints.audience.is_some()
61      || constraints.leeway_secs != 0
62    {
63      Err(ConstraintsNotSupported {
64        reason: "this JwtVerifier does not override `validate_constraints`; \
65                 configure constraints on the verifier itself or implement \
66                 `validate_constraints` on your custom verifier",
67      })
68    } else {
69      Ok(())
70    }
71  }
72}
73
74/// Reported by [`JwtVerifier::validate_constraints`] when the verifier cannot
75/// (or won't) enforce the requested `VerifyConstraints`. The middleware
76/// surfaces this as 401 Unauthorized — fail-closed by design.
77#[derive(Debug, Clone)]
78pub struct ConstraintsNotSupported {
79  /// Human-readable diagnostic surfaced in the 401 response body.
80  pub reason: &'static str,
81}
82
83impl fmt::Display for ConstraintsNotSupported {
84  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85    write!(f, "constraints not enforceable: {}", self.reason)
86  }
87}
88
89/// Optional global verification constraints applied on top of the verifier.
90#[derive(Default, Clone)]
91pub struct VerifyConstraints {
92  /// Required issuer (`iss` claim).
93  pub issuer: Option<String>,
94  /// Required audience (`aud` claim).
95  pub audience: Option<String>,
96  /// Allowed clock skew in seconds.
97  pub leeway_secs: u64,
98}
99
100/// Token revocation list interface (sync because revocation is on the hot
101/// path and remote checks should go through a cache).
102pub trait RevocationList: Send + Sync + 'static {
103  fn is_revoked(&self, jti: &str) -> bool;
104}
105
106/// Default in-memory revocation list keyed by `jti` (JWT ID claim).
107#[derive(Default, Clone)]
108pub struct InMemoryRevocationList {
109  inner: Arc<SccHashSet<String>>,
110}
111
112impl InMemoryRevocationList {
113  pub fn new() -> Self {
114    Self::default()
115  }
116
117  pub fn revoke(&self, jti: impl Into<String>) {
118    let _ = self.inner.insert_sync(jti.into());
119  }
120
121  pub fn unrevoke(&self, jti: &str) {
122    let _ = self.inner.remove_sync(jti);
123  }
124}
125
126impl RevocationList for InMemoryRevocationList {
127  fn is_revoked(&self, jti: &str) -> bool {
128    self.inner.contains_sync(jti)
129  }
130}
131
132/// Optional remote introspection. Returns true when the token is still
133/// valid; false when it has been revoked / expired upstream.
134pub type IntrospectionFn =
135  Arc<dyn Fn(&str) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> + Send + Sync + 'static>;
136
137/// Closure that extracts a `jti` (or any revocation-list key) from the
138/// verifier's decoded claims. Required when wiring up [`JwtAuth::revocation`].
139pub type JtiExtractorFn<C> = Arc<dyn Fn(&C) -> Option<String> + Send + Sync + 'static>;
140
141/// Pair of [`RevocationList`] and a JTI extractor used to wire revocation onto a verifier.
142pub type RevocationCheck<C> = (Arc<dyn RevocationList>, JtiExtractorFn<C>);
143
144/// JWT authentication middleware.
145pub struct JwtAuth<V: JwtVerifier> {
146  verifier: V,
147  constraints: VerifyConstraints,
148  revocation: Option<RevocationCheck<V::Claims>>,
149  introspect: Option<IntrospectionFn>,
150}
151
152impl<V: JwtVerifier> JwtAuth<V> {
153  /// Creates a JWT auth middleware with the given verifier and no extra
154  /// constraints / revocation.
155  pub fn new(verifier: V) -> Self {
156    Self {
157      verifier,
158      constraints: VerifyConstraints::default(),
159      revocation: None,
160      introspect: None,
161    }
162  }
163
164  /// Sets per-claim constraints (issuer, audience, leeway).
165  pub fn constraints(mut self, c: VerifyConstraints) -> Self {
166    self.constraints = c;
167    self
168  }
169
170  /// Plugs a revocation list checked after signature verification.
171  /// `extractor` returns the revocation key (typically the `jti` claim) for
172  /// each decoded claims value.
173  pub fn revocation<R, F>(mut self, list: R, extractor: F) -> Self
174  where
175    R: RevocationList,
176    F: Fn(&V::Claims) -> Option<String> + Send + Sync + 'static,
177  {
178    self.revocation = Some((Arc::new(list), Arc::new(extractor)));
179    self
180  }
181
182  /// Plugs a remote introspection callback. The callback is invoked on every
183  /// successful local verification — short-lived caches belong inside the
184  /// callback itself.
185  pub fn introspect<F, Fut>(mut self, f: F) -> Self
186  where
187    F: Fn(&str) -> Fut + Send + Sync + 'static,
188    Fut: Future<Output = bool> + Send + 'static,
189  {
190    self.introspect = Some(Arc::new(move |t: &str| Box::pin(f(t))));
191    self
192  }
193}
194
195impl<V: JwtVerifier> IntoMiddleware for JwtAuth<V> {
196  fn into_middleware(
197    self,
198  ) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
199  + Clone
200  + Send
201  + Sync
202  + 'static {
203    let verifier = self.verifier;
204    let constraints = Arc::new(self.constraints);
205    let revocation = self.revocation;
206    let introspect = self.introspect;
207
208    move |mut req: Request, next: Next| {
209      let verifier = verifier.clone();
210      let constraints = constraints.clone();
211      let revocation = revocation.clone();
212      let introspect = introspect.clone();
213
214      Box::pin(async move {
215        // PMW-04: RFC 7235 §2.1 requires the auth scheme name to be
216        // matched case-insensitively. Sibling `bearer_auth.rs:205` already
217        // uses `eq_ignore_ascii_case`; here we previously used the
218        // case-sensitive `strip_prefix("Bearer ")` which silently 401'd
219        // any legitimate `bearer <jwt>` / `BEARER <jwt>` client.
220        let token = match req
221          .headers()
222          .get(AUTHORIZATION)
223          .and_then(|v| v.to_str().ok())
224          .and_then(|s| s.split_once(' '))
225          .filter(|(scheme, _)| scheme.eq_ignore_ascii_case("Bearer"))
226          .map(|(_, rest)| rest.trim())
227        {
228          Some(t) => t.to_string(),
229          None => {
230            return (
231              StatusCode::UNAUTHORIZED,
232              "Missing or invalid Authorization header",
233            )
234              .into_response();
235          }
236        };
237
238        let claims = match verifier.verify(&token) {
239          Ok(c) => c,
240          Err(e) => {
241            return (StatusCode::UNAUTHORIZED, format!("Invalid token: {e}")).into_response();
242          }
243        };
244
245        // Caller-controlled iss/aud/leeway. Propagate to the verifier so it
246        // can apply them. Default trait impl fails closed when constraints
247        // are configured but the verifier does not implement enforcement.
248        if let Err(e) = verifier.validate_constraints(&claims, &constraints) {
249          return (StatusCode::UNAUTHORIZED, format!("Invalid token: {e}")).into_response();
250        }
251
252        if let Some((list, extractor)) = revocation.as_ref()
253          && let Some(jti) = extractor(&claims)
254          && list.is_revoked(&jti)
255        {
256          return (StatusCode::UNAUTHORIZED, "token revoked").into_response();
257        }
258
259        if let Some(introspect) = introspect.as_ref()
260          && !introspect(&token).await
261        {
262          return (StatusCode::UNAUTHORIZED, "token introspection failed").into_response();
263        }
264
265        req.extensions_mut().insert(claims);
266        next.run(req).await.into_response()
267      })
268    }
269  }
270}
271
272#[cfg(feature = "jwt-simple")]
273mod jwt_simple_impl {
274  use std::collections::HashMap;
275  use std::sync::Arc;
276
277  use ::jwt_simple::prelude::*;
278  use serde::Serialize;
279  use serde::de::DeserializeOwned;
280  use tako_rs_core::types::BuildHasher;
281
282  /// Multi-algorithm JWT verification key wrapper.
283  pub enum AnyVerifyKey {
284    HS256(Arc<HS256Key>),
285    HS384(Arc<HS384Key>),
286    HS512(Arc<HS512Key>),
287    Blake2b(Arc<Blake2bKey>),
288    RS256(Arc<RS256PublicKey>),
289    RS384(Arc<RS384PublicKey>),
290    RS512(Arc<RS512PublicKey>),
291    PS256(Arc<PS256PublicKey>),
292    PS384(Arc<PS384PublicKey>),
293    PS512(Arc<PS512PublicKey>),
294    ES256(Arc<ES256PublicKey>),
295    ES256K(Arc<ES256kPublicKey>),
296    ES384(Arc<ES384PublicKey>),
297    EdDSA(Arc<Ed25519PublicKey>),
298  }
299
300  impl AnyVerifyKey {
301    pub fn alg_id(&self) -> &'static str {
302      match self {
303        Self::HS256(_) => "HS256",
304        Self::HS384(_) => "HS384",
305        Self::HS512(_) => "HS512",
306        Self::Blake2b(_) => "BLAKE2B",
307        Self::RS256(_) => "RS256",
308        Self::RS384(_) => "RS384",
309        Self::RS512(_) => "RS512",
310        Self::PS256(_) => "PS256",
311        Self::PS384(_) => "PS384",
312        Self::PS512(_) => "PS512",
313        Self::ES256(_) => "ES256",
314        Self::ES256K(_) => "ES256K",
315        Self::ES384(_) => "ES384",
316        Self::EdDSA(_) => "EdDSA",
317      }
318    }
319
320    fn verify_token<C>(
321      &self,
322      token: &str,
323      opts: VerificationOptions,
324    ) -> Result<JWTClaims<C>, ::jwt_simple::Error>
325    where
326      C: Serialize + DeserializeOwned,
327    {
328      let opts = Some(opts);
329      match self {
330        Self::HS256(k) => k.verify_token::<C>(token, opts),
331        Self::HS384(k) => k.verify_token::<C>(token, opts),
332        Self::HS512(k) => k.verify_token::<C>(token, opts),
333        Self::Blake2b(k) => k.verify_token::<C>(token, opts),
334        Self::RS256(k) => k.verify_token::<C>(token, opts),
335        Self::RS384(k) => k.verify_token::<C>(token, opts),
336        Self::RS512(k) => k.verify_token::<C>(token, opts),
337        Self::PS256(k) => k.verify_token::<C>(token, opts),
338        Self::PS384(k) => k.verify_token::<C>(token, opts),
339        Self::PS512(k) => k.verify_token::<C>(token, opts),
340        Self::ES256(k) => k.verify_token::<C>(token, opts),
341        Self::ES256K(k) => k.verify_token::<C>(token, opts),
342        Self::ES384(k) => k.verify_token::<C>(token, opts),
343        Self::EdDSA(k) => k.verify_token::<C>(token, opts),
344      }
345    }
346  }
347
348  impl Clone for AnyVerifyKey {
349    fn clone(&self) -> Self {
350      match self {
351        Self::HS256(k) => Self::HS256(Arc::clone(k)),
352        Self::HS384(k) => Self::HS384(Arc::clone(k)),
353        Self::HS512(k) => Self::HS512(Arc::clone(k)),
354        Self::Blake2b(k) => Self::Blake2b(Arc::clone(k)),
355        Self::RS256(k) => Self::RS256(Arc::clone(k)),
356        Self::RS384(k) => Self::RS384(Arc::clone(k)),
357        Self::RS512(k) => Self::RS512(Arc::clone(k)),
358        Self::PS256(k) => Self::PS256(Arc::clone(k)),
359        Self::PS384(k) => Self::PS384(Arc::clone(k)),
360        Self::PS512(k) => Self::PS512(Arc::clone(k)),
361        Self::ES256(k) => Self::ES256(Arc::clone(k)),
362        Self::ES256K(k) => Self::ES256K(Arc::clone(k)),
363        Self::ES384(k) => Self::ES384(Arc::clone(k)),
364        Self::EdDSA(k) => Self::EdDSA(Arc::clone(k)),
365      }
366    }
367  }
368
369  /// Multi-algorithm verifier with per-`kid` rotation.
370  ///
371  /// `keys` carries algorithm-keyed defaults; `keys_by_kid` adds an optional
372  /// kid-keyed lookup that wins when the JWT header carries `kid`. Updating
373  /// the kid map at runtime rotates without restarting.
374  pub struct MultiKeyVerifier<C> {
375    keys_by_alg: HashMap<&'static str, AnyVerifyKey, BuildHasher>,
376    keys_by_kid: super::Arc<parking_lot::RwLock<HashMap<String, AnyVerifyKey>>>,
377    constraints: super::Arc<super::VerifyConstraints>,
378    _phantom: std::marker::PhantomData<C>,
379  }
380
381  impl<C> Clone for MultiKeyVerifier<C> {
382    fn clone(&self) -> Self {
383      Self {
384        keys_by_alg: self.keys_by_alg.clone(),
385        keys_by_kid: self.keys_by_kid.clone(),
386        constraints: self.constraints.clone(),
387        _phantom: std::marker::PhantomData,
388      }
389    }
390  }
391
392  impl<C> MultiKeyVerifier<C> {
393    /// Builds a verifier with algorithm-only key selection.
394    pub fn new(keys: HashMap<&'static str, AnyVerifyKey, BuildHasher>) -> Self {
395      Self {
396        keys_by_alg: keys,
397        keys_by_kid: super::Arc::new(parking_lot::RwLock::new(HashMap::new())),
398        constraints: super::Arc::new(super::VerifyConstraints::default()),
399        _phantom: std::marker::PhantomData,
400      }
401    }
402
403    /// Adds / replaces the rotation key for `kid`.
404    pub fn rotate_key(&self, kid: impl Into<String>, key: AnyVerifyKey) {
405      self.keys_by_kid.write().insert(kid.into(), key);
406    }
407
408    /// Removes the rotation key for `kid`.
409    pub fn revoke_kid(&self, kid: &str) {
410      self.keys_by_kid.write().remove(kid);
411    }
412
413    /// Sets per-claim verification constraints.
414    pub fn constraints(mut self, c: super::VerifyConstraints) -> Self {
415      self.constraints = super::Arc::new(c);
416      self
417    }
418  }
419
420  impl<C> super::JwtVerifier for MultiKeyVerifier<C>
421  where
422    C: Clone + Serialize + DeserializeOwned + Send + Sync + 'static,
423  {
424    type Claims = JWTClaims<C>;
425    type Error = String;
426
427    fn verify(&self, token: &str) -> Result<Self::Claims, Self::Error> {
428      let meta = ::jwt_simple::token::Token::decode_metadata(token)
429        .map_err(|e| format!("Cannot decode JWT header: {e}"))?;
430
431      let alg = meta.algorithm();
432      let kid = meta.key_id();
433
434      let key = if let Some(kid) = kid {
435        let kid_map = self.keys_by_kid.read();
436        kid_map.get(kid).cloned()
437      } else {
438        None
439      };
440      let key = match key {
441        Some(k) => k,
442        None => self
443          .keys_by_alg
444          .get(alg)
445          .cloned()
446          .ok_or_else(|| format!("Algorithm {alg} not allowed"))?,
447      };
448
449      let mut opts = VerificationOptions {
450        time_tolerance: Some(::jwt_simple::prelude::Duration::from_secs(
451          self.constraints.leeway_secs,
452        )),
453        ..Default::default()
454      };
455      if let Some(iss) = &self.constraints.issuer {
456        let mut set = std::collections::HashSet::new();
457        set.insert(iss.clone());
458        opts.allowed_issuers = Some(set);
459      }
460      if let Some(aud) = &self.constraints.audience {
461        let mut set = std::collections::HashSet::new();
462        set.insert(aud.clone());
463        opts.allowed_audiences = Some(set);
464      }
465
466      key
467        .verify_token::<C>(token, opts)
468        .map_err(|e| e.to_string())
469    }
470
471    fn validate_constraints(
472      &self,
473      claims: &Self::Claims,
474      constraints: &super::VerifyConstraints,
475    ) -> Result<(), super::ConstraintsNotSupported> {
476      if let Some(expected) = &constraints.issuer
477        && claims.issuer.as_deref() != Some(expected.as_str())
478      {
479        return Err(super::ConstraintsNotSupported {
480          reason: "issuer mismatch",
481        });
482      }
483      if let Some(expected) = &constraints.audience {
484        let mut allowed = std::collections::HashSet::new();
485        allowed.insert(expected.clone());
486        match &claims.audiences {
487          Some(a) if a.contains(&allowed) => {}
488          _ => {
489            return Err(super::ConstraintsNotSupported {
490              reason: "audience mismatch",
491            });
492          }
493        }
494      }
495      // `leeway_secs` is applied to exp/nbf by the underlying verify() call
496      // when this verifier's internal `constraints.leeway_secs` is set; the
497      // middleware-level field is informational only here. If both are set
498      // and disagree, the verifier-level leeway wins for exp/nbf and the
499      // middleware-level leeway is ignored.
500      Ok(())
501    }
502  }
503}
504
505#[cfg(feature = "jwt-simple")]
506pub use jwt_simple_impl::AnyVerifyKey;
507#[cfg(feature = "jwt-simple")]
508pub use jwt_simple_impl::MultiKeyVerifier;