systemprompt_api/services/middleware/jwt/
context.rs1use async_trait::async_trait;
2use axum::body::Body;
3use axum::extract::Request;
4use axum::http::HeaderMap;
5use std::sync::Arc;
6
7use crate::services::middleware::context::ContextExtractor;
8use systemprompt_database::DbPool;
9use systemprompt_identifiers::{
10 AgentName, ContextId, SessionId, SessionSource, TaskId, TraceId, UserId,
11};
12use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
13use systemprompt_security::{HeaderExtractor, TokenExtractor};
14use systemprompt_traits::{AnalyticsProvider, CreateSessionInput};
15use systemprompt_users::UserService;
16
17use super::token::{JwtExtractor, JwtUserContext};
18
19#[derive(Clone)]
20pub struct JwtContextExtractor {
21 jwt_extractor: Arc<JwtExtractor>,
22 token_extractor: TokenExtractor,
23 db_pool: DbPool,
24 analytics_provider: Option<Arc<dyn AnalyticsProvider>>,
25}
26
27impl std::fmt::Debug for JwtContextExtractor {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("JwtContextExtractor")
30 .field("jwt_extractor", &self.jwt_extractor)
31 .field("token_extractor", &self.token_extractor)
32 .field("db_pool", &"DbPool")
33 .field("analytics_provider", &self.analytics_provider.is_some())
34 .finish()
35 }
36}
37
38impl JwtContextExtractor {
39 pub fn new(jwt_secret: &str, db_pool: &DbPool) -> Self {
40 Self {
41 jwt_extractor: Arc::new(JwtExtractor::new(jwt_secret)),
42 token_extractor: TokenExtractor::browser_only(),
43 db_pool: db_pool.clone(),
44 analytics_provider: None,
45 }
46 }
47
48 pub fn with_analytics_provider(mut self, provider: Arc<dyn AnalyticsProvider>) -> Self {
49 self.analytics_provider = Some(provider);
50 self
51 }
52
53 fn extract_jwt_context(
54 &self,
55 headers: &HeaderMap,
56 ) -> Result<JwtUserContext, ContextExtractionError> {
57 let token = self
58 .token_extractor
59 .extract(headers)
60 .map_err(|_| ContextExtractionError::MissingAuthHeader)?;
61 self.jwt_extractor
62 .extract_user_context(&token)
63 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
64 }
65
66 async fn validate_user_exists(
67 &self,
68 jwt_context: &JwtUserContext,
69 route_context: &str,
70 ) -> Result<(), ContextExtractionError> {
71 let user_service = UserService::new(&self.db_pool).map_err(|e| {
72 ContextExtractionError::DatabaseError(format!("Failed to create user service: {e}"))
73 })?;
74 let user_exists = user_service
75 .find_by_id(&jwt_context.user_id)
76 .await
77 .map_err(|e| {
78 ContextExtractionError::DatabaseError(format!(
79 "Failed to check user existence: {e}"
80 ))
81 })?;
82
83 if user_exists.is_none() {
84 tracing::info!(
85 session_id = %jwt_context.session_id.as_str(),
86 user_id = %jwt_context.user_id.as_str(),
87 route = %route_context,
88 "JWT validation failed: User no longer exists in database"
89 );
90
91 return Err(ContextExtractionError::UserNotFound(format!(
92 "User {} no longer exists",
93 jwt_context.user_id.as_str()
94 )));
95 }
96 Ok(())
97 }
98
99 async fn validate_session_exists(
100 &self,
101 jwt_context: &JwtUserContext,
102 headers: &HeaderMap,
103 route_context: &str,
104 ) -> Result<(), ContextExtractionError> {
105 let Some(analytics_provider) = &self.analytics_provider else {
106 return Ok(());
107 };
108
109 let session_exists = analytics_provider
110 .find_session_by_id(&jwt_context.session_id)
111 .await
112 .map_err(|e| {
113 ContextExtractionError::DatabaseError(format!("Failed to check session: {e}"))
114 })?
115 .is_some();
116
117 if session_exists {
118 return Ok(());
119 }
120
121 tracing::info!(
122 session_id = %jwt_context.session_id.as_str(),
123 user_id = %jwt_context.user_id.as_str(),
124 route = %route_context,
125 "Creating missing session for legacy token"
126 );
127
128 let config = systemprompt_models::Config::get().map_err(|e| {
129 ContextExtractionError::DatabaseError(format!("Failed to get config: {e}"))
130 })?;
131 let expires_at =
132 chrono::Utc::now() + chrono::Duration::seconds(config.jwt_access_token_expiration);
133 let analytics = analytics_provider.extract_analytics(headers, None);
134 let session_source = jwt_context
135 .client_id
136 .as_ref()
137 .map(|c| SessionSource::from_client_id(c.as_str()))
138 .unwrap_or(SessionSource::Api);
139
140 analytics_provider
141 .create_session(CreateSessionInput {
142 session_id: &jwt_context.session_id,
143 user_id: Some(&jwt_context.user_id),
144 analytics: &analytics,
145 session_source,
146 is_bot: false,
147 expires_at,
148 })
149 .await
150 .map_err(|e| {
151 ContextExtractionError::DatabaseError(format!("Failed to create session: {e}"))
152 })?;
153
154 Ok(())
155 }
156
157 fn extract_common_headers(
158 &self,
159 headers: &HeaderMap,
160 ) -> (TraceId, Option<TaskId>, Option<String>, AgentName) {
161 (
162 HeaderExtractor::extract_trace_id(headers),
163 HeaderExtractor::extract_task_id(headers),
164 self.token_extractor.extract(headers).ok(),
165 HeaderExtractor::extract_agent_name(headers),
166 )
167 }
168
169 fn build_context(
170 jwt_context: &JwtUserContext,
171 session_id: SessionId,
172 user_id: UserId,
173 trace_id: TraceId,
174 context_id: ContextId,
175 agent_name: AgentName,
176 task_id: Option<TaskId>,
177 auth_token: Option<String>,
178 ) -> RequestContext {
179 let mut ctx = RequestContext::new(session_id, trace_id, context_id, agent_name)
180 .with_user_id(user_id)
181 .with_user_type(jwt_context.user_type);
182
183 if let Some(client_id) = jwt_context.client_id.clone() {
184 ctx = ctx.with_client_id(client_id);
185 }
186 if let Some(t_id) = task_id {
187 ctx = ctx.with_task_id(t_id);
188 }
189 if let Some(token) = auth_token {
190 ctx = ctx.with_auth_token(token);
191 }
192 ctx
193 }
194
195 pub async fn extract_standard(
196 &self,
197 headers: &HeaderMap,
198 ) -> Result<RequestContext, ContextExtractionError> {
199 let has_auth = headers.get("authorization").is_some();
200 let has_context_headers =
201 headers.get("x-user-id").is_some() && headers.get("x-session-id").is_some();
202
203 if has_context_headers && !has_auth {
204 return Err(ContextExtractionError::ForbiddenHeader {
205 header: "X-User-ID/X-Session-ID".to_string(),
206 reason: "Context headers require valid JWT for authentication".to_string(),
207 });
208 }
209
210 let jwt_context = self.extract_jwt_context(headers)?;
211
212 if jwt_context.session_id.as_str().is_empty() {
213 return Err(ContextExtractionError::MissingSessionId);
214 }
215 if jwt_context.user_id.as_str().is_empty() {
216 return Err(ContextExtractionError::MissingUserId);
217 }
218
219 self.validate_user_exists(&jwt_context, "").await?;
220 self.validate_session_exists(&jwt_context, headers, "")
221 .await?;
222
223 let session_id = headers
224 .get("x-session-id")
225 .and_then(|h| h.to_str().ok())
226 .map_or_else(
227 || jwt_context.session_id.clone(),
228 |s| SessionId::new(s.to_string()),
229 );
230
231 let user_id = headers
232 .get("x-user-id")
233 .and_then(|h| h.to_str().ok())
234 .map_or_else(
235 || jwt_context.user_id.clone(),
236 |s| UserId::new(s.to_string()),
237 );
238
239 let context_id = headers
240 .get("x-context-id")
241 .and_then(|h| h.to_str().ok())
242 .map_or_else(
243 || ContextId::new(String::new()),
244 |s| ContextId::new(s.to_string()),
245 );
246
247 let (trace_id, task_id, auth_token, agent_name) = self.extract_common_headers(headers);
248
249 Ok(Self::build_context(
250 &jwt_context,
251 session_id,
252 user_id,
253 trace_id,
254 context_id,
255 agent_name,
256 task_id,
257 auth_token,
258 ))
259 }
260
261 pub async fn extract_mcp_a2a(
262 &self,
263 headers: &HeaderMap,
264 ) -> Result<RequestContext, ContextExtractionError> {
265 self.extract_standard(headers).await
266 }
267
268 async fn extract_from_request_impl(
269 &self,
270 request: Request<Body>,
271 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
272 use crate::services::middleware::context::sources::{
273 ContextIdSource, PayloadSource, TASK_BASED_CONTEXT_MARKER,
274 };
275
276 let headers = request.headers().clone();
277 let has_auth = headers.get("authorization").is_some();
278
279 if headers.get("x-context-id").is_some() && !has_auth {
280 return Err(ContextExtractionError::ForbiddenHeader {
281 header: "X-Context-ID".to_string(),
282 reason: "Context ID must be in request body (A2A spec). Use contextId field in \
283 message."
284 .to_string(),
285 });
286 }
287
288 let jwt_context = self.extract_jwt_context(&headers)?;
289
290 if jwt_context.session_id.as_str().is_empty() {
291 return Err(ContextExtractionError::MissingSessionId);
292 }
293 if jwt_context.user_id.as_str().is_empty() {
294 return Err(ContextExtractionError::MissingUserId);
295 }
296
297 self.validate_user_exists(&jwt_context, " (A2A route)")
298 .await?;
299 self.validate_session_exists(&jwt_context, &headers, " (A2A route)")
300 .await?;
301
302 let (body_bytes, reconstructed_request) =
303 PayloadSource::read_and_reconstruct(request).await?;
304
305 let context_source = PayloadSource::extract_context_source(&body_bytes)?;
306 let (context_id, task_id_from_payload) = match context_source {
307 ContextIdSource::Direct(id) => (ContextId::new(id), None),
308 ContextIdSource::FromTask { task_id } => (
309 ContextId::new(TASK_BASED_CONTEXT_MARKER),
310 Some(TaskId::new(task_id)),
311 ),
312 };
313
314 let (trace_id, task_id_from_header, auth_token, agent_name) =
315 self.extract_common_headers(&headers);
316
317 let task_id = task_id_from_payload.or(task_id_from_header);
318
319 let ctx = Self::build_context(
320 &jwt_context,
321 jwt_context.session_id.clone(),
322 jwt_context.user_id.clone(),
323 trace_id,
324 context_id,
325 agent_name,
326 task_id,
327 auth_token,
328 );
329
330 Ok((ctx, reconstructed_request))
331 }
332}
333
334#[async_trait]
335impl ContextExtractor for JwtContextExtractor {
336 async fn extract_from_headers(
337 &self,
338 headers: &HeaderMap,
339 ) -> Result<RequestContext, ContextExtractionError> {
340 self.extract_standard(headers).await
341 }
342
343 async fn extract_from_request(
344 &self,
345 request: Request<Body>,
346 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
347 self.extract_from_request_impl(request).await
348 }
349
350 async fn extract_user_only(
351 &self,
352 headers: &HeaderMap,
353 ) -> Result<RequestContext, ContextExtractionError> {
354 self.extract_standard(headers).await
355 }
356}