Skip to main content

systemprompt_api/services/middleware/jwt/
context.rs

1use async_trait::async_trait;
2use axum::body::Body;
3use axum::extract::Request;
4use axum::http::HeaderMap;
5use std::sync::Arc;
6
7use crate::services::middleware::context::ContextExtractor;
8use systemprompt_database::DbPool;
9use systemprompt_identifiers::{AgentName, ContextId, SessionId, TraceId, UserId};
10use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
11use systemprompt_security::TokenExtractor;
12use systemprompt_traits::AnalyticsProvider;
13
14use super::params::{BuildContextParams, build_context, extract_common_headers};
15use super::token::{JwtExtractor, JwtUserContext};
16use super::validation::{validate_session_exists, validate_user_exists};
17
18#[derive(Clone)]
19pub struct JwtContextExtractor {
20    jwt_extractor: Arc<JwtExtractor>,
21    token_extractor: TokenExtractor,
22    db_pool: DbPool,
23    analytics_provider: Option<Arc<dyn AnalyticsProvider>>,
24}
25
26impl std::fmt::Debug for JwtContextExtractor {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("JwtContextExtractor")
29            .field("jwt_extractor", &self.jwt_extractor)
30            .field("token_extractor", &self.token_extractor)
31            .field("db_pool", &"DbPool")
32            .field("analytics_provider", &self.analytics_provider.is_some())
33            .finish()
34    }
35}
36
37impl JwtContextExtractor {
38    pub fn new(jwt_secret: &str, db_pool: &DbPool) -> Self {
39        Self {
40            jwt_extractor: Arc::new(JwtExtractor::new(jwt_secret)),
41            token_extractor: TokenExtractor::browser_only(),
42            db_pool: Arc::clone(db_pool),
43            analytics_provider: None,
44        }
45    }
46
47    pub fn with_analytics_provider(mut self, provider: Arc<dyn AnalyticsProvider>) -> Self {
48        self.analytics_provider = Some(provider);
49        self
50    }
51
52    fn extract_jwt_context(
53        &self,
54        headers: &HeaderMap,
55    ) -> Result<JwtUserContext, ContextExtractionError> {
56        let token = self
57            .token_extractor
58            .extract(headers)
59            .map_err(|_| ContextExtractionError::MissingAuthHeader)?;
60        self.jwt_extractor
61            .extract_user_context(&token)
62            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
63    }
64
65    pub async fn extract_standard(
66        &self,
67        headers: &HeaderMap,
68    ) -> Result<RequestContext, ContextExtractionError> {
69        let has_auth = headers.get("authorization").is_some();
70        let has_context_headers =
71            headers.get("x-user-id").is_some() && headers.get("x-session-id").is_some();
72
73        if has_context_headers && !has_auth {
74            return Err(ContextExtractionError::ForbiddenHeader {
75                header: "X-User-ID/X-Session-ID".to_string(),
76                reason: "Context headers require valid JWT for authentication".to_string(),
77            });
78        }
79
80        let jwt_context = self.extract_jwt_context(headers)?;
81
82        if jwt_context.session_id.as_str().is_empty() {
83            return Err(ContextExtractionError::MissingSessionId);
84        }
85        if jwt_context.user_id.as_str().is_empty() {
86            return Err(ContextExtractionError::MissingUserId);
87        }
88
89        validate_user_exists(&self.db_pool, &jwt_context, "").await?;
90        validate_session_exists(self.analytics_provider.as_ref(), &jwt_context, headers, "")
91            .await?;
92
93        let session_id = headers
94            .get("x-session-id")
95            .and_then(|h| h.to_str().ok())
96            .map_or_else(
97                || jwt_context.session_id.clone(),
98                |s| SessionId::new(s.to_string()),
99            );
100
101        let user_id = headers
102            .get("x-user-id")
103            .and_then(|h| h.to_str().ok())
104            .map_or_else(
105                || jwt_context.user_id.clone(),
106                |s| UserId::new(s.to_string()),
107            );
108
109        let context_id = headers
110            .get("x-context-id")
111            .and_then(|h| h.to_str().ok())
112            .map_or_else(
113                || ContextId::new(String::new()),
114                |s| ContextId::new(s.to_string()),
115            );
116
117        let (trace_id, task_id, auth_token, agent_name) =
118            extract_common_headers(&self.token_extractor, headers);
119
120        Ok(build_context(BuildContextParams {
121            jwt_context,
122            session_id,
123            user_id,
124            trace_id,
125            context_id,
126            agent_name,
127            task_id,
128            auth_token,
129        }))
130    }
131
132    pub async fn extract_mcp_a2a(
133        &self,
134        headers: &HeaderMap,
135    ) -> Result<RequestContext, ContextExtractionError> {
136        self.extract_standard(headers).await
137    }
138
139    pub async fn extract_for_gateway(
140        &self,
141        jwt_token: &systemprompt_identifiers::JwtToken,
142    ) -> Result<RequestContext, ContextExtractionError> {
143        let jwt_context = self
144            .jwt_extractor
145            .extract_user_context(jwt_token.as_str())
146            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))?;
147
148        if jwt_context.session_id.as_str().is_empty() {
149            return Err(ContextExtractionError::MissingSessionId);
150        }
151        if jwt_context.user_id.as_str().is_empty() {
152            return Err(ContextExtractionError::MissingUserId);
153        }
154
155        validate_user_exists(&self.db_pool, &jwt_context, "gateway").await?;
156
157        let session_id = jwt_context.session_id.clone();
158        let user_id = jwt_context.user_id.clone();
159
160        Ok(build_context(BuildContextParams {
161            jwt_context,
162            session_id,
163            user_id,
164            trace_id: TraceId::generate(),
165            context_id: ContextId::new(String::new()),
166            agent_name: AgentName::system(),
167            task_id: None,
168            auth_token: Some(jwt_token.as_str().to_string()),
169        }))
170    }
171
172    async fn extract_from_request_impl(
173        &self,
174        request: Request<Body>,
175    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
176        use crate::services::middleware::context::sources::{
177            ContextIdSource, PayloadSource, TASK_BASED_CONTEXT_MARKER,
178        };
179
180        let headers = request.headers().clone();
181        let has_auth = headers.get("authorization").is_some();
182
183        if headers.get("x-context-id").is_some() && !has_auth {
184            return Err(ContextExtractionError::ForbiddenHeader {
185                header: "X-Context-ID".to_string(),
186                reason: "Context ID must be in request body (A2A spec). Use contextId field in \
187                         message."
188                    .to_string(),
189            });
190        }
191
192        let jwt_context = self.extract_jwt_context(&headers)?;
193
194        if jwt_context.session_id.as_str().is_empty() {
195            return Err(ContextExtractionError::MissingSessionId);
196        }
197        if jwt_context.user_id.as_str().is_empty() {
198            return Err(ContextExtractionError::MissingUserId);
199        }
200
201        validate_user_exists(&self.db_pool, &jwt_context, " (A2A route)").await?;
202        validate_session_exists(
203            self.analytics_provider.as_ref(),
204            &jwt_context,
205            &headers,
206            " (A2A route)",
207        )
208        .await?;
209
210        let (body_bytes, reconstructed_request) =
211            PayloadSource::read_and_reconstruct(request).await?;
212
213        let context_source = PayloadSource::extract_context_source(&body_bytes)?;
214        let (context_id, task_id_from_payload) = match context_source {
215            ContextIdSource::Direct(id) => (ContextId::new(id), None),
216            ContextIdSource::FromTask { task_id } => {
217                (ContextId::new(TASK_BASED_CONTEXT_MARKER), Some(task_id))
218            },
219        };
220
221        let (trace_id, task_id_from_header, auth_token, agent_name) =
222            extract_common_headers(&self.token_extractor, &headers);
223
224        let task_id = task_id_from_payload.or(task_id_from_header);
225
226        let session_id = jwt_context.session_id.clone();
227        let user_id = jwt_context.user_id.clone();
228        let ctx = build_context(BuildContextParams {
229            jwt_context,
230            session_id,
231            user_id,
232            trace_id,
233            context_id,
234            agent_name,
235            task_id,
236            auth_token,
237        });
238
239        Ok((ctx, reconstructed_request))
240    }
241}
242
243#[async_trait]
244impl ContextExtractor for JwtContextExtractor {
245    async fn extract_from_headers(
246        &self,
247        headers: &HeaderMap,
248    ) -> Result<RequestContext, ContextExtractionError> {
249        self.extract_standard(headers).await
250    }
251
252    async fn extract_from_request(
253        &self,
254        request: Request<Body>,
255    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
256        self.extract_from_request_impl(request).await
257    }
258
259    async fn extract_user_only(
260        &self,
261        headers: &HeaderMap,
262    ) -> Result<RequestContext, ContextExtractionError> {
263        self.extract_standard(headers).await
264    }
265}