Skip to main content

systemprompt_api/services/middleware/jwt/
context.rs

1//! JWT-backed request-context extractor.
2//!
3//! [`JwtContextExtractor`] implements [`ContextExtractor`] by validating the
4//! bearer token (signature, session existence, user existence, and JTI
5//! revocation) and building a `RequestContext`. It resolves the context id from
6//! the `x-context-id` header on standard routes and from the JSON-RPC body on
7//! A2A routes, and exposes a gateway decode path for pre-authenticated tokens.
8
9use async_trait::async_trait;
10use axum::body::Body;
11use axum::extract::Request;
12use axum::http::HeaderMap;
13use std::sync::Arc;
14
15use crate::services::middleware::context::ContextExtractor;
16use systemprompt_identifiers::ContextId;
17use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
18use systemprompt_security::{JwtUserContext, TokenExtractor, extract_user_context};
19use systemprompt_traits::{AnalyticsProvider, UserProvider};
20
21use super::params::{BuildContextParams, build_context, extract_common_headers};
22use super::revocation::JtiRevocationChecker;
23use super::validation::{UserCache, user_is_admin, validate_session_exists, validate_user_exists};
24
25#[derive(Clone)]
26pub struct JwtContextExtractor {
27    token_extractor: TokenExtractor,
28    analytics_provider: Arc<dyn AnalyticsProvider>,
29    user_provider: Arc<dyn UserProvider>,
30    user_cache: Arc<UserCache>,
31    jti_revocation: JtiRevocationChecker,
32}
33
34impl std::fmt::Debug for JwtContextExtractor {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("JwtContextExtractor")
37            .field("token_extractor", &self.token_extractor)
38            .finish_non_exhaustive()
39    }
40}
41
42impl JwtContextExtractor {
43    pub fn new(
44        analytics_provider: Arc<dyn AnalyticsProvider>,
45        user_provider: Arc<dyn UserProvider>,
46        jti_revocation: JtiRevocationChecker,
47    ) -> Self {
48        Self {
49            token_extractor: TokenExtractor::browser_only(),
50            analytics_provider,
51            user_provider,
52            user_cache: UserCache::new(),
53            jti_revocation,
54        }
55    }
56
57    fn extract_jwt_context(
58        &self,
59        headers: &HeaderMap,
60    ) -> Result<JwtUserContext, ContextExtractionError> {
61        let token = self
62            .token_extractor
63            .extract(headers)
64            .map_err(|_e| ContextExtractionError::MissingAuthHeader)?;
65        extract_user_context(&token)
66            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
67    }
68
69    async fn validate(
70        &self,
71        jwt_context: &JwtUserContext,
72        route_context: &str,
73    ) -> Result<systemprompt_traits::AuthUser, ContextExtractionError> {
74        if jwt_context.session_id.as_str().is_empty() {
75            return Err(ContextExtractionError::MissingSessionId);
76        }
77        if jwt_context.user_id.as_str().is_empty() {
78            return Err(ContextExtractionError::MissingUserId);
79        }
80        let validated = validate_user_exists(
81            &self.user_provider,
82            &self.user_cache,
83            jwt_context,
84            route_context,
85        )
86        .await?;
87        validate_session_exists(&self.analytics_provider, jwt_context, route_context).await?;
88        self.jti_revocation
89            .ensure_not_revoked(&jwt_context.jti)
90            .await?;
91        Ok(validated.user)
92    }
93
94    pub async fn extract_standard(
95        &self,
96        headers: &HeaderMap,
97    ) -> Result<RequestContext, ContextExtractionError> {
98        let jwt_context = self.extract_jwt_context(headers)?;
99        let user = self.validate(&jwt_context, "").await?;
100
101        let context_id = headers
102            .get("x-context-id")
103            .and_then(|h| h.to_str().ok())
104            .filter(|s| !s.is_empty())
105            .and_then(|s| ContextId::try_new(s).ok())
106            .unwrap_or_else(ContextId::generate);
107
108        let (trace_id, task_id, auth_token, agent_name) =
109            extract_common_headers(&self.token_extractor, headers);
110
111        let user_type = jwt_context.user_type.reconcile_with(user_is_admin(&user));
112        let session_id = jwt_context.session_id.clone();
113        let user_id = jwt_context.user_id.clone();
114
115        Ok(build_context(BuildContextParams {
116            jwt_context,
117            session_id,
118            user_id,
119            trace_id,
120            context_id,
121            agent_name,
122            task_id,
123            auth_token,
124            user_type,
125        }))
126    }
127
128    pub async fn decode_for_gateway(
129        &self,
130        jwt_token: &systemprompt_identifiers::JwtToken,
131    ) -> Result<(JwtUserContext, systemprompt_traits::AuthUser), ContextExtractionError> {
132        let jwt_context = extract_user_context(jwt_token.as_str())
133            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))?;
134
135        let user = self.validate(&jwt_context, "gateway").await?;
136        Ok((jwt_context, user))
137    }
138
139    async fn extract_from_request_impl(
140        &self,
141        request: Request<Body>,
142    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
143        use crate::services::middleware::context::sources::{ContextIdSource, PayloadSource};
144
145        let headers = request.headers().clone();
146        let has_auth = headers.get("authorization").is_some();
147
148        if headers.get("x-context-id").is_some() && !has_auth {
149            return Err(ContextExtractionError::ForbiddenHeader {
150                header: "X-Context-ID".to_owned(),
151                reason: "Context ID must be in request body (A2A spec). Use contextId field in \
152                         message."
153                    .to_owned(),
154            });
155        }
156
157        let jwt_context = self.extract_jwt_context(&headers)?;
158        let user = self.validate(&jwt_context, " (A2A route)").await?;
159
160        let (body_bytes, reconstructed_request) =
161            PayloadSource::read_and_reconstruct(request).await?;
162
163        let context_source = PayloadSource::extract_context_source(&body_bytes)?;
164        let (context_id, task_id_from_payload) = match context_source {
165            ContextIdSource::Direct(id) => (
166                ContextId::try_new(id).map_err(|e| ContextExtractionError::InvalidHeaderValue {
167                    header: "contextId".to_owned(),
168                    reason: e.to_string(),
169                })?,
170                None,
171            ),
172            ContextIdSource::FromTask { task_id } => (ContextId::generate(), Some(task_id)),
173        };
174
175        let (trace_id, task_id_from_header, auth_token, agent_name) =
176            extract_common_headers(&self.token_extractor, &headers);
177
178        let task_id = task_id_from_payload.or(task_id_from_header);
179        let user_type = jwt_context.user_type.reconcile_with(user_is_admin(&user));
180
181        let session_id = jwt_context.session_id.clone();
182        let user_id = jwt_context.user_id.clone();
183        let ctx = build_context(BuildContextParams {
184            jwt_context,
185            session_id,
186            user_id,
187            trace_id,
188            context_id,
189            agent_name,
190            task_id,
191            auth_token,
192            user_type,
193        });
194
195        Ok((ctx, reconstructed_request))
196    }
197}
198
199#[async_trait]
200impl ContextExtractor for JwtContextExtractor {
201    async fn extract_from_headers(
202        &self,
203        headers: &HeaderMap,
204    ) -> Result<RequestContext, ContextExtractionError> {
205        self.extract_standard(headers).await
206    }
207
208    async fn extract_from_request(
209        &self,
210        request: Request<Body>,
211    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
212        self.extract_from_request_impl(request).await
213    }
214}