systemprompt_api/services/middleware/jwt/
context.rs1use async_trait::async_trait;
2use axum::body::Body;
3use axum::extract::Request;
4use axum::http::HeaderMap;
5use std::sync::Arc;
6
7use crate::services::middleware::context::ContextExtractor;
8use systemprompt_identifiers::ContextId;
9use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
10use systemprompt_security::{JwtUserContext, TokenExtractor, extract_user_context};
11use systemprompt_traits::{AnalyticsProvider, UserProvider};
12
13use super::params::{BuildContextParams, build_context, extract_common_headers};
14use super::validation::{UserCache, user_is_admin, validate_session_exists, validate_user_exists};
15
16#[derive(Clone)]
17pub struct JwtContextExtractor {
18 token_extractor: TokenExtractor,
19 analytics_provider: Arc<dyn AnalyticsProvider>,
20 user_provider: Arc<dyn UserProvider>,
21 user_cache: Arc<UserCache>,
22}
23
24impl std::fmt::Debug for JwtContextExtractor {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("JwtContextExtractor")
27 .field("token_extractor", &self.token_extractor)
28 .finish_non_exhaustive()
29 }
30}
31
32impl JwtContextExtractor {
33 pub fn new(
34 analytics_provider: Arc<dyn AnalyticsProvider>,
35 user_provider: Arc<dyn UserProvider>,
36 ) -> Self {
37 Self {
38 token_extractor: TokenExtractor::browser_only(),
39 analytics_provider,
40 user_provider,
41 user_cache: UserCache::new(),
42 }
43 }
44
45 fn extract_jwt_context(
46 &self,
47 headers: &HeaderMap,
48 ) -> Result<JwtUserContext, ContextExtractionError> {
49 let token = self
50 .token_extractor
51 .extract(headers)
52 .map_err(|_e| ContextExtractionError::MissingAuthHeader)?;
53 extract_user_context(&token)
54 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
55 }
56
57 async fn validate(
58 &self,
59 jwt_context: &JwtUserContext,
60 route_context: &str,
61 ) -> Result<systemprompt_traits::AuthUser, ContextExtractionError> {
62 if jwt_context.session_id.as_str().is_empty() {
63 return Err(ContextExtractionError::MissingSessionId);
64 }
65 if jwt_context.user_id.as_str().is_empty() {
66 return Err(ContextExtractionError::MissingUserId);
67 }
68 let validated = validate_user_exists(
69 &self.user_provider,
70 &self.user_cache,
71 jwt_context,
72 route_context,
73 )
74 .await?;
75 validate_session_exists(&self.analytics_provider, jwt_context, route_context).await?;
76 Ok(validated.user)
77 }
78
79 pub async fn extract_standard(
80 &self,
81 headers: &HeaderMap,
82 ) -> Result<RequestContext, ContextExtractionError> {
83 let jwt_context = self.extract_jwt_context(headers)?;
84 let user = self.validate(&jwt_context, "").await?;
85
86 let context_id = headers
87 .get("x-context-id")
88 .and_then(|h| h.to_str().ok())
89 .filter(|s| !s.is_empty())
90 .and_then(|s| ContextId::try_new(s).ok())
91 .unwrap_or_else(ContextId::generate);
92
93 let (trace_id, task_id, auth_token, agent_name) =
94 extract_common_headers(&self.token_extractor, headers);
95
96 let user_type = jwt_context.user_type.reconcile_with(user_is_admin(&user));
97 let session_id = jwt_context.session_id.clone();
98 let user_id = jwt_context.user_id.clone();
99
100 Ok(build_context(BuildContextParams {
101 jwt_context,
102 session_id,
103 user_id,
104 trace_id,
105 context_id,
106 agent_name,
107 task_id,
108 auth_token,
109 user_type,
110 }))
111 }
112
113 pub async fn decode_for_gateway(
114 &self,
115 jwt_token: &systemprompt_identifiers::JwtToken,
116 ) -> Result<JwtUserContext, ContextExtractionError> {
117 let jwt_context = extract_user_context(jwt_token.as_str())
118 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))?;
119
120 self.validate(&jwt_context, "gateway").await?;
121 Ok(jwt_context)
122 }
123
124 async fn extract_from_request_impl(
125 &self,
126 request: Request<Body>,
127 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
128 use crate::services::middleware::context::sources::{ContextIdSource, PayloadSource};
129
130 let headers = request.headers().clone();
131 let has_auth = headers.get("authorization").is_some();
132
133 if headers.get("x-context-id").is_some() && !has_auth {
134 return Err(ContextExtractionError::ForbiddenHeader {
135 header: "X-Context-ID".to_owned(),
136 reason: "Context ID must be in request body (A2A spec). Use contextId field in \
137 message."
138 .to_owned(),
139 });
140 }
141
142 let jwt_context = self.extract_jwt_context(&headers)?;
143 let user = self.validate(&jwt_context, " (A2A route)").await?;
144
145 let (body_bytes, reconstructed_request) =
146 PayloadSource::read_and_reconstruct(request).await?;
147
148 let context_source = PayloadSource::extract_context_source(&body_bytes)?;
149 let (context_id, task_id_from_payload) = match context_source {
150 ContextIdSource::Direct(id) => (
151 ContextId::try_new(id).map_err(|e| ContextExtractionError::InvalidHeaderValue {
152 header: "contextId".to_owned(),
153 reason: e.to_string(),
154 })?,
155 None,
156 ),
157 ContextIdSource::FromTask { task_id } => (ContextId::generate(), Some(task_id)),
158 };
159
160 let (trace_id, task_id_from_header, auth_token, agent_name) =
161 extract_common_headers(&self.token_extractor, &headers);
162
163 let task_id = task_id_from_payload.or(task_id_from_header);
164 let user_type = jwt_context.user_type.reconcile_with(user_is_admin(&user));
165
166 let session_id = jwt_context.session_id.clone();
167 let user_id = jwt_context.user_id.clone();
168 let ctx = build_context(BuildContextParams {
169 jwt_context,
170 session_id,
171 user_id,
172 trace_id,
173 context_id,
174 agent_name,
175 task_id,
176 auth_token,
177 user_type,
178 });
179
180 Ok((ctx, reconstructed_request))
181 }
182}
183
184#[async_trait]
185impl ContextExtractor for JwtContextExtractor {
186 async fn extract_from_headers(
187 &self,
188 headers: &HeaderMap,
189 ) -> Result<RequestContext, ContextExtractionError> {
190 self.extract_standard(headers).await
191 }
192
193 async fn extract_from_request(
194 &self,
195 request: Request<Body>,
196 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
197 self.extract_from_request_impl(request).await
198 }
199
200 async fn extract_user_only(
201 &self,
202 headers: &HeaderMap,
203 ) -> Result<RequestContext, ContextExtractionError> {
204 self.extract_standard(headers).await
205 }
206}