systemprompt_api/services/middleware/jwt/
context.rs1use async_trait::async_trait;
2use axum::body::Body;
3use axum::extract::Request;
4use axum::http::HeaderMap;
5use std::sync::Arc;
6
7use crate::services::middleware::context::ContextExtractor;
8use systemprompt_identifiers::{ContextId, SessionId, UserId};
9use systemprompt_models::auth::UserType;
10use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
11use systemprompt_security::TokenExtractor;
12use systemprompt_traits::{AnalyticsProvider, AuthUser, UserProvider};
13
14use super::params::{BuildContextParams, build_context, extract_common_headers};
15use super::token::{JwtExtractor, JwtUserContext};
16use super::validation::{UserCache, user_is_admin, validate_session_exists, validate_user_exists};
17
18#[derive(Clone)]
19pub struct JwtContextExtractor {
20 jwt_extractor: Arc<JwtExtractor>,
21 token_extractor: TokenExtractor,
22 analytics_provider: Arc<dyn AnalyticsProvider>,
23 user_provider: Arc<dyn UserProvider>,
24 user_cache: Arc<UserCache>,
25}
26
27impl std::fmt::Debug for JwtContextExtractor {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("JwtContextExtractor")
30 .field("jwt_extractor", &self.jwt_extractor)
31 .field("token_extractor", &self.token_extractor)
32 .finish_non_exhaustive()
33 }
34}
35
36impl JwtContextExtractor {
37 pub fn new(
38 analytics_provider: Arc<dyn AnalyticsProvider>,
39 user_provider: Arc<dyn UserProvider>,
40 ) -> Self {
41 Self {
42 jwt_extractor: Arc::new(JwtExtractor::new()),
43 token_extractor: TokenExtractor::browser_only(),
44 analytics_provider,
45 user_provider,
46 user_cache: UserCache::new(),
47 }
48 }
49
50 fn extract_jwt_context(
51 &self,
52 headers: &HeaderMap,
53 ) -> Result<JwtUserContext, ContextExtractionError> {
54 let token = self
55 .token_extractor
56 .extract(headers)
57 .map_err(|_| ContextExtractionError::MissingAuthHeader)?;
58 self.jwt_extractor
59 .extract_user_context(&token)
60 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
61 }
62
63 async fn validate(
64 &self,
65 jwt_context: &JwtUserContext,
66 route_context: &str,
67 ) -> Result<AuthUser, ContextExtractionError> {
68 if jwt_context.session_id.as_str().is_empty() {
69 return Err(ContextExtractionError::MissingSessionId);
70 }
71 if jwt_context.user_id.as_str().is_empty() {
72 return Err(ContextExtractionError::MissingUserId);
73 }
74 let validated = validate_user_exists(
75 &self.user_provider,
76 &self.user_cache,
77 jwt_context,
78 route_context,
79 )
80 .await?;
81 validate_session_exists(&self.analytics_provider, jwt_context, route_context).await?;
82 Ok(validated.user)
83 }
84
85 pub async fn extract_standard(
86 &self,
87 headers: &HeaderMap,
88 ) -> Result<RequestContext, ContextExtractionError> {
89 let has_auth = headers.get("authorization").is_some();
90 let has_context_headers =
91 headers.get("x-user-id").is_some() && headers.get("x-session-id").is_some();
92
93 if has_context_headers && !has_auth {
94 return Err(ContextExtractionError::ForbiddenHeader {
95 header: "X-User-ID/X-Session-ID".to_string(),
96 reason: "Context headers require valid JWT for authentication".to_string(),
97 });
98 }
99
100 let jwt_context = self.extract_jwt_context(headers)?;
101 let user = self.validate(&jwt_context, "").await?;
102
103 let session_id = headers
104 .get("x-session-id")
105 .and_then(|h| h.to_str().ok())
106 .map_or_else(
107 || jwt_context.session_id.clone(),
108 |s| SessionId::new(s.to_string()),
109 );
110
111 let user_id = headers
112 .get("x-user-id")
113 .and_then(|h| h.to_str().ok())
114 .map_or_else(
115 || jwt_context.user_id.clone(),
116 |s| UserId::new(s.to_string()),
117 );
118
119 let context_id = headers
120 .get("x-context-id")
121 .and_then(|h| h.to_str().ok())
122 .filter(|s| !s.is_empty())
123 .and_then(|s| ContextId::try_new(s).ok())
124 .unwrap_or_else(ContextId::generate);
125
126 let (trace_id, task_id, auth_token, agent_name) =
127 extract_common_headers(&self.token_extractor, headers);
128
129 let user_type = resolve_user_type(jwt_context.user_type, &user);
130
131 Ok(build_context(BuildContextParams {
132 jwt_context,
133 session_id,
134 user_id,
135 trace_id,
136 context_id,
137 agent_name,
138 task_id,
139 auth_token,
140 user_type,
141 }))
142 }
143
144 pub async fn extract_mcp_a2a(
145 &self,
146 headers: &HeaderMap,
147 ) -> Result<RequestContext, ContextExtractionError> {
148 self.extract_standard(headers).await
149 }
150
151 pub async fn decode_for_gateway(
152 &self,
153 jwt_token: &systemprompt_identifiers::JwtToken,
154 ) -> Result<JwtUserContext, ContextExtractionError> {
155 let jwt_context = self
156 .jwt_extractor
157 .extract_user_context(jwt_token.as_str())
158 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))?;
159
160 let _ = self.validate(&jwt_context, "gateway").await?;
161 Ok(jwt_context)
162 }
163
164 async fn extract_from_request_impl(
165 &self,
166 request: Request<Body>,
167 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
168 use crate::services::middleware::context::sources::{ContextIdSource, PayloadSource};
169
170 let headers = request.headers().clone();
171 let has_auth = headers.get("authorization").is_some();
172
173 if headers.get("x-context-id").is_some() && !has_auth {
174 return Err(ContextExtractionError::ForbiddenHeader {
175 header: "X-Context-ID".to_string(),
176 reason: "Context ID must be in request body (A2A spec). Use contextId field in \
177 message."
178 .to_string(),
179 });
180 }
181
182 let jwt_context = self.extract_jwt_context(&headers)?;
183 let user = self.validate(&jwt_context, " (A2A route)").await?;
184
185 let (body_bytes, reconstructed_request) =
186 PayloadSource::read_and_reconstruct(request).await?;
187
188 let context_source = PayloadSource::extract_context_source(&body_bytes)?;
189 let (context_id, task_id_from_payload) = match context_source {
190 ContextIdSource::Direct(id) => (
191 ContextId::try_new(id).map_err(|e| ContextExtractionError::InvalidHeaderValue {
192 header: "contextId".to_string(),
193 reason: e.to_string(),
194 })?,
195 None,
196 ),
197 ContextIdSource::FromTask { task_id } => (ContextId::generate(), Some(task_id)),
198 };
199
200 let (trace_id, task_id_from_header, auth_token, agent_name) =
201 extract_common_headers(&self.token_extractor, &headers);
202
203 let task_id = task_id_from_payload.or(task_id_from_header);
204 let user_type = resolve_user_type(jwt_context.user_type, &user);
205
206 let session_id = jwt_context.session_id.clone();
207 let user_id = jwt_context.user_id.clone();
208 let ctx = build_context(BuildContextParams {
209 jwt_context,
210 session_id,
211 user_id,
212 trace_id,
213 context_id,
214 agent_name,
215 task_id,
216 auth_token,
217 user_type,
218 });
219
220 Ok((ctx, reconstructed_request))
221 }
222}
223
224fn resolve_user_type(claimed: UserType, user: &AuthUser) -> UserType {
229 match claimed {
230 UserType::Admin if !user_is_admin(user) => UserType::User,
231 other => other,
232 }
233}
234
235#[async_trait]
236impl ContextExtractor for JwtContextExtractor {
237 async fn extract_from_headers(
238 &self,
239 headers: &HeaderMap,
240 ) -> Result<RequestContext, ContextExtractionError> {
241 self.extract_standard(headers).await
242 }
243
244 async fn extract_from_request(
245 &self,
246 request: Request<Body>,
247 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
248 self.extract_from_request_impl(request).await
249 }
250
251 async fn extract_user_only(
252 &self,
253 headers: &HeaderMap,
254 ) -> Result<RequestContext, ContextExtractionError> {
255 self.extract_standard(headers).await
256 }
257}