Skip to main content

systemprompt_api/services/middleware/jwt/
context.rs

1use async_trait::async_trait;
2use axum::body::Body;
3use axum::extract::Request;
4use axum::http::HeaderMap;
5use std::sync::Arc;
6
7use crate::services::middleware::context::ContextExtractor;
8use systemprompt_database::DbPool;
9use systemprompt_identifiers::{
10    AgentName, ContextId, SessionId, SessionSource, TaskId, TraceId, UserId,
11};
12use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
13use systemprompt_security::{HeaderExtractor, TokenExtractor};
14use systemprompt_traits::{AnalyticsProvider, CreateSessionInput};
15use systemprompt_users::UserService;
16
17use super::token::{JwtExtractor, JwtUserContext};
18
19struct BuildContextParams {
20    jwt_context: JwtUserContext,
21    session_id: SessionId,
22    user_id: UserId,
23    trace_id: TraceId,
24    context_id: ContextId,
25    agent_name: AgentName,
26    task_id: Option<TaskId>,
27    auth_token: Option<String>,
28}
29
30#[derive(Clone)]
31pub struct JwtContextExtractor {
32    jwt_extractor: Arc<JwtExtractor>,
33    token_extractor: TokenExtractor,
34    db_pool: DbPool,
35    analytics_provider: Option<Arc<dyn AnalyticsProvider>>,
36}
37
38impl std::fmt::Debug for JwtContextExtractor {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("JwtContextExtractor")
41            .field("jwt_extractor", &self.jwt_extractor)
42            .field("token_extractor", &self.token_extractor)
43            .field("db_pool", &"DbPool")
44            .field("analytics_provider", &self.analytics_provider.is_some())
45            .finish()
46    }
47}
48
49impl JwtContextExtractor {
50    pub fn new(jwt_secret: &str, db_pool: &DbPool) -> Self {
51        Self {
52            jwt_extractor: Arc::new(JwtExtractor::new(jwt_secret)),
53            token_extractor: TokenExtractor::browser_only(),
54            db_pool: Arc::clone(db_pool),
55            analytics_provider: None,
56        }
57    }
58
59    pub fn with_analytics_provider(mut self, provider: Arc<dyn AnalyticsProvider>) -> Self {
60        self.analytics_provider = Some(provider);
61        self
62    }
63
64    fn extract_jwt_context(
65        &self,
66        headers: &HeaderMap,
67    ) -> Result<JwtUserContext, ContextExtractionError> {
68        let token = self
69            .token_extractor
70            .extract(headers)
71            .map_err(|_| ContextExtractionError::MissingAuthHeader)?;
72        self.jwt_extractor
73            .extract_user_context(&token)
74            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
75    }
76
77    async fn validate_user_exists(
78        &self,
79        jwt_context: &JwtUserContext,
80        route_context: &str,
81    ) -> Result<(), ContextExtractionError> {
82        let user_service = UserService::new(&self.db_pool).map_err(|e| {
83            ContextExtractionError::DatabaseError(format!("Failed to create user service: {e}"))
84        })?;
85        let user_exists = user_service
86            .find_by_id(&jwt_context.user_id)
87            .await
88            .map_err(|e| {
89                ContextExtractionError::DatabaseError(format!(
90                    "Failed to check user existence: {e}"
91                ))
92            })?;
93
94        if user_exists.is_none() {
95            tracing::info!(
96                session_id = %jwt_context.session_id.as_str(),
97                user_id = %jwt_context.user_id.as_str(),
98                route = %route_context,
99                "JWT validation failed: User no longer exists in database"
100            );
101
102            return Err(ContextExtractionError::UserNotFound(format!(
103                "User {} no longer exists",
104                jwt_context.user_id.as_str()
105            )));
106        }
107        Ok(())
108    }
109
110    async fn validate_session_exists(
111        &self,
112        jwt_context: &JwtUserContext,
113        headers: &HeaderMap,
114        route_context: &str,
115    ) -> Result<(), ContextExtractionError> {
116        let Some(analytics_provider) = &self.analytics_provider else {
117            return Ok(());
118        };
119
120        let session_exists = analytics_provider
121            .find_session_by_id(&jwt_context.session_id)
122            .await
123            .map_err(|e| {
124                ContextExtractionError::DatabaseError(format!("Failed to check session: {e}"))
125            })?
126            .is_some();
127
128        if session_exists {
129            return Ok(());
130        }
131
132        tracing::info!(
133            session_id = %jwt_context.session_id.as_str(),
134            user_id = %jwt_context.user_id.as_str(),
135            route = %route_context,
136            "Creating missing session for legacy token"
137        );
138
139        let config = systemprompt_models::Config::get().map_err(|e| {
140            ContextExtractionError::DatabaseError(format!("Failed to get config: {e}"))
141        })?;
142        let expires_at =
143            chrono::Utc::now() + chrono::Duration::seconds(config.jwt_access_token_expiration);
144        let analytics = analytics_provider.extract_analytics(headers, None);
145        let session_source = jwt_context
146            .client_id
147            .as_ref()
148            .map_or(SessionSource::Api, |c| {
149                SessionSource::from_client_id(c.as_str())
150            });
151
152        analytics_provider
153            .create_session(CreateSessionInput {
154                session_id: &jwt_context.session_id,
155                user_id: Some(&jwt_context.user_id),
156                analytics: &analytics,
157                session_source,
158                is_bot: false,
159                is_ai_crawler: false,
160                expires_at,
161            })
162            .await
163            .map_err(|e| {
164                ContextExtractionError::DatabaseError(format!("Failed to create session: {e}"))
165            })?;
166
167        Ok(())
168    }
169
170    fn extract_common_headers(
171        &self,
172        headers: &HeaderMap,
173    ) -> (TraceId, Option<TaskId>, Option<String>, AgentName) {
174        (
175            HeaderExtractor::extract_trace_id(headers),
176            HeaderExtractor::extract_task_id(headers),
177            self.token_extractor.extract(headers).ok(),
178            HeaderExtractor::extract_agent_name(headers),
179        )
180    }
181
182    fn build_context(params: BuildContextParams) -> RequestContext {
183        let BuildContextParams {
184            jwt_context,
185            session_id,
186            user_id,
187            trace_id,
188            context_id,
189            agent_name,
190            task_id,
191            auth_token,
192        } = params;
193        let mut ctx = RequestContext::new(session_id, trace_id, context_id, agent_name)
194            .with_user_id(user_id)
195            .with_user_type(jwt_context.user_type);
196
197        if let Some(client_id) = jwt_context.client_id {
198            ctx = ctx.with_client_id(client_id);
199        }
200        if let Some(t_id) = task_id {
201            ctx = ctx.with_task_id(t_id);
202        }
203        if let Some(token) = auth_token {
204            ctx = ctx.with_auth_token(token);
205        }
206        ctx
207    }
208
209    pub async fn extract_standard(
210        &self,
211        headers: &HeaderMap,
212    ) -> Result<RequestContext, ContextExtractionError> {
213        let has_auth = headers.get("authorization").is_some();
214        let has_context_headers =
215            headers.get("x-user-id").is_some() && headers.get("x-session-id").is_some();
216
217        if has_context_headers && !has_auth {
218            return Err(ContextExtractionError::ForbiddenHeader {
219                header: "X-User-ID/X-Session-ID".to_string(),
220                reason: "Context headers require valid JWT for authentication".to_string(),
221            });
222        }
223
224        let jwt_context = self.extract_jwt_context(headers)?;
225
226        if jwt_context.session_id.as_str().is_empty() {
227            return Err(ContextExtractionError::MissingSessionId);
228        }
229        if jwt_context.user_id.as_str().is_empty() {
230            return Err(ContextExtractionError::MissingUserId);
231        }
232
233        self.validate_user_exists(&jwt_context, "").await?;
234        self.validate_session_exists(&jwt_context, headers, "")
235            .await?;
236
237        let session_id = headers
238            .get("x-session-id")
239            .and_then(|h| h.to_str().ok())
240            .map_or_else(
241                || jwt_context.session_id.clone(),
242                |s| SessionId::new(s.to_string()),
243            );
244
245        let user_id = headers
246            .get("x-user-id")
247            .and_then(|h| h.to_str().ok())
248            .map_or_else(
249                || jwt_context.user_id.clone(),
250                |s| UserId::new(s.to_string()),
251            );
252
253        let context_id = headers
254            .get("x-context-id")
255            .and_then(|h| h.to_str().ok())
256            .map_or_else(
257                || ContextId::new(String::new()),
258                |s| ContextId::new(s.to_string()),
259            );
260
261        let (trace_id, task_id, auth_token, agent_name) = self.extract_common_headers(headers);
262
263        Ok(Self::build_context(BuildContextParams {
264            jwt_context,
265            session_id,
266            user_id,
267            trace_id,
268            context_id,
269            agent_name,
270            task_id,
271            auth_token,
272        }))
273    }
274
275    pub async fn extract_mcp_a2a(
276        &self,
277        headers: &HeaderMap,
278    ) -> Result<RequestContext, ContextExtractionError> {
279        self.extract_standard(headers).await
280    }
281
282    pub async fn extract_for_gateway(
283        &self,
284        jwt_token: &systemprompt_identifiers::JwtToken,
285    ) -> Result<RequestContext, ContextExtractionError> {
286        let jwt_context = self
287            .jwt_extractor
288            .extract_user_context(jwt_token.as_str())
289            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))?;
290
291        if jwt_context.session_id.as_str().is_empty() {
292            return Err(ContextExtractionError::MissingSessionId);
293        }
294        if jwt_context.user_id.as_str().is_empty() {
295            return Err(ContextExtractionError::MissingUserId);
296        }
297
298        self.validate_user_exists(&jwt_context, "gateway").await?;
299
300        let session_id = jwt_context.session_id.clone();
301        let user_id = jwt_context.user_id.clone();
302
303        Ok(Self::build_context(BuildContextParams {
304            jwt_context,
305            session_id,
306            user_id,
307            trace_id: TraceId::generate(),
308            context_id: ContextId::new(String::new()),
309            agent_name: AgentName::system(),
310            task_id: None,
311            auth_token: Some(jwt_token.as_str().to_string()),
312        }))
313    }
314
315    async fn extract_from_request_impl(
316        &self,
317        request: Request<Body>,
318    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
319        use crate::services::middleware::context::sources::{
320            ContextIdSource, PayloadSource, TASK_BASED_CONTEXT_MARKER,
321        };
322
323        let headers = request.headers().clone();
324        let has_auth = headers.get("authorization").is_some();
325
326        if headers.get("x-context-id").is_some() && !has_auth {
327            return Err(ContextExtractionError::ForbiddenHeader {
328                header: "X-Context-ID".to_string(),
329                reason: "Context ID must be in request body (A2A spec). Use contextId field in \
330                         message."
331                    .to_string(),
332            });
333        }
334
335        let jwt_context = self.extract_jwt_context(&headers)?;
336
337        if jwt_context.session_id.as_str().is_empty() {
338            return Err(ContextExtractionError::MissingSessionId);
339        }
340        if jwt_context.user_id.as_str().is_empty() {
341            return Err(ContextExtractionError::MissingUserId);
342        }
343
344        self.validate_user_exists(&jwt_context, " (A2A route)")
345            .await?;
346        self.validate_session_exists(&jwt_context, &headers, " (A2A route)")
347            .await?;
348
349        let (body_bytes, reconstructed_request) =
350            PayloadSource::read_and_reconstruct(request).await?;
351
352        let context_source = PayloadSource::extract_context_source(&body_bytes)?;
353        let (context_id, task_id_from_payload) = match context_source {
354            ContextIdSource::Direct(id) => (ContextId::new(id), None),
355            ContextIdSource::FromTask { task_id } => (
356                ContextId::new(TASK_BASED_CONTEXT_MARKER),
357                Some(TaskId::new(task_id)),
358            ),
359        };
360
361        let (trace_id, task_id_from_header, auth_token, agent_name) =
362            self.extract_common_headers(&headers);
363
364        let task_id = task_id_from_payload.or(task_id_from_header);
365
366        let session_id = jwt_context.session_id.clone();
367        let user_id = jwt_context.user_id.clone();
368        let ctx = Self::build_context(BuildContextParams {
369            jwt_context,
370            session_id,
371            user_id,
372            trace_id,
373            context_id,
374            agent_name,
375            task_id,
376            auth_token,
377        });
378
379        Ok((ctx, reconstructed_request))
380    }
381}
382
383#[async_trait]
384impl ContextExtractor for JwtContextExtractor {
385    async fn extract_from_headers(
386        &self,
387        headers: &HeaderMap,
388    ) -> Result<RequestContext, ContextExtractionError> {
389        self.extract_standard(headers).await
390    }
391
392    async fn extract_from_request(
393        &self,
394        request: Request<Body>,
395    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
396        self.extract_from_request_impl(request).await
397    }
398
399    async fn extract_user_only(
400        &self,
401        headers: &HeaderMap,
402    ) -> Result<RequestContext, ContextExtractionError> {
403        self.extract_standard(headers).await
404    }
405}