Skip to main content

systemprompt_api/services/middleware/jwt/
context.rs

1use async_trait::async_trait;
2use axum::body::Body;
3use axum::extract::Request;
4use axum::http::HeaderMap;
5use std::sync::Arc;
6
7use crate::services::middleware::context::ContextExtractor;
8use systemprompt_identifiers::{ContextId, SessionId, UserId};
9use systemprompt_models::auth::UserType;
10use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
11use systemprompt_security::TokenExtractor;
12use systemprompt_traits::{AnalyticsProvider, AuthUser, UserProvider};
13
14use super::params::{BuildContextParams, build_context, extract_common_headers};
15use super::token::{JwtExtractor, JwtUserContext};
16use super::validation::{UserCache, user_is_admin, validate_session_exists, validate_user_exists};
17
18#[derive(Clone)]
19pub struct JwtContextExtractor {
20    jwt_extractor: Arc<JwtExtractor>,
21    token_extractor: TokenExtractor,
22    analytics_provider: Arc<dyn AnalyticsProvider>,
23    user_provider: Arc<dyn UserProvider>,
24    user_cache: Arc<UserCache>,
25}
26
27impl std::fmt::Debug for JwtContextExtractor {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("JwtContextExtractor")
30            .field("jwt_extractor", &self.jwt_extractor)
31            .field("token_extractor", &self.token_extractor)
32            .finish_non_exhaustive()
33    }
34}
35
36impl JwtContextExtractor {
37    pub fn new(
38        analytics_provider: Arc<dyn AnalyticsProvider>,
39        user_provider: Arc<dyn UserProvider>,
40    ) -> Self {
41        Self {
42            jwt_extractor: Arc::new(JwtExtractor::new()),
43            token_extractor: TokenExtractor::browser_only(),
44            analytics_provider,
45            user_provider,
46            user_cache: UserCache::new(),
47        }
48    }
49
50    fn extract_jwt_context(
51        &self,
52        headers: &HeaderMap,
53    ) -> Result<JwtUserContext, ContextExtractionError> {
54        let token = self
55            .token_extractor
56            .extract(headers)
57            .map_err(|_| ContextExtractionError::MissingAuthHeader)?;
58        self.jwt_extractor
59            .extract_user_context(&token)
60            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
61    }
62
63    async fn validate(
64        &self,
65        jwt_context: &JwtUserContext,
66        route_context: &str,
67    ) -> Result<AuthUser, ContextExtractionError> {
68        if jwt_context.session_id.as_str().is_empty() {
69            return Err(ContextExtractionError::MissingSessionId);
70        }
71        if jwt_context.user_id.as_str().is_empty() {
72            return Err(ContextExtractionError::MissingUserId);
73        }
74        let validated = validate_user_exists(
75            &self.user_provider,
76            &self.user_cache,
77            jwt_context,
78            route_context,
79        )
80        .await?;
81        validate_session_exists(&self.analytics_provider, jwt_context, route_context).await?;
82        Ok(validated.user)
83    }
84
85    pub async fn extract_standard(
86        &self,
87        headers: &HeaderMap,
88    ) -> Result<RequestContext, ContextExtractionError> {
89        let has_auth = headers.get("authorization").is_some();
90        let has_context_headers =
91            headers.get("x-user-id").is_some() && headers.get("x-session-id").is_some();
92
93        if has_context_headers && !has_auth {
94            return Err(ContextExtractionError::ForbiddenHeader {
95                header: "X-User-ID/X-Session-ID".to_string(),
96                reason: "Context headers require valid JWT for authentication".to_string(),
97            });
98        }
99
100        let jwt_context = self.extract_jwt_context(headers)?;
101        let user = self.validate(&jwt_context, "").await?;
102
103        let session_id = headers
104            .get("x-session-id")
105            .and_then(|h| h.to_str().ok())
106            .map_or_else(
107                || jwt_context.session_id.clone(),
108                |s| SessionId::new(s.to_string()),
109            );
110
111        let user_id = headers
112            .get("x-user-id")
113            .and_then(|h| h.to_str().ok())
114            .map_or_else(
115                || jwt_context.user_id.clone(),
116                |s| UserId::new(s.to_string()),
117            );
118
119        let context_id = headers
120            .get("x-context-id")
121            .and_then(|h| h.to_str().ok())
122            .filter(|s| !s.is_empty())
123            .and_then(|s| ContextId::try_new(s).ok())
124            .unwrap_or_else(ContextId::generate);
125
126        let (trace_id, task_id, auth_token, agent_name) =
127            extract_common_headers(&self.token_extractor, headers);
128
129        let user_type = resolve_user_type(jwt_context.user_type, &user);
130
131        Ok(build_context(BuildContextParams {
132            jwt_context,
133            session_id,
134            user_id,
135            trace_id,
136            context_id,
137            agent_name,
138            task_id,
139            auth_token,
140            user_type,
141        }))
142    }
143
144    pub async fn extract_mcp_a2a(
145        &self,
146        headers: &HeaderMap,
147    ) -> Result<RequestContext, ContextExtractionError> {
148        self.extract_standard(headers).await
149    }
150
151    pub async fn decode_for_gateway(
152        &self,
153        jwt_token: &systemprompt_identifiers::JwtToken,
154    ) -> Result<JwtUserContext, ContextExtractionError> {
155        let jwt_context = self
156            .jwt_extractor
157            .extract_user_context(jwt_token.as_str())
158            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))?;
159
160        let _ = self.validate(&jwt_context, "gateway").await?;
161        Ok(jwt_context)
162    }
163
164    async fn extract_from_request_impl(
165        &self,
166        request: Request<Body>,
167    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
168        use crate::services::middleware::context::sources::{ContextIdSource, PayloadSource};
169
170        let headers = request.headers().clone();
171        let has_auth = headers.get("authorization").is_some();
172
173        if headers.get("x-context-id").is_some() && !has_auth {
174            return Err(ContextExtractionError::ForbiddenHeader {
175                header: "X-Context-ID".to_string(),
176                reason: "Context ID must be in request body (A2A spec). Use contextId field in \
177                         message."
178                    .to_string(),
179            });
180        }
181
182        let jwt_context = self.extract_jwt_context(&headers)?;
183        let user = self.validate(&jwt_context, " (A2A route)").await?;
184
185        let (body_bytes, reconstructed_request) =
186            PayloadSource::read_and_reconstruct(request).await?;
187
188        let context_source = PayloadSource::extract_context_source(&body_bytes)?;
189        let (context_id, task_id_from_payload) = match context_source {
190            ContextIdSource::Direct(id) => (
191                ContextId::try_new(id).map_err(|e| ContextExtractionError::InvalidHeaderValue {
192                    header: "contextId".to_string(),
193                    reason: e.to_string(),
194                })?,
195                None,
196            ),
197            ContextIdSource::FromTask { task_id } => (ContextId::generate(), Some(task_id)),
198        };
199
200        let (trace_id, task_id_from_header, auth_token, agent_name) =
201            extract_common_headers(&self.token_extractor, &headers);
202
203        let task_id = task_id_from_payload.or(task_id_from_header);
204        let user_type = resolve_user_type(jwt_context.user_type, &user);
205
206        let session_id = jwt_context.session_id.clone();
207        let user_id = jwt_context.user_id.clone();
208        let ctx = build_context(BuildContextParams {
209            jwt_context,
210            session_id,
211            user_id,
212            trace_id,
213            context_id,
214            agent_name,
215            task_id,
216            auth_token,
217            user_type,
218        });
219
220        Ok((ctx, reconstructed_request))
221    }
222}
223
224// Human types (Admin, User) are settled against the users row: an Admin JWT
225// for a non-admin row is rewritten to User. Machine types (Service, A2a, Mcp,
226// Anon) are trusted from the JWT — they are minted by the OAuth layer and
227// are not reflected in the users.roles column.
228fn resolve_user_type(claimed: UserType, user: &AuthUser) -> UserType {
229    match claimed {
230        UserType::Admin if !user_is_admin(user) => UserType::User,
231        other => other,
232    }
233}
234
235#[async_trait]
236impl ContextExtractor for JwtContextExtractor {
237    async fn extract_from_headers(
238        &self,
239        headers: &HeaderMap,
240    ) -> Result<RequestContext, ContextExtractionError> {
241        self.extract_standard(headers).await
242    }
243
244    async fn extract_from_request(
245        &self,
246        request: Request<Body>,
247    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
248        self.extract_from_request_impl(request).await
249    }
250
251    async fn extract_user_only(
252        &self,
253        headers: &HeaderMap,
254    ) -> Result<RequestContext, ContextExtractionError> {
255        self.extract_standard(headers).await
256    }
257}