Skip to main content

systemprompt_api/services/middleware/jwt/
context.rs

1use async_trait::async_trait;
2use axum::body::Body;
3use axum::extract::Request;
4use axum::http::HeaderMap;
5use std::sync::Arc;
6
7use crate::services::middleware::context::ContextExtractor;
8use systemprompt_identifiers::ContextId;
9use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
10use systemprompt_security::{JwtUserContext, TokenExtractor, extract_user_context};
11use systemprompt_traits::{AnalyticsProvider, UserProvider};
12
13use super::params::{BuildContextParams, build_context, extract_common_headers};
14use super::validation::{UserCache, user_is_admin, validate_session_exists, validate_user_exists};
15
16#[derive(Clone)]
17pub struct JwtContextExtractor {
18    token_extractor: TokenExtractor,
19    analytics_provider: Arc<dyn AnalyticsProvider>,
20    user_provider: Arc<dyn UserProvider>,
21    user_cache: Arc<UserCache>,
22}
23
24impl std::fmt::Debug for JwtContextExtractor {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("JwtContextExtractor")
27            .field("token_extractor", &self.token_extractor)
28            .finish_non_exhaustive()
29    }
30}
31
32impl JwtContextExtractor {
33    pub fn new(
34        analytics_provider: Arc<dyn AnalyticsProvider>,
35        user_provider: Arc<dyn UserProvider>,
36    ) -> Self {
37        Self {
38            token_extractor: TokenExtractor::browser_only(),
39            analytics_provider,
40            user_provider,
41            user_cache: UserCache::new(),
42        }
43    }
44
45    fn extract_jwt_context(
46        &self,
47        headers: &HeaderMap,
48    ) -> Result<JwtUserContext, ContextExtractionError> {
49        let token = self
50            .token_extractor
51            .extract(headers)
52            .map_err(|_e| ContextExtractionError::MissingAuthHeader)?;
53        extract_user_context(&token)
54            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))
55    }
56
57    async fn validate(
58        &self,
59        jwt_context: &JwtUserContext,
60        route_context: &str,
61    ) -> Result<systemprompt_traits::AuthUser, ContextExtractionError> {
62        if jwt_context.session_id.as_str().is_empty() {
63            return Err(ContextExtractionError::MissingSessionId);
64        }
65        if jwt_context.user_id.as_str().is_empty() {
66            return Err(ContextExtractionError::MissingUserId);
67        }
68        let validated = validate_user_exists(
69            &self.user_provider,
70            &self.user_cache,
71            jwt_context,
72            route_context,
73        )
74        .await?;
75        validate_session_exists(&self.analytics_provider, jwt_context, route_context).await?;
76        Ok(validated.user)
77    }
78
79    pub async fn extract_standard(
80        &self,
81        headers: &HeaderMap,
82    ) -> Result<RequestContext, ContextExtractionError> {
83        let jwt_context = self.extract_jwt_context(headers)?;
84        let user = self.validate(&jwt_context, "").await?;
85
86        let context_id = headers
87            .get("x-context-id")
88            .and_then(|h| h.to_str().ok())
89            .filter(|s| !s.is_empty())
90            .and_then(|s| ContextId::try_new(s).ok())
91            .unwrap_or_else(ContextId::generate);
92
93        let (trace_id, task_id, auth_token, agent_name) =
94            extract_common_headers(&self.token_extractor, headers);
95
96        let user_type = jwt_context.user_type.reconcile_with(user_is_admin(&user));
97        let session_id = jwt_context.session_id.clone();
98        let user_id = jwt_context.user_id.clone();
99
100        Ok(build_context(BuildContextParams {
101            jwt_context,
102            session_id,
103            user_id,
104            trace_id,
105            context_id,
106            agent_name,
107            task_id,
108            auth_token,
109            user_type,
110        }))
111    }
112
113    pub async fn decode_for_gateway(
114        &self,
115        jwt_token: &systemprompt_identifiers::JwtToken,
116    ) -> Result<JwtUserContext, ContextExtractionError> {
117        let jwt_context = extract_user_context(jwt_token.as_str())
118            .map_err(|e| ContextExtractionError::InvalidToken(e.to_string()))?;
119
120        self.validate(&jwt_context, "gateway").await?;
121        Ok(jwt_context)
122    }
123
124    async fn extract_from_request_impl(
125        &self,
126        request: Request<Body>,
127    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
128        use crate::services::middleware::context::sources::{ContextIdSource, PayloadSource};
129
130        let headers = request.headers().clone();
131        let has_auth = headers.get("authorization").is_some();
132
133        if headers.get("x-context-id").is_some() && !has_auth {
134            return Err(ContextExtractionError::ForbiddenHeader {
135                header: "X-Context-ID".to_owned(),
136                reason: "Context ID must be in request body (A2A spec). Use contextId field in \
137                         message."
138                    .to_owned(),
139            });
140        }
141
142        let jwt_context = self.extract_jwt_context(&headers)?;
143        let user = self.validate(&jwt_context, " (A2A route)").await?;
144
145        let (body_bytes, reconstructed_request) =
146            PayloadSource::read_and_reconstruct(request).await?;
147
148        let context_source = PayloadSource::extract_context_source(&body_bytes)?;
149        let (context_id, task_id_from_payload) = match context_source {
150            ContextIdSource::Direct(id) => (
151                ContextId::try_new(id).map_err(|e| ContextExtractionError::InvalidHeaderValue {
152                    header: "contextId".to_owned(),
153                    reason: e.to_string(),
154                })?,
155                None,
156            ),
157            ContextIdSource::FromTask { task_id } => (ContextId::generate(), Some(task_id)),
158        };
159
160        let (trace_id, task_id_from_header, auth_token, agent_name) =
161            extract_common_headers(&self.token_extractor, &headers);
162
163        let task_id = task_id_from_payload.or(task_id_from_header);
164        let user_type = jwt_context.user_type.reconcile_with(user_is_admin(&user));
165
166        let session_id = jwt_context.session_id.clone();
167        let user_id = jwt_context.user_id.clone();
168        let ctx = build_context(BuildContextParams {
169            jwt_context,
170            session_id,
171            user_id,
172            trace_id,
173            context_id,
174            agent_name,
175            task_id,
176            auth_token,
177            user_type,
178        });
179
180        Ok((ctx, reconstructed_request))
181    }
182}
183
184#[async_trait]
185impl ContextExtractor for JwtContextExtractor {
186    async fn extract_from_headers(
187        &self,
188        headers: &HeaderMap,
189    ) -> Result<RequestContext, ContextExtractionError> {
190        self.extract_standard(headers).await
191    }
192
193    async fn extract_from_request(
194        &self,
195        request: Request<Body>,
196    ) -> Result<(RequestContext, Request<Body>), ContextExtractionError> {
197        self.extract_from_request_impl(request).await
198    }
199
200    async fn extract_user_only(
201        &self,
202        headers: &HeaderMap,
203    ) -> Result<RequestContext, ContextExtractionError> {
204        self.extract_standard(headers).await
205    }
206}