1use std::fmt;
22use std::future::Future;
23use std::pin::Pin;
24use std::sync::Arc;
25
26use http::StatusCode;
27use http::header::AUTHORIZATION;
28use scc::HashSet as SccHashSet;
29use tako_rs_core::middleware::IntoMiddleware;
30use tako_rs_core::middleware::Next;
31use tako_rs_core::responder::Responder;
32use tako_rs_core::types::Request;
33use tako_rs_core::types::Response;
34
35pub trait JwtVerifier: Send + Sync + Clone + 'static {
37 type Claims: Send + Sync + Clone + 'static;
39 type Error: fmt::Display;
41
42 fn verify(&self, token: &str) -> Result<Self::Claims, Self::Error>;
44
45 fn validate_constraints(
55 &self,
56 _claims: &Self::Claims,
57 constraints: &VerifyConstraints,
58 ) -> Result<(), ConstraintsNotSupported> {
59 if constraints.issuer.is_some()
60 || constraints.audience.is_some()
61 || constraints.leeway_secs != 0
62 {
63 Err(ConstraintsNotSupported {
64 reason: "this JwtVerifier does not override `validate_constraints`; \
65 configure constraints on the verifier itself or implement \
66 `validate_constraints` on your custom verifier",
67 })
68 } else {
69 Ok(())
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
78pub struct ConstraintsNotSupported {
79 pub reason: &'static str,
81}
82
83impl fmt::Display for ConstraintsNotSupported {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 write!(f, "constraints not enforceable: {}", self.reason)
86 }
87}
88
89#[derive(Default, Clone)]
91pub struct VerifyConstraints {
92 pub issuer: Option<String>,
94 pub audience: Option<String>,
96 pub leeway_secs: u64,
98}
99
100pub trait RevocationList: Send + Sync + 'static {
103 fn is_revoked(&self, jti: &str) -> bool;
104}
105
106#[derive(Default, Clone)]
108pub struct InMemoryRevocationList {
109 inner: Arc<SccHashSet<String>>,
110}
111
112impl InMemoryRevocationList {
113 pub fn new() -> Self {
114 Self::default()
115 }
116
117 pub fn revoke(&self, jti: impl Into<String>) {
118 let _ = self.inner.insert_sync(jti.into());
119 }
120
121 pub fn unrevoke(&self, jti: &str) {
122 let _ = self.inner.remove_sync(jti);
123 }
124}
125
126impl RevocationList for InMemoryRevocationList {
127 fn is_revoked(&self, jti: &str) -> bool {
128 self.inner.contains_sync(jti)
129 }
130}
131
132pub type IntrospectionFn =
135 Arc<dyn Fn(&str) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>> + Send + Sync + 'static>;
136
137pub type JtiExtractorFn<C> = Arc<dyn Fn(&C) -> Option<String> + Send + Sync + 'static>;
140
141pub type RevocationCheck<C> = (Arc<dyn RevocationList>, JtiExtractorFn<C>);
143
144pub struct JwtAuth<V: JwtVerifier> {
146 verifier: V,
147 constraints: VerifyConstraints,
148 revocation: Option<RevocationCheck<V::Claims>>,
149 introspect: Option<IntrospectionFn>,
150}
151
152impl<V: JwtVerifier> JwtAuth<V> {
153 pub fn new(verifier: V) -> Self {
156 Self {
157 verifier,
158 constraints: VerifyConstraints::default(),
159 revocation: None,
160 introspect: None,
161 }
162 }
163
164 pub fn constraints(mut self, c: VerifyConstraints) -> Self {
166 self.constraints = c;
167 self
168 }
169
170 pub fn revocation<R, F>(mut self, list: R, extractor: F) -> Self
174 where
175 R: RevocationList,
176 F: Fn(&V::Claims) -> Option<String> + Send + Sync + 'static,
177 {
178 self.revocation = Some((Arc::new(list), Arc::new(extractor)));
179 self
180 }
181
182 pub fn introspect<F, Fut>(mut self, f: F) -> Self
186 where
187 F: Fn(&str) -> Fut + Send + Sync + 'static,
188 Fut: Future<Output = bool> + Send + 'static,
189 {
190 self.introspect = Some(Arc::new(move |t: &str| Box::pin(f(t))));
191 self
192 }
193}
194
195impl<V: JwtVerifier> IntoMiddleware for JwtAuth<V> {
196 fn into_middleware(
197 self,
198 ) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
199 + Clone
200 + Send
201 + Sync
202 + 'static {
203 let verifier = self.verifier;
204 let constraints = Arc::new(self.constraints);
205 let revocation = self.revocation;
206 let introspect = self.introspect;
207
208 move |mut req: Request, next: Next| {
209 let verifier = verifier.clone();
210 let constraints = constraints.clone();
211 let revocation = revocation.clone();
212 let introspect = introspect.clone();
213
214 Box::pin(async move {
215 let token = match req
221 .headers()
222 .get(AUTHORIZATION)
223 .and_then(|v| v.to_str().ok())
224 .and_then(|s| s.split_once(' '))
225 .filter(|(scheme, _)| scheme.eq_ignore_ascii_case("Bearer"))
226 .map(|(_, rest)| rest.trim())
227 {
228 Some(t) => t.to_string(),
229 None => {
230 return (
231 StatusCode::UNAUTHORIZED,
232 "Missing or invalid Authorization header",
233 )
234 .into_response();
235 }
236 };
237
238 let claims = match verifier.verify(&token) {
239 Ok(c) => c,
240 Err(e) => {
241 return (StatusCode::UNAUTHORIZED, format!("Invalid token: {e}")).into_response();
242 }
243 };
244
245 if let Err(e) = verifier.validate_constraints(&claims, &constraints) {
249 return (StatusCode::UNAUTHORIZED, format!("Invalid token: {e}")).into_response();
250 }
251
252 if let Some((list, extractor)) = revocation.as_ref()
253 && let Some(jti) = extractor(&claims)
254 && list.is_revoked(&jti)
255 {
256 return (StatusCode::UNAUTHORIZED, "token revoked").into_response();
257 }
258
259 if let Some(introspect) = introspect.as_ref()
260 && !introspect(&token).await
261 {
262 return (StatusCode::UNAUTHORIZED, "token introspection failed").into_response();
263 }
264
265 req.extensions_mut().insert(claims);
266 next.run(req).await.into_response()
267 })
268 }
269 }
270}
271
272#[cfg(feature = "jwt-simple")]
273mod jwt_simple_impl {
274 use std::collections::HashMap;
275 use std::sync::Arc;
276
277 use ::jwt_simple::prelude::*;
278 use serde::Serialize;
279 use serde::de::DeserializeOwned;
280 use tako_rs_core::types::BuildHasher;
281
282 pub enum AnyVerifyKey {
284 HS256(Arc<HS256Key>),
285 HS384(Arc<HS384Key>),
286 HS512(Arc<HS512Key>),
287 Blake2b(Arc<Blake2bKey>),
288 RS256(Arc<RS256PublicKey>),
289 RS384(Arc<RS384PublicKey>),
290 RS512(Arc<RS512PublicKey>),
291 PS256(Arc<PS256PublicKey>),
292 PS384(Arc<PS384PublicKey>),
293 PS512(Arc<PS512PublicKey>),
294 ES256(Arc<ES256PublicKey>),
295 ES256K(Arc<ES256kPublicKey>),
296 ES384(Arc<ES384PublicKey>),
297 EdDSA(Arc<Ed25519PublicKey>),
298 }
299
300 impl AnyVerifyKey {
301 pub fn alg_id(&self) -> &'static str {
302 match self {
303 Self::HS256(_) => "HS256",
304 Self::HS384(_) => "HS384",
305 Self::HS512(_) => "HS512",
306 Self::Blake2b(_) => "BLAKE2B",
307 Self::RS256(_) => "RS256",
308 Self::RS384(_) => "RS384",
309 Self::RS512(_) => "RS512",
310 Self::PS256(_) => "PS256",
311 Self::PS384(_) => "PS384",
312 Self::PS512(_) => "PS512",
313 Self::ES256(_) => "ES256",
314 Self::ES256K(_) => "ES256K",
315 Self::ES384(_) => "ES384",
316 Self::EdDSA(_) => "EdDSA",
317 }
318 }
319
320 fn verify_token<C>(
321 &self,
322 token: &str,
323 opts: VerificationOptions,
324 ) -> Result<JWTClaims<C>, ::jwt_simple::Error>
325 where
326 C: Serialize + DeserializeOwned,
327 {
328 let opts = Some(opts);
329 match self {
330 Self::HS256(k) => k.verify_token::<C>(token, opts),
331 Self::HS384(k) => k.verify_token::<C>(token, opts),
332 Self::HS512(k) => k.verify_token::<C>(token, opts),
333 Self::Blake2b(k) => k.verify_token::<C>(token, opts),
334 Self::RS256(k) => k.verify_token::<C>(token, opts),
335 Self::RS384(k) => k.verify_token::<C>(token, opts),
336 Self::RS512(k) => k.verify_token::<C>(token, opts),
337 Self::PS256(k) => k.verify_token::<C>(token, opts),
338 Self::PS384(k) => k.verify_token::<C>(token, opts),
339 Self::PS512(k) => k.verify_token::<C>(token, opts),
340 Self::ES256(k) => k.verify_token::<C>(token, opts),
341 Self::ES256K(k) => k.verify_token::<C>(token, opts),
342 Self::ES384(k) => k.verify_token::<C>(token, opts),
343 Self::EdDSA(k) => k.verify_token::<C>(token, opts),
344 }
345 }
346 }
347
348 impl Clone for AnyVerifyKey {
349 fn clone(&self) -> Self {
350 match self {
351 Self::HS256(k) => Self::HS256(Arc::clone(k)),
352 Self::HS384(k) => Self::HS384(Arc::clone(k)),
353 Self::HS512(k) => Self::HS512(Arc::clone(k)),
354 Self::Blake2b(k) => Self::Blake2b(Arc::clone(k)),
355 Self::RS256(k) => Self::RS256(Arc::clone(k)),
356 Self::RS384(k) => Self::RS384(Arc::clone(k)),
357 Self::RS512(k) => Self::RS512(Arc::clone(k)),
358 Self::PS256(k) => Self::PS256(Arc::clone(k)),
359 Self::PS384(k) => Self::PS384(Arc::clone(k)),
360 Self::PS512(k) => Self::PS512(Arc::clone(k)),
361 Self::ES256(k) => Self::ES256(Arc::clone(k)),
362 Self::ES256K(k) => Self::ES256K(Arc::clone(k)),
363 Self::ES384(k) => Self::ES384(Arc::clone(k)),
364 Self::EdDSA(k) => Self::EdDSA(Arc::clone(k)),
365 }
366 }
367 }
368
369 pub struct MultiKeyVerifier<C> {
375 keys_by_alg: HashMap<&'static str, AnyVerifyKey, BuildHasher>,
376 keys_by_kid: super::Arc<parking_lot::RwLock<HashMap<String, AnyVerifyKey>>>,
377 constraints: super::Arc<super::VerifyConstraints>,
378 _phantom: std::marker::PhantomData<C>,
379 }
380
381 impl<C> Clone for MultiKeyVerifier<C> {
382 fn clone(&self) -> Self {
383 Self {
384 keys_by_alg: self.keys_by_alg.clone(),
385 keys_by_kid: self.keys_by_kid.clone(),
386 constraints: self.constraints.clone(),
387 _phantom: std::marker::PhantomData,
388 }
389 }
390 }
391
392 impl<C> MultiKeyVerifier<C> {
393 pub fn new(keys: HashMap<&'static str, AnyVerifyKey, BuildHasher>) -> Self {
395 Self {
396 keys_by_alg: keys,
397 keys_by_kid: super::Arc::new(parking_lot::RwLock::new(HashMap::new())),
398 constraints: super::Arc::new(super::VerifyConstraints::default()),
399 _phantom: std::marker::PhantomData,
400 }
401 }
402
403 pub fn rotate_key(&self, kid: impl Into<String>, key: AnyVerifyKey) {
405 self.keys_by_kid.write().insert(kid.into(), key);
406 }
407
408 pub fn revoke_kid(&self, kid: &str) {
410 self.keys_by_kid.write().remove(kid);
411 }
412
413 pub fn constraints(mut self, c: super::VerifyConstraints) -> Self {
415 self.constraints = super::Arc::new(c);
416 self
417 }
418 }
419
420 impl<C> super::JwtVerifier for MultiKeyVerifier<C>
421 where
422 C: Clone + Serialize + DeserializeOwned + Send + Sync + 'static,
423 {
424 type Claims = JWTClaims<C>;
425 type Error = String;
426
427 fn verify(&self, token: &str) -> Result<Self::Claims, Self::Error> {
428 let meta = ::jwt_simple::token::Token::decode_metadata(token)
429 .map_err(|e| format!("Cannot decode JWT header: {e}"))?;
430
431 let alg = meta.algorithm();
432 let kid = meta.key_id();
433
434 let key = if let Some(kid) = kid {
435 let kid_map = self.keys_by_kid.read();
436 kid_map.get(kid).cloned()
437 } else {
438 None
439 };
440 let key = match key {
441 Some(k) => k,
442 None => self
443 .keys_by_alg
444 .get(alg)
445 .cloned()
446 .ok_or_else(|| format!("Algorithm {alg} not allowed"))?,
447 };
448
449 let mut opts = VerificationOptions {
450 time_tolerance: Some(::jwt_simple::prelude::Duration::from_secs(
451 self.constraints.leeway_secs,
452 )),
453 ..Default::default()
454 };
455 if let Some(iss) = &self.constraints.issuer {
456 let mut set = std::collections::HashSet::new();
457 set.insert(iss.clone());
458 opts.allowed_issuers = Some(set);
459 }
460 if let Some(aud) = &self.constraints.audience {
461 let mut set = std::collections::HashSet::new();
462 set.insert(aud.clone());
463 opts.allowed_audiences = Some(set);
464 }
465
466 key
467 .verify_token::<C>(token, opts)
468 .map_err(|e| e.to_string())
469 }
470
471 fn validate_constraints(
472 &self,
473 claims: &Self::Claims,
474 constraints: &super::VerifyConstraints,
475 ) -> Result<(), super::ConstraintsNotSupported> {
476 if let Some(expected) = &constraints.issuer
477 && claims.issuer.as_deref() != Some(expected.as_str())
478 {
479 return Err(super::ConstraintsNotSupported {
480 reason: "issuer mismatch",
481 });
482 }
483 if let Some(expected) = &constraints.audience {
484 let mut allowed = std::collections::HashSet::new();
485 allowed.insert(expected.clone());
486 match &claims.audiences {
487 Some(a) if a.contains(&allowed) => {}
488 _ => {
489 return Err(super::ConstraintsNotSupported {
490 reason: "audience mismatch",
491 });
492 }
493 }
494 }
495 Ok(())
501 }
502 }
503}
504
505#[cfg(feature = "jwt-simple")]
506pub use jwt_simple_impl::AnyVerifyKey;
507#[cfg(feature = "jwt-simple")]
508pub use jwt_simple_impl::MultiKeyVerifier;