Skip to main content

systemprompt_api/services/middleware/jwt/
context.rs

1use async_trait::async_trait;
2use axum::body::Body;
3use axum::extract::Request;
4use axum::http::HeaderMap;
5use std::sync::Arc;
6
7use crate::services::middleware::context::ContextExtractor;
8use systemprompt_database::DbPool;
9use systemprompt_identifiers::{
10    AgentName, ContextId, SessionId, SessionSource, TaskId, TraceId, UserId,
11};
12use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
13use systemprompt_security::{HeaderExtractor, TokenExtractor};
14use systemprompt_traits::{AnalyticsProvider, CreateSessionInput};
15use systemprompt_users::UserService;
16
17use super::token::{JwtExtractor, JwtUserContext};
18
19#[derive(Clone)]
20pub struct JwtContextExtractor {
21    jwt_extractor: Arc<JwtExtractor>,
22    token_extractor: TokenExtractor,
23    db_pool: DbPool,
24    analytics_provider: Option<Arc<dyn AnalyticsProvider>>,
25}
26
27impl std::fmt::Debug for JwtContextExtractor {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("JwtContextExtractor")
30            .field("jwt_extractor", &self.jwt_extractor)
31            .field("token_extractor", &self.token_extractor)
32            .field("db_pool", &"DbPool")
33            .field("analytics_provider", &self.analytics_provider.is_some())
34            .finish()
35    }
36}
37
38impl JwtContextExtractor {
39    pub fn new(jwt_secret: &str, db_pool: &DbPool) -> Self {
40        Self {
41            jwt_extractor: Arc::new(JwtExtractor::new(jwt_secret)),
42            token_extractor: TokenExtractor::browser_only(),
43            db_pool: db_pool.clone(),
44            analytics_provider: None,
45        }
46    }
47
48    pub fn with_analytics_provider(mut self, provider: Arc<dyn AnalyticsProvider>) -> Self {
49        self.analytics_provider = Some(provider);
50        self
51    }
52
53    fn extract_jwt_context(
54        &self,
55        headers: &HeaderMap,
56    ) -> Result<JwtUserContext, ContextExtractionError> {
57        let token = self
58            .token_extractor
59            .extract(headers)
60            .map_err(|_| ContextExtractionError::MissingAuthHeader)?;
61        self.jwt_extractor
62            .extract_user_context(&token)
63            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
64    }
65
66    async fn validate_user_exists(
67        &self,
68        jwt_context: &JwtUserContext,
69        route_context: &str,
70    ) -> Result<(), ContextExtractionError> {
71        let user_service = UserService::new(&self.db_pool).map_err(|e| {
72            ContextExtractionError::DatabaseError(format!("Failed to create user service: {e}"))
73        })?;
74        let user_exists = user_service
75            .find_by_id(&jwt_context.user_id)
76            .await
77            .map_err(|e| {
78                ContextExtractionError::DatabaseError(format!(
79                    "Failed to check user existence: {e}"
80                ))
81            })?;
82
83        if user_exists.is_none() {
84            tracing::info!(
85                session_id = %jwt_context.session_id.as_str(),
86                user_id = %jwt_context.user_id.as_str(),
87                route = %route_context,
88                "JWT validation failed: User no longer exists in database"
89            );
90
91            return Err(ContextExtractionError::UserNotFound(format!(
92                "User {} no longer exists",
93                jwt_context.user_id.as_str()
94            )));
95        }
96        Ok(())
97    }
98
99    async fn validate_session_exists(
100        &self,
101        jwt_context: &JwtUserContext,
102        headers: &HeaderMap,
103        route_context: &str,
104    ) -> Result<(), ContextExtractionError> {
105        let Some(analytics_provider) = &self.analytics_provider else {
106            return Ok(());
107        };
108
109        let session_exists = analytics_provider
110            .find_session_by_id(&jwt_context.session_id)
111            .await
112            .map_err(|e| {
113                ContextExtractionError::DatabaseError(format!("Failed to check session: {e}"))
114            })?
115            .is_some();
116
117        if session_exists {
118            return Ok(());
119        }
120
121        tracing::info!(
122            session_id = %jwt_context.session_id.as_str(),
123            user_id = %jwt_context.user_id.as_str(),
124            route = %route_context,
125            "Creating missing session for legacy token"
126        );
127
128        let config = systemprompt_models::Config::get().map_err(|e| {
129            ContextExtractionError::DatabaseError(format!("Failed to get config: {e}"))
130        })?;
131        let expires_at =
132            chrono::Utc::now() + chrono::Duration::seconds(config.jwt_access_token_expiration);
133        let analytics = analytics_provider.extract_analytics(headers, None);
134        let session_source = jwt_context
135            .client_id
136            .as_ref()
137            .map(|c| SessionSource::from_client_id(c.as_str()))
138            .unwrap_or(SessionSource::Api);
139
140        analytics_provider
141            .create_session(CreateSessionInput {
142                session_id: &jwt_context.session_id,
143                user_id: Some(&jwt_context.user_id),
144                analytics: &analytics,
145                session_source,
146                is_bot: false,
147                expires_at,
148            })
149            .await
150            .map_err(|e| {
151                ContextExtractionError::DatabaseError(format!("Failed to create session: {e}"))
152            })?;
153
154        Ok(())
155    }
156
157    fn extract_common_headers(
158        &self,
159        headers: &HeaderMap,
160    ) -> (TraceId, Option<TaskId>, Option<String>, AgentName) {
161        (
162            HeaderExtractor::extract_trace_id(headers),
163            HeaderExtractor::extract_task_id(headers),
164            self.token_extractor.extract(headers).ok(),
165            HeaderExtractor::extract_agent_name(headers),
166        )
167    }
168
169    fn build_context(
170        jwt_context: &JwtUserContext,
171        session_id: SessionId,
172        user_id: UserId,
173        trace_id: TraceId,
174        context_id: ContextId,
175        agent_name: AgentName,
176        task_id: Option<TaskId>,
177        auth_token: Option<String>,
178    ) -> RequestContext {
179        let mut ctx = RequestContext::new(session_id, trace_id, context_id, agent_name)
180            .with_user_id(user_id)
181            .with_user_type(jwt_context.user_type);
182
183        if let Some(client_id) = jwt_context.client_id.clone() {
184            ctx = ctx.with_client_id(client_id);
185        }
186        if let Some(t_id) = task_id {
187            ctx = ctx.with_task_id(t_id);
188        }
189        if let Some(token) = auth_token {
190            ctx = ctx.with_auth_token(token);
191        }
192        ctx
193    }
194
195    pub async fn extract_standard(
196        &self,
197        headers: &HeaderMap,
198    ) -> Result<RequestContext, ContextExtractionError> {
199        let has_auth = headers.get("authorization").is_some();
200        let has_context_headers =
201            headers.get("x-user-id").is_some() && headers.get("x-session-id").is_some();
202
203        if has_context_headers && !has_auth {
204            return Err(ContextExtractionError::ForbiddenHeader {
205                header: "X-User-ID/X-Session-ID".to_string(),
206                reason: "Context headers require valid JWT for authentication".to_string(),
207            });
208        }
209
210        let jwt_context = self.extract_jwt_context(headers)?;
211
212        if jwt_context.session_id.as_str().is_empty() {
213            return Err(ContextExtractionError::MissingSessionId);
214        }
215        if jwt_context.user_id.as_str().is_empty() {
216            return Err(ContextExtractionError::MissingUserId);
217        }
218
219        self.validate_user_exists(&jwt_context, "").await?;
220        self.validate_session_exists(&jwt_context, headers, "")
221            .await?;
222
223        let session_id = headers
224            .get("x-session-id")
225            .and_then(|h| h.to_str().ok())
226            .map_or_else(
227                || jwt_context.session_id.clone(),
228                |s| SessionId::new(s.to_string()),
229            );
230
231        let user_id = headers
232            .get("x-user-id")
233            .and_then(|h| h.to_str().ok())
234            .map_or_else(
235                || jwt_context.user_id.clone(),
236                |s| UserId::new(s.to_string()),
237            );
238
239        let context_id = headers
240            .get("x-context-id")
241            .and_then(|h| h.to_str().ok())
242            .map_or_else(
243                || ContextId::new(String::new()),
244                |s| ContextId::new(s.to_string()),
245            );
246
247        let (trace_id, task_id, auth_token, agent_name) = self.extract_common_headers(headers);
248
249        Ok(Self::build_context(
250            &jwt_context,
251            session_id,
252            user_id,
253            trace_id,
254            context_id,
255            agent_name,
256            task_id,
257            auth_token,
258        ))
259    }
260
261    pub async fn extract_mcp_a2a(
262        &self,
263        headers: &HeaderMap,
264    ) -> Result<RequestContext, ContextExtractionError> {
265        self.extract_standard(headers).await
266    }
267
268    async fn extract_from_request_impl(
269        &self,
270        request: Request<Body>,
271    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
272        use crate::services::middleware::context::sources::{
273            ContextIdSource, PayloadSource, TASK_BASED_CONTEXT_MARKER,
274        };
275
276        let headers = request.headers().clone();
277        let has_auth = headers.get("authorization").is_some();
278
279        if headers.get("x-context-id").is_some() && !has_auth {
280            return Err(ContextExtractionError::ForbiddenHeader {
281                header: "X-Context-ID".to_string(),
282                reason: "Context ID must be in request body (A2A spec). Use contextId field in \
283                         message."
284                    .to_string(),
285            });
286        }
287
288        let jwt_context = self.extract_jwt_context(&headers)?;
289
290        if jwt_context.session_id.as_str().is_empty() {
291            return Err(ContextExtractionError::MissingSessionId);
292        }
293        if jwt_context.user_id.as_str().is_empty() {
294            return Err(ContextExtractionError::MissingUserId);
295        }
296
297        self.validate_user_exists(&jwt_context, " (A2A route)")
298            .await?;
299        self.validate_session_exists(&jwt_context, &headers, " (A2A route)")
300            .await?;
301
302        let (body_bytes, reconstructed_request) =
303            PayloadSource::read_and_reconstruct(request).await?;
304
305        let context_source = PayloadSource::extract_context_source(&body_bytes)?;
306        let (context_id, task_id_from_payload) = match context_source {
307            ContextIdSource::Direct(id) => (ContextId::new(id), None),
308            ContextIdSource::FromTask { task_id } => (
309                ContextId::new(TASK_BASED_CONTEXT_MARKER),
310                Some(TaskId::new(task_id)),
311            ),
312        };
313
314        let (trace_id, task_id_from_header, auth_token, agent_name) =
315            self.extract_common_headers(&headers);
316
317        let task_id = task_id_from_payload.or(task_id_from_header);
318
319        let ctx = Self::build_context(
320            &jwt_context,
321            jwt_context.session_id.clone(),
322            jwt_context.user_id.clone(),
323            trace_id,
324            context_id,
325            agent_name,
326            task_id,
327            auth_token,
328        );
329
330        Ok((ctx, reconstructed_request))
331    }
332}
333
334#[async_trait]
335impl ContextExtractor for JwtContextExtractor {
336    async fn extract_from_headers(
337        &self,
338        headers: &HeaderMap,
339    ) -> Result<RequestContext, ContextExtractionError> {
340        self.extract_standard(headers).await
341    }
342
343    async fn extract_from_request(
344        &self,
345        request: Request<Body>,
346    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
347        self.extract_from_request_impl(request).await
348    }
349
350    async fn extract_user_only(
351        &self,
352        headers: &HeaderMap,
353    ) -> Result<RequestContext, ContextExtractionError> {
354        self.extract_standard(headers).await
355    }
356}