systemprompt_api/services/middleware/jwt/
context.rs1use async_trait::async_trait;
2use axum::body::Body;
3use axum::extract::Request;
4use axum::http::HeaderMap;
5use std::sync::Arc;
6
7use crate::services::middleware::context::ContextExtractor;
8use systemprompt_database::DbPool;
9use systemprompt_identifiers::{AgentName, ContextId, SessionId, TraceId, UserId};
10use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
11use systemprompt_security::TokenExtractor;
12use systemprompt_traits::AnalyticsProvider;
13
14use super::params::{BuildContextParams, build_context, extract_common_headers};
15use super::token::{JwtExtractor, JwtUserContext};
16use super::validation::{validate_session_exists, validate_user_exists};
17
18#[derive(Clone)]
19pub struct JwtContextExtractor {
20 jwt_extractor: Arc<JwtExtractor>,
21 token_extractor: TokenExtractor,
22 db_pool: DbPool,
23 analytics_provider: Option<Arc<dyn AnalyticsProvider>>,
24}
25
26impl std::fmt::Debug for JwtContextExtractor {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("JwtContextExtractor")
29 .field("jwt_extractor", &self.jwt_extractor)
30 .field("token_extractor", &self.token_extractor)
31 .field("db_pool", &"DbPool")
32 .field("analytics_provider", &self.analytics_provider.is_some())
33 .finish()
34 }
35}
36
37impl JwtContextExtractor {
38 pub fn new(jwt_secret: &str, db_pool: &DbPool) -> Self {
39 Self {
40 jwt_extractor: Arc::new(JwtExtractor::new(jwt_secret)),
41 token_extractor: TokenExtractor::browser_only(),
42 db_pool: Arc::clone(db_pool),
43 analytics_provider: None,
44 }
45 }
46
47 pub fn with_analytics_provider(mut self, provider: Arc<dyn AnalyticsProvider>) -> Self {
48 self.analytics_provider = Some(provider);
49 self
50 }
51
52 fn extract_jwt_context(
53 &self,
54 headers: &HeaderMap,
55 ) -> Result<JwtUserContext, ContextExtractionError> {
56 let token = self
57 .token_extractor
58 .extract(headers)
59 .map_err(|_| ContextExtractionError::MissingAuthHeader)?;
60 self.jwt_extractor
61 .extract_user_context(&token)
62 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
63 }
64
65 pub async fn extract_standard(
66 &self,
67 headers: &HeaderMap,
68 ) -> Result<RequestContext, ContextExtractionError> {
69 let has_auth = headers.get("authorization").is_some();
70 let has_context_headers =
71 headers.get("x-user-id").is_some() && headers.get("x-session-id").is_some();
72
73 if has_context_headers && !has_auth {
74 return Err(ContextExtractionError::ForbiddenHeader {
75 header: "X-User-ID/X-Session-ID".to_string(),
76 reason: "Context headers require valid JWT for authentication".to_string(),
77 });
78 }
79
80 let jwt_context = self.extract_jwt_context(headers)?;
81
82 if jwt_context.session_id.as_str().is_empty() {
83 return Err(ContextExtractionError::MissingSessionId);
84 }
85 if jwt_context.user_id.as_str().is_empty() {
86 return Err(ContextExtractionError::MissingUserId);
87 }
88
89 validate_user_exists(&self.db_pool, &jwt_context, "").await?;
90 validate_session_exists(self.analytics_provider.as_ref(), &jwt_context, headers, "")
91 .await?;
92
93 let session_id = headers
94 .get("x-session-id")
95 .and_then(|h| h.to_str().ok())
96 .map_or_else(
97 || jwt_context.session_id.clone(),
98 |s| SessionId::new(s.to_string()),
99 );
100
101 let user_id = headers
102 .get("x-user-id")
103 .and_then(|h| h.to_str().ok())
104 .map_or_else(
105 || jwt_context.user_id.clone(),
106 |s| UserId::new(s.to_string()),
107 );
108
109 let context_id = headers
110 .get("x-context-id")
111 .and_then(|h| h.to_str().ok())
112 .map_or_else(
113 || ContextId::new(String::new()),
114 |s| ContextId::new(s.to_string()),
115 );
116
117 let (trace_id, task_id, auth_token, agent_name) =
118 extract_common_headers(&self.token_extractor, headers);
119
120 Ok(build_context(BuildContextParams {
121 jwt_context,
122 session_id,
123 user_id,
124 trace_id,
125 context_id,
126 agent_name,
127 task_id,
128 auth_token,
129 }))
130 }
131
132 pub async fn extract_mcp_a2a(
133 &self,
134 headers: &HeaderMap,
135 ) -> Result<RequestContext, ContextExtractionError> {
136 self.extract_standard(headers).await
137 }
138
139 pub async fn extract_for_gateway(
140 &self,
141 jwt_token: &systemprompt_identifiers::JwtToken,
142 ) -> Result<RequestContext, ContextExtractionError> {
143 let jwt_context = self
144 .jwt_extractor
145 .extract_user_context(jwt_token.as_str())
146 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))?;
147
148 if jwt_context.session_id.as_str().is_empty() {
149 return Err(ContextExtractionError::MissingSessionId);
150 }
151 if jwt_context.user_id.as_str().is_empty() {
152 return Err(ContextExtractionError::MissingUserId);
153 }
154
155 validate_user_exists(&self.db_pool, &jwt_context, "gateway").await?;
156
157 let session_id = jwt_context.session_id.clone();
158 let user_id = jwt_context.user_id.clone();
159
160 Ok(build_context(BuildContextParams {
161 jwt_context,
162 session_id,
163 user_id,
164 trace_id: TraceId::generate(),
165 context_id: ContextId::new(String::new()),
166 agent_name: AgentName::system(),
167 task_id: None,
168 auth_token: Some(jwt_token.as_str().to_string()),
169 }))
170 }
171
172 async fn extract_from_request_impl(
173 &self,
174 request: Request<Body>,
175 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
176 use crate::services::middleware::context::sources::{
177 ContextIdSource, PayloadSource, TASK_BASED_CONTEXT_MARKER,
178 };
179
180 let headers = request.headers().clone();
181 let has_auth = headers.get("authorization").is_some();
182
183 if headers.get("x-context-id").is_some() && !has_auth {
184 return Err(ContextExtractionError::ForbiddenHeader {
185 header: "X-Context-ID".to_string(),
186 reason: "Context ID must be in request body (A2A spec). Use contextId field in \
187 message."
188 .to_string(),
189 });
190 }
191
192 let jwt_context = self.extract_jwt_context(&headers)?;
193
194 if jwt_context.session_id.as_str().is_empty() {
195 return Err(ContextExtractionError::MissingSessionId);
196 }
197 if jwt_context.user_id.as_str().is_empty() {
198 return Err(ContextExtractionError::MissingUserId);
199 }
200
201 validate_user_exists(&self.db_pool, &jwt_context, " (A2A route)").await?;
202 validate_session_exists(
203 self.analytics_provider.as_ref(),
204 &jwt_context,
205 &headers,
206 " (A2A route)",
207 )
208 .await?;
209
210 let (body_bytes, reconstructed_request) =
211 PayloadSource::read_and_reconstruct(request).await?;
212
213 let context_source = PayloadSource::extract_context_source(&body_bytes)?;
214 let (context_id, task_id_from_payload) = match context_source {
215 ContextIdSource::Direct(id) => (ContextId::new(id), None),
216 ContextIdSource::FromTask { task_id } => {
217 (ContextId::new(TASK_BASED_CONTEXT_MARKER), Some(task_id))
218 },
219 };
220
221 let (trace_id, task_id_from_header, auth_token, agent_name) =
222 extract_common_headers(&self.token_extractor, &headers);
223
224 let task_id = task_id_from_payload.or(task_id_from_header);
225
226 let session_id = jwt_context.session_id.clone();
227 let user_id = jwt_context.user_id.clone();
228 let ctx = build_context(BuildContextParams {
229 jwt_context,
230 session_id,
231 user_id,
232 trace_id,
233 context_id,
234 agent_name,
235 task_id,
236 auth_token,
237 });
238
239 Ok((ctx, reconstructed_request))
240 }
241}
242
243#[async_trait]
244impl ContextExtractor for JwtContextExtractor {
245 async fn extract_from_headers(
246 &self,
247 headers: &HeaderMap,
248 ) -> Result<RequestContext, ContextExtractionError> {
249 self.extract_standard(headers).await
250 }
251
252 async fn extract_from_request(
253 &self,
254 request: Request<Body>,
255 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
256 self.extract_from_request_impl(request).await
257 }
258
259 async fn extract_user_only(
260 &self,
261 headers: &HeaderMap,
262 ) -> Result<RequestContext, ContextExtractionError> {
263 self.extract_standard(headers).await
264 }
265}