1use async_trait::async_trait;
2use axum::body::Body;
3use axum::extract::Request;
4use axum::http::HeaderMap;
5use std::sync::Arc;
6
7use crate::services::middleware::context::ContextExtractor;
8use systemprompt_database::DbPool;
9use systemprompt_identifiers::{
10 AgentName, ContextId, SessionId, SessionSource, TaskId, TraceId, UserId,
11};
12use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
13use systemprompt_security::{HeaderExtractor, TokenExtractor};
14use systemprompt_traits::{AnalyticsProvider, CreateSessionInput};
15use systemprompt_users::UserService;
16
17use super::token::{JwtExtractor, JwtUserContext};
18
19struct BuildContextParams {
20 jwt_context: JwtUserContext,
21 session_id: SessionId,
22 user_id: UserId,
23 trace_id: TraceId,
24 context_id: ContextId,
25 agent_name: AgentName,
26 task_id: Option<TaskId>,
27 auth_token: Option<String>,
28}
29
30#[derive(Clone)]
31pub struct JwtContextExtractor {
32 jwt_extractor: Arc<JwtExtractor>,
33 token_extractor: TokenExtractor,
34 db_pool: DbPool,
35 analytics_provider: Option<Arc<dyn AnalyticsProvider>>,
36}
37
38impl std::fmt::Debug for JwtContextExtractor {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("JwtContextExtractor")
41 .field("jwt_extractor", &self.jwt_extractor)
42 .field("token_extractor", &self.token_extractor)
43 .field("db_pool", &"DbPool")
44 .field("analytics_provider", &self.analytics_provider.is_some())
45 .finish()
46 }
47}
48
49impl JwtContextExtractor {
50 pub fn new(jwt_secret: &str, db_pool: &DbPool) -> Self {
51 Self {
52 jwt_extractor: Arc::new(JwtExtractor::new(jwt_secret)),
53 token_extractor: TokenExtractor::browser_only(),
54 db_pool: Arc::clone(db_pool),
55 analytics_provider: None,
56 }
57 }
58
59 pub fn with_analytics_provider(mut self, provider: Arc<dyn AnalyticsProvider>) -> Self {
60 self.analytics_provider = Some(provider);
61 self
62 }
63
64 fn extract_jwt_context(
65 &self,
66 headers: &HeaderMap,
67 ) -> Result<JwtUserContext, ContextExtractionError> {
68 let token = self
69 .token_extractor
70 .extract(headers)
71 .map_err(|_| ContextExtractionError::MissingAuthHeader)?;
72 self.jwt_extractor
73 .extract_user_context(&token)
74 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
75 }
76
77 async fn validate_user_exists(
78 &self,
79 jwt_context: &JwtUserContext,
80 route_context: &str,
81 ) -> Result<(), ContextExtractionError> {
82 let user_service = UserService::new(&self.db_pool).map_err(|e| {
83 ContextExtractionError::DatabaseError(format!("Failed to create user service: {e}"))
84 })?;
85 let user_exists = user_service
86 .find_by_id(&jwt_context.user_id)
87 .await
88 .map_err(|e| {
89 ContextExtractionError::DatabaseError(format!(
90 "Failed to check user existence: {e}"
91 ))
92 })?;
93
94 if user_exists.is_none() {
95 tracing::info!(
96 session_id = %jwt_context.session_id.as_str(),
97 user_id = %jwt_context.user_id.as_str(),
98 route = %route_context,
99 "JWT validation failed: User no longer exists in database"
100 );
101
102 return Err(ContextExtractionError::UserNotFound(format!(
103 "User {} no longer exists",
104 jwt_context.user_id.as_str()
105 )));
106 }
107 Ok(())
108 }
109
110 async fn validate_session_exists(
111 &self,
112 jwt_context: &JwtUserContext,
113 headers: &HeaderMap,
114 route_context: &str,
115 ) -> Result<(), ContextExtractionError> {
116 let Some(analytics_provider) = &self.analytics_provider else {
117 return Ok(());
118 };
119
120 let session_exists = analytics_provider
121 .find_session_by_id(&jwt_context.session_id)
122 .await
123 .map_err(|e| {
124 ContextExtractionError::DatabaseError(format!("Failed to check session: {e}"))
125 })?
126 .is_some();
127
128 if session_exists {
129 return Ok(());
130 }
131
132 tracing::info!(
133 session_id = %jwt_context.session_id.as_str(),
134 user_id = %jwt_context.user_id.as_str(),
135 route = %route_context,
136 "Creating missing session for legacy token"
137 );
138
139 let config = systemprompt_models::Config::get().map_err(|e| {
140 ContextExtractionError::DatabaseError(format!("Failed to get config: {e}"))
141 })?;
142 let expires_at =
143 chrono::Utc::now() + chrono::Duration::seconds(config.jwt_access_token_expiration);
144 let analytics = analytics_provider.extract_analytics(headers, None);
145 let session_source = jwt_context
146 .client_id
147 .as_ref()
148 .map_or(SessionSource::Api, |c| {
149 SessionSource::from_client_id(c.as_str())
150 });
151
152 analytics_provider
153 .create_session(CreateSessionInput {
154 session_id: &jwt_context.session_id,
155 user_id: Some(&jwt_context.user_id),
156 analytics: &analytics,
157 session_source,
158 is_bot: false,
159 is_ai_crawler: false,
160 expires_at,
161 })
162 .await
163 .map_err(|e| {
164 ContextExtractionError::DatabaseError(format!("Failed to create session: {e}"))
165 })?;
166
167 Ok(())
168 }
169
170 fn extract_common_headers(
171 &self,
172 headers: &HeaderMap,
173 ) -> (TraceId, Option<TaskId>, Option<String>, AgentName) {
174 (
175 HeaderExtractor::extract_trace_id(headers),
176 HeaderExtractor::extract_task_id(headers),
177 self.token_extractor.extract(headers).ok(),
178 HeaderExtractor::extract_agent_name(headers),
179 )
180 }
181
182 fn build_context(params: BuildContextParams) -> RequestContext {
183 let BuildContextParams {
184 jwt_context,
185 session_id,
186 user_id,
187 trace_id,
188 context_id,
189 agent_name,
190 task_id,
191 auth_token,
192 } = params;
193 let mut ctx = RequestContext::new(session_id, trace_id, context_id, agent_name)
194 .with_user_id(user_id)
195 .with_user_type(jwt_context.user_type);
196
197 if let Some(client_id) = jwt_context.client_id {
198 ctx = ctx.with_client_id(client_id);
199 }
200 if let Some(t_id) = task_id {
201 ctx = ctx.with_task_id(t_id);
202 }
203 if let Some(token) = auth_token {
204 ctx = ctx.with_auth_token(token);
205 }
206 ctx
207 }
208
209 pub async fn extract_standard(
210 &self,
211 headers: &HeaderMap,
212 ) -> Result<RequestContext, ContextExtractionError> {
213 let has_auth = headers.get("authorization").is_some();
214 let has_context_headers =
215 headers.get("x-user-id").is_some() && headers.get("x-session-id").is_some();
216
217 if has_context_headers && !has_auth {
218 return Err(ContextExtractionError::ForbiddenHeader {
219 header: "X-User-ID/X-Session-ID".to_string(),
220 reason: "Context headers require valid JWT for authentication".to_string(),
221 });
222 }
223
224 let jwt_context = self.extract_jwt_context(headers)?;
225
226 if jwt_context.session_id.as_str().is_empty() {
227 return Err(ContextExtractionError::MissingSessionId);
228 }
229 if jwt_context.user_id.as_str().is_empty() {
230 return Err(ContextExtractionError::MissingUserId);
231 }
232
233 self.validate_user_exists(&jwt_context, "").await?;
234 self.validate_session_exists(&jwt_context, headers, "")
235 .await?;
236
237 let session_id = headers
238 .get("x-session-id")
239 .and_then(|h| h.to_str().ok())
240 .map_or_else(
241 || jwt_context.session_id.clone(),
242 |s| SessionId::new(s.to_string()),
243 );
244
245 let user_id = headers
246 .get("x-user-id")
247 .and_then(|h| h.to_str().ok())
248 .map_or_else(
249 || jwt_context.user_id.clone(),
250 |s| UserId::new(s.to_string()),
251 );
252
253 let context_id = headers
254 .get("x-context-id")
255 .and_then(|h| h.to_str().ok())
256 .map_or_else(
257 || ContextId::new(String::new()),
258 |s| ContextId::new(s.to_string()),
259 );
260
261 let (trace_id, task_id, auth_token, agent_name) = self.extract_common_headers(headers);
262
263 Ok(Self::build_context(BuildContextParams {
264 jwt_context,
265 session_id,
266 user_id,
267 trace_id,
268 context_id,
269 agent_name,
270 task_id,
271 auth_token,
272 }))
273 }
274
275 pub async fn extract_mcp_a2a(
276 &self,
277 headers: &HeaderMap,
278 ) -> Result<RequestContext, ContextExtractionError> {
279 self.extract_standard(headers).await
280 }
281
282 pub async fn extract_for_gateway(
283 &self,
284 jwt_token: &systemprompt_identifiers::JwtToken,
285 ) -> Result<RequestContext, ContextExtractionError> {
286 let jwt_context = self
287 .jwt_extractor
288 .extract_user_context(jwt_token.as_str())
289 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))?;
290
291 if jwt_context.session_id.as_str().is_empty() {
292 return Err(ContextExtractionError::MissingSessionId);
293 }
294 if jwt_context.user_id.as_str().is_empty() {
295 return Err(ContextExtractionError::MissingUserId);
296 }
297
298 self.validate_user_exists(&jwt_context, "gateway").await?;
299
300 let session_id = jwt_context.session_id.clone();
301 let user_id = jwt_context.user_id.clone();
302
303 Ok(Self::build_context(BuildContextParams {
304 jwt_context,
305 session_id,
306 user_id,
307 trace_id: TraceId::generate(),
308 context_id: ContextId::new(String::new()),
309 agent_name: AgentName::system(),
310 task_id: None,
311 auth_token: Some(jwt_token.as_str().to_string()),
312 }))
313 }
314
315 async fn extract_from_request_impl(
316 &self,
317 request: Request<Body>,
318 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
319 use crate::services::middleware::context::sources::{
320 ContextIdSource, PayloadSource, TASK_BASED_CONTEXT_MARKER,
321 };
322
323 let headers = request.headers().clone();
324 let has_auth = headers.get("authorization").is_some();
325
326 if headers.get("x-context-id").is_some() && !has_auth {
327 return Err(ContextExtractionError::ForbiddenHeader {
328 header: "X-Context-ID".to_string(),
329 reason: "Context ID must be in request body (A2A spec). Use contextId field in \
330 message."
331 .to_string(),
332 });
333 }
334
335 let jwt_context = self.extract_jwt_context(&headers)?;
336
337 if jwt_context.session_id.as_str().is_empty() {
338 return Err(ContextExtractionError::MissingSessionId);
339 }
340 if jwt_context.user_id.as_str().is_empty() {
341 return Err(ContextExtractionError::MissingUserId);
342 }
343
344 self.validate_user_exists(&jwt_context, " (A2A route)")
345 .await?;
346 self.validate_session_exists(&jwt_context, &headers, " (A2A route)")
347 .await?;
348
349 let (body_bytes, reconstructed_request) =
350 PayloadSource::read_and_reconstruct(request).await?;
351
352 let context_source = PayloadSource::extract_context_source(&body_bytes)?;
353 let (context_id, task_id_from_payload) = match context_source {
354 ContextIdSource::Direct(id) => (ContextId::new(id), None),
355 ContextIdSource::FromTask { task_id } => (
356 ContextId::new(TASK_BASED_CONTEXT_MARKER),
357 Some(TaskId::new(task_id)),
358 ),
359 };
360
361 let (trace_id, task_id_from_header, auth_token, agent_name) =
362 self.extract_common_headers(&headers);
363
364 let task_id = task_id_from_payload.or(task_id_from_header);
365
366 let session_id = jwt_context.session_id.clone();
367 let user_id = jwt_context.user_id.clone();
368 let ctx = Self::build_context(BuildContextParams {
369 jwt_context,
370 session_id,
371 user_id,
372 trace_id,
373 context_id,
374 agent_name,
375 task_id,
376 auth_token,
377 });
378
379 Ok((ctx, reconstructed_request))
380 }
381}
382
383#[async_trait]
384impl ContextExtractor for JwtContextExtractor {
385 async fn extract_from_headers(
386 &self,
387 headers: &HeaderMap,
388 ) -> Result<RequestContext, ContextExtractionError> {
389 self.extract_standard(headers).await
390 }
391
392 async fn extract_from_request(
393 &self,
394 request: Request<Body>,
395 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
396 self.extract_from_request_impl(request).await
397 }
398
399 async fn extract_user_only(
400 &self,
401 headers: &HeaderMap,
402 ) -> Result<RequestContext, ContextExtractionError> {
403 self.extract_standard(headers).await
404 }
405}