systemprompt_api/services/middleware/jwt/
context.rs1use async_trait::async_trait;
10use axum::body::Body;
11use axum::extract::Request;
12use axum::http::HeaderMap;
13use std::sync::Arc;
14
15use crate::services::middleware::context::ContextExtractor;
16use systemprompt_identifiers::ContextId;
17use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
18use systemprompt_security::{JwtUserContext, TokenExtractor, extract_user_context};
19use systemprompt_traits::{AnalyticsProvider, UserProvider};
20
21use super::params::{BuildContextParams, build_context, extract_common_headers};
22use super::revocation::JtiRevocationChecker;
23use super::validation::{UserCache, user_is_admin, validate_session_exists, validate_user_exists};
24
25#[derive(Clone)]
26pub struct JwtContextExtractor {
27 token_extractor: TokenExtractor,
28 analytics_provider: Arc<dyn AnalyticsProvider>,
29 user_provider: Arc<dyn UserProvider>,
30 user_cache: Arc<UserCache>,
31 jti_revocation: JtiRevocationChecker,
32}
33
34impl std::fmt::Debug for JwtContextExtractor {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("JwtContextExtractor")
37 .field("token_extractor", &self.token_extractor)
38 .finish_non_exhaustive()
39 }
40}
41
42impl JwtContextExtractor {
43 pub fn new(
44 analytics_provider: Arc<dyn AnalyticsProvider>,
45 user_provider: Arc<dyn UserProvider>,
46 jti_revocation: JtiRevocationChecker,
47 ) -> Self {
48 Self {
49 token_extractor: TokenExtractor::browser_only(),
50 analytics_provider,
51 user_provider,
52 user_cache: UserCache::new(),
53 jti_revocation,
54 }
55 }
56
57 fn extract_jwt_context(
58 &self,
59 headers: &HeaderMap,
60 ) -> Result<JwtUserContext, ContextExtractionError> {
61 let token = self
62 .token_extractor
63 .extract(headers)
64 .map_err(|_e| ContextExtractionError::MissingAuthHeader)?;
65 extract_user_context(&token)
66 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
67 }
68
69 async fn validate(
70 &self,
71 jwt_context: &JwtUserContext,
72 route_context: &str,
73 ) -> Result<systemprompt_traits::AuthUser, ContextExtractionError> {
74 if jwt_context.session_id.as_str().is_empty() {
75 return Err(ContextExtractionError::MissingSessionId);
76 }
77 if jwt_context.user_id.as_str().is_empty() {
78 return Err(ContextExtractionError::MissingUserId);
79 }
80 let validated = validate_user_exists(
81 &self.user_provider,
82 &self.user_cache,
83 jwt_context,
84 route_context,
85 )
86 .await?;
87 validate_session_exists(&self.analytics_provider, jwt_context, route_context).await?;
88 self.jti_revocation
89 .ensure_not_revoked(&jwt_context.jti)
90 .await?;
91 Ok(validated.user)
92 }
93
94 pub async fn extract_standard(
95 &self,
96 headers: &HeaderMap,
97 ) -> Result<RequestContext, ContextExtractionError> {
98 let jwt_context = self.extract_jwt_context(headers)?;
99 let user = self.validate(&jwt_context, "").await?;
100
101 let context_id = headers
102 .get("x-context-id")
103 .and_then(|h| h.to_str().ok())
104 .filter(|s| !s.is_empty())
105 .and_then(|s| ContextId::try_new(s).ok())
106 .unwrap_or_else(ContextId::generate);
107
108 let (trace_id, task_id, auth_token, agent_name) =
109 extract_common_headers(&self.token_extractor, headers);
110
111 let user_type = jwt_context.user_type.reconcile_with(user_is_admin(&user));
112 let session_id = jwt_context.session_id.clone();
113 let user_id = jwt_context.user_id.clone();
114
115 Ok(build_context(BuildContextParams {
116 jwt_context,
117 session_id,
118 user_id,
119 trace_id,
120 context_id,
121 agent_name,
122 task_id,
123 auth_token,
124 user_type,
125 }))
126 }
127
128 pub async fn decode_for_gateway(
129 &self,
130 jwt_token: &systemprompt_identifiers::JwtToken,
131 ) -> Result<(JwtUserContext, systemprompt_traits::AuthUser), ContextExtractionError> {
132 let jwt_context = extract_user_context(jwt_token.as_str())
133 .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))?;
134
135 let user = self.validate(&jwt_context, "gateway").await?;
136 Ok((jwt_context, user))
137 }
138
139 async fn extract_from_request_impl(
140 &self,
141 request: Request<Body>,
142 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
143 use crate::services::middleware::context::sources::{ContextIdSource, PayloadSource};
144
145 let headers = request.headers().clone();
146 let has_auth = headers.get("authorization").is_some();
147
148 if headers.get("x-context-id").is_some() && !has_auth {
149 return Err(ContextExtractionError::ForbiddenHeader {
150 header: "X-Context-ID".to_owned(),
151 reason: "Context ID must be in request body (A2A spec). Use contextId field in \
152 message."
153 .to_owned(),
154 });
155 }
156
157 let jwt_context = self.extract_jwt_context(&headers)?;
158 let user = self.validate(&jwt_context, " (A2A route)").await?;
159
160 let (body_bytes, reconstructed_request) =
161 PayloadSource::read_and_reconstruct(request).await?;
162
163 let context_source = PayloadSource::extract_context_source(&body_bytes)?;
164 let (context_id, task_id_from_payload) = match context_source {
165 ContextIdSource::Direct(id) => (
166 ContextId::try_new(id).map_err(|e| ContextExtractionError::InvalidHeaderValue {
167 header: "contextId".to_owned(),
168 reason: e.to_string(),
169 })?,
170 None,
171 ),
172 ContextIdSource::FromTask { task_id } => (ContextId::generate(), Some(task_id)),
173 };
174
175 let (trace_id, task_id_from_header, auth_token, agent_name) =
176 extract_common_headers(&self.token_extractor, &headers);
177
178 let task_id = task_id_from_payload.or(task_id_from_header);
179 let user_type = jwt_context.user_type.reconcile_with(user_is_admin(&user));
180
181 let session_id = jwt_context.session_id.clone();
182 let user_id = jwt_context.user_id.clone();
183 let ctx = build_context(BuildContextParams {
184 jwt_context,
185 session_id,
186 user_id,
187 trace_id,
188 context_id,
189 agent_name,
190 task_id,
191 auth_token,
192 user_type,
193 });
194
195 Ok((ctx, reconstructed_request))
196 }
197}
198
199#[async_trait]
200impl ContextExtractor for JwtContextExtractor {
201 async fn extract_from_headers(
202 &self,
203 headers: &HeaderMap,
204 ) -> Result<RequestContext, ContextExtractionError> {
205 self.extract_standard(headers).await
206 }
207
208 async fn extract_from_request(
209 &self,
210 request: Request<Body>,
211 ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
212 self.extract_from_request_impl(request).await
213 }
214}