systemprompt_api/services/middleware/context/
middleware.rs1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::{IntoResponse, Response};
4use std::sync::Arc;
5use systemprompt_security::HeaderExtractor;
6use tracing::Instrument;
7
8use super::extractors::ContextExtractor;
9use super::requirements::ContextRequirement;
10use systemprompt_identifiers::{AgentName, ContextId, TraceId};
11use systemprompt_models::api::ApiError;
12use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
13
14#[derive(Debug, Clone)]
15pub struct ContextMiddleware<E> {
16 extractor: Arc<E>,
17 auth_level: ContextRequirement,
18}
19
20impl<E> ContextMiddleware<E> {
21 pub fn new(extractor: E) -> Self {
22 Self {
23 extractor: Arc::new(extractor),
24 auth_level: ContextRequirement::default(),
25 }
26 }
27
28 pub fn public(extractor: E) -> Self {
29 Self {
30 extractor: Arc::new(extractor),
31 auth_level: ContextRequirement::None,
32 }
33 }
34
35 pub fn user_only(extractor: E) -> Self {
36 Self {
37 extractor: Arc::new(extractor),
38 auth_level: ContextRequirement::UserOnly,
39 }
40 }
41
42 pub fn full(extractor: E) -> Self {
43 Self {
44 extractor: Arc::new(extractor),
45 auth_level: ContextRequirement::UserWithContext,
46 }
47 }
48
49 pub fn mcp(extractor: E) -> Self {
50 Self {
51 extractor: Arc::new(extractor),
52 auth_level: ContextRequirement::McpWithHeaders,
53 }
54 }
55
56 fn error_to_api_error(error: &ContextExtractionError) -> ApiError {
57 match error {
58 ContextExtractionError::MissingAuthHeader => {
59 ApiError::unauthorized("Missing Authorization header")
60 },
61 ContextExtractionError::InvalidToken(_) => {
62 ApiError::unauthorized("Invalid or expired JWT token")
63 },
64 ContextExtractionError::UserNotFound(_) => {
65 ApiError::unauthorized("User no longer exists")
66 },
67 ContextExtractionError::MissingSessionId => {
68 ApiError::bad_request("JWT missing required 'session_id' claim")
69 },
70 ContextExtractionError::MissingUserId => {
71 ApiError::bad_request("JWT missing required 'sub' claim")
72 },
73 ContextExtractionError::MissingContextId => ApiError::bad_request(
74 "Missing required 'x-context-id' header (for MCP routes) or contextId in body \
75 (for A2A routes)",
76 ),
77 ContextExtractionError::MissingHeader(header) => {
78 ApiError::bad_request(format!("Missing required header: {header}"))
79 },
80 ContextExtractionError::InvalidHeaderValue { header, reason } => {
81 ApiError::bad_request(format!("Invalid header {header}: {reason}"))
82 },
83 ContextExtractionError::InvalidUserId(reason) => {
84 ApiError::bad_request(format!("Invalid user_id: {reason}"))
85 },
86 ContextExtractionError::DatabaseError(_) => {
87 ApiError::internal_error("Internal server error")
88 },
89 ContextExtractionError::ForbiddenHeader { header, reason } => {
90 ApiError::bad_request(format!(
91 "Header '{header}' is not allowed: {reason}. Use JWT authentication instead."
92 ))
93 },
94 }
95 }
96
97 fn log_error_response(
98 error: &ContextExtractionError,
99 trace_id: &TraceId,
100 path: &str,
101 method: &str,
102 ) -> Response {
103 let _span = tracing::error_span!(
104 "context_extraction_error",
105 trace_id = %trace_id,
106 path = %path,
107 method = %method,
108 )
109 .entered();
110
111 match error {
112 ContextExtractionError::DatabaseError(e) => {
113 tracing::error!(
114 error = %e,
115 error_type = "database",
116 "Context extraction failed due to database error"
117 );
118 },
119 ContextExtractionError::InvalidToken(reason) => {
120 tracing::warn!(
121 reason = %reason,
122 error_type = "invalid_token",
123 "Context extraction failed: invalid token"
124 );
125 },
126 ContextExtractionError::UserNotFound(user_id) => {
127 tracing::warn!(
128 user_id = %user_id,
129 error_type = "user_not_found",
130 "Context extraction failed: user not found"
131 );
132 },
133 _ => {
134 tracing::warn!(
135 error = %error,
136 error_type = "context_extraction",
137 "Context extraction failed"
138 );
139 },
140 }
141
142 Self::error_to_api_error(error)
143 .with_trace_id(trace_id.as_str())
144 .with_path(path)
145 .into_response()
146 }
147}
148
149fn create_request_span(ctx: &RequestContext) -> tracing::Span {
150 tracing::info_span!(
151 "request",
152 user_id = %ctx.user_id(),
153 session_id = %ctx.session_id(),
154 trace_id = %ctx.trace_id(),
155 context_id = %ctx.context_id(),
156 )
157}
158
159impl<E: ContextExtractor> ContextMiddleware<E> {
160 pub async fn handle(&self, request: Request, next: Next) -> Response {
161 let requirement = request
162 .extensions()
163 .get::<ContextRequirement>()
164 .copied()
165 .unwrap_or(self.auth_level);
166
167 if request.extensions().get::<RequestContext>().is_some()
168 && self.auth_level == ContextRequirement::None
169 {
170 return next.run(request).await;
171 }
172
173 match requirement {
174 ContextRequirement::None => self.handle_none_requirement(request, next).await,
175 ContextRequirement::UserOnly => self.handle_user_only(request, next).await,
176 ContextRequirement::UserWithContext => {
177 self.handle_user_with_context(request, next).await
178 },
179 ContextRequirement::McpWithHeaders => self.handle_mcp_with_headers(request, next).await,
180 }
181 }
182
183 async fn handle_none_requirement(&self, mut request: Request, next: Next) -> Response {
184 let headers = request.headers();
185 let mut req_ctx = if let Some(ctx) = request.extensions().get::<RequestContext>() {
186 ctx.clone()
187 } else {
188 return ApiError::internal_error(
189 "Middleware configuration error: SessionMiddleware must run before \
190 ContextMiddleware",
191 )
192 .into_response();
193 };
194
195 if let Some(context_id) = headers.get("x-context-id") {
196 if let Ok(id) = context_id.to_str() {
197 req_ctx.execution.context_id = ContextId::new(id.to_string());
198 }
199 }
200
201 if let Some(agent_name) = headers.get("x-agent-name") {
202 if let Ok(name) = agent_name.to_str() {
203 req_ctx.execution.agent_name = AgentName::new(name.to_string());
204 }
205 }
206
207 let span = create_request_span(&req_ctx);
208 request.extensions_mut().insert(req_ctx);
209 next.run(request).instrument(span).await
210 }
211
212 async fn handle_user_only(&self, mut request: Request, next: Next) -> Response {
213 let trace_id = HeaderExtractor::extract_trace_id(request.headers());
214 let path = request.uri().path().to_string();
215 let method = request.method().to_string();
216
217 match self.extractor.extract_user_only(request.headers()).await {
218 Ok(context) => {
219 let span = create_request_span(&context);
220 request.extensions_mut().insert(context);
221 next.run(request).instrument(span).await
222 },
223 Err(e) => Self::log_error_response(&e, &trace_id, &path, &method),
224 }
225 }
226
227 async fn handle_user_with_context(&self, request: Request, next: Next) -> Response {
228 let trace_id = HeaderExtractor::extract_trace_id(request.headers());
229 let path = request.uri().path().to_string();
230 let method = request.method().to_string();
231
232 match self.extractor.extract_from_request(request).await {
233 Ok((context, reconstructed_request)) => {
234 let span = create_request_span(&context);
235 let mut req = reconstructed_request;
236 req.extensions_mut().insert(context);
237 next.run(req).instrument(span).await
238 },
239 Err(e) => Self::log_error_response(&e, &trace_id, &path, &method),
240 }
241 }
242
243 async fn handle_mcp_with_headers(&self, request: Request, next: Next) -> Response {
244 let trace_id = HeaderExtractor::extract_trace_id(request.headers());
245 let path = request.uri().path().to_string();
246 let method = request.method().to_string();
247
248 match self.extractor.extract_from_headers(request.headers()).await {
249 Ok(context) => {
250 let span = create_request_span(&context);
251 let mut req = request;
252 req.extensions_mut().insert(context);
253 next.run(req).instrument(span).await
254 },
255 Err(e) => {
256 let fallback_ctx = request.extensions().get::<RequestContext>().cloned();
257 #[allow(clippy::single_match_else)]
258 match fallback_ctx {
259 Some(ctx) => {
260 tracing::debug!(
261 error = %e,
262 trace_id = %trace_id,
263 "MCP header extraction failed, using session context"
264 );
265 let span = create_request_span(&ctx);
266 next.run(request).instrument(span).await
267 },
268 None => {
269 tracing::error!(
270 trace_id = %trace_id,
271 path = %path,
272 method = %method,
273 "Middleware configuration error: SessionMiddleware must run before ContextMiddleware"
274 );
275 ApiError::internal_error("Middleware configuration error")
276 .with_trace_id(trace_id.as_str())
277 .with_path(&path)
278 .into_response()
279 },
280 }
281 },
282 }
283 }
284}