systemprompt_api/services/middleware/
jti_revocation.rs1use axum::extract::{Request, State};
15use axum::http::StatusCode;
16use axum::middleware::Next;
17use axum::response::{IntoResponse, Response};
18use std::sync::Arc;
19use systemprompt_models::RequestContext;
20use systemprompt_models::api::{ApiError, ErrorCode};
21use systemprompt_oauth::repository::{JtiRevocationCache, OAuthRepository};
22
23#[derive(Clone)]
24pub struct JtiRevocationState {
25 pub repo: Arc<OAuthRepository>,
26 pub cache: Arc<JtiRevocationCache>,
27}
28
29impl std::fmt::Debug for JtiRevocationState {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("JtiRevocationState").finish_non_exhaustive()
32 }
33}
34
35pub async fn jti_revocation_middleware(
36 State(state): State<JtiRevocationState>,
37 request: Request,
38 next: Next,
39) -> Response {
40 let jti = request
41 .extensions()
42 .get::<RequestContext>()
43 .map(|ctx| ctx.jti().to_owned())
44 .unwrap_or_default();
45
46 if jti.is_empty() {
47 return next.run(request).await;
48 }
49
50 if state.cache.peek(&jti) == Some(true) {
51 return token_revoked_response();
52 }
53 if state.cache.peek(&jti) == Some(false) {
54 return next.run(request).await;
55 }
56
57 match state.repo.is_jti_revoked(&jti).await {
58 Ok(revoked) => {
59 state.cache.record(&jti, revoked);
60 if revoked {
61 token_revoked_response()
62 } else {
63 next.run(request).await
64 }
65 },
66 Err(e) => {
67 tracing::error!(error = %e, "JTI revocation lookup failed; failing closed");
68 ApiError::new(ErrorCode::InternalError, "auth state lookup failed").into_response()
69 },
70 }
71}
72
73fn token_revoked_response() -> Response {
74 let mut resp = ApiError::new(ErrorCode::Unauthorized, "Token revoked").into_response();
75 *resp.status_mut() = StatusCode::UNAUTHORIZED;
76 resp
77}