Skip to main content

systemprompt_api/services/middleware/
jti_revocation.rs

1//! JTI revocation tower layer.
2//!
3//! Runs after [`crate::services::middleware::context::ContextMiddleware`] has
4//! built the [`RequestContext`] and attached it to request extensions. The
5//! JWT itself was already validated upstream (signature, audience, expiry);
6//! this layer adds the one stateful check JWT validation cannot do — has the
7//! token been explicitly revoked?
8//!
9//! - Anonymous / system contexts (empty `jti`) → no-op.
10//! - Cache hit (revoked) → 401 immediately.
11//! - Cache hit (fresh negative) → next.
12//! - Cache miss → DB lookup, cache the result, then 401 or next.
13
14use axum::extract::{Request, State};
15use axum::http::StatusCode;
16use axum::middleware::Next;
17use axum::response::{IntoResponse, Response};
18use std::sync::Arc;
19use systemprompt_models::RequestContext;
20use systemprompt_models::api::{ApiError, ErrorCode};
21use systemprompt_oauth::repository::{JtiRevocationCache, OAuthRepository};
22
23#[derive(Clone)]
24pub struct JtiRevocationState {
25    pub repo: Arc<OAuthRepository>,
26    pub cache: Arc<JtiRevocationCache>,
27}
28
29impl std::fmt::Debug for JtiRevocationState {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("JtiRevocationState").finish_non_exhaustive()
32    }
33}
34
35pub async fn jti_revocation_middleware(
36    State(state): State<JtiRevocationState>,
37    request: Request,
38    next: Next,
39) -> Response {
40    let jti = request
41        .extensions()
42        .get::<RequestContext>()
43        .map(|ctx| ctx.jti().to_owned())
44        .unwrap_or_default();
45
46    if jti.is_empty() {
47        return next.run(request).await;
48    }
49
50    if state.cache.peek(&jti) == Some(true) {
51        return token_revoked_response();
52    }
53    if state.cache.peek(&jti) == Some(false) {
54        return next.run(request).await;
55    }
56
57    match state.repo.is_jti_revoked(&jti).await {
58        Ok(revoked) => {
59            state.cache.record(&jti, revoked);
60            if revoked {
61                token_revoked_response()
62            } else {
63                next.run(request).await
64            }
65        },
66        Err(e) => {
67            tracing::error!(error = %e, "JTI revocation lookup failed; failing closed");
68            ApiError::new(ErrorCode::InternalError, "auth state lookup failed").into_response()
69        },
70    }
71}
72
73fn token_revoked_response() -> Response {
74    let mut resp = ApiError::new(ErrorCode::Unauthorized, "Token revoked").into_response();
75    *resp.status_mut() = StatusCode::UNAUTHORIZED;
76    resp
77}