Skip to main content

systemprompt_api/services/middleware/context/
middleware.rs

1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::{IntoResponse, Response};
4use std::sync::Arc;
5use systemprompt_security::HeaderExtractor;
6use tracing::Instrument;
7
8use super::extractors::ContextExtractor;
9use super::requirements::ContextRequirement;
10use systemprompt_identifiers::{AgentName, ContextId, TraceId};
11use systemprompt_models::api::ApiError;
12use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};
13
14#[derive(Debug, Clone)]
15pub struct ContextMiddleware<E> {
16    extractor: Arc<E>,
17    auth_level: ContextRequirement,
18}
19
20impl<E> ContextMiddleware<E> {
21    pub fn new(extractor: E) -> Self {
22        Self {
23            extractor: Arc::new(extractor),
24            auth_level: ContextRequirement::default(),
25        }
26    }
27
28    pub fn public(extractor: E) -> Self {
29        Self {
30            extractor: Arc::new(extractor),
31            auth_level: ContextRequirement::None,
32        }
33    }
34
35    pub fn user_only(extractor: E) -> Self {
36        Self {
37            extractor: Arc::new(extractor),
38            auth_level: ContextRequirement::UserOnly,
39        }
40    }
41
42    pub fn full(extractor: E) -> Self {
43        Self {
44            extractor: Arc::new(extractor),
45            auth_level: ContextRequirement::UserWithContext,
46        }
47    }
48
49    pub fn mcp(extractor: E) -> Self {
50        Self {
51            extractor: Arc::new(extractor),
52            auth_level: ContextRequirement::McpWithHeaders,
53        }
54    }
55
56    fn error_to_api_error(error: &ContextExtractionError) -> ApiError {
57        match error {
58            ContextExtractionError::MissingAuthHeader => {
59                ApiError::unauthorized("Missing Authorization header")
60            },
61            ContextExtractionError::InvalidToken(_) => {
62                ApiError::unauthorized("Invalid or expired JWT token")
63            },
64            ContextExtractionError::UserNotFound(_) => {
65                ApiError::unauthorized("User no longer exists")
66            },
67            ContextExtractionError::MissingSessionId => {
68                ApiError::bad_request("JWT missing required 'session_id' claim")
69            },
70            ContextExtractionError::MissingUserId => {
71                ApiError::bad_request("JWT missing required 'sub' claim")
72            },
73            ContextExtractionError::MissingContextId => ApiError::bad_request(
74                "Missing required 'x-context-id' header (for MCP routes) or contextId in body \
75                 (for A2A routes)",
76            ),
77            ContextExtractionError::MissingHeader(header) => {
78                ApiError::bad_request(format!("Missing required header: {header}"))
79            },
80            ContextExtractionError::InvalidHeaderValue { header, reason } => {
81                ApiError::bad_request(format!("Invalid header {header}: {reason}"))
82            },
83            ContextExtractionError::InvalidUserId(reason) => {
84                ApiError::bad_request(format!("Invalid user_id: {reason}"))
85            },
86            ContextExtractionError::DatabaseError(_) => {
87                ApiError::internal_error("Internal server error")
88            },
89            ContextExtractionError::ForbiddenHeader { header, reason } => {
90                ApiError::bad_request(format!(
91                    "Header '{header}' is not allowed: {reason}. Use JWT authentication instead."
92                ))
93            },
94        }
95    }
96
97    fn log_error_response(
98        error: &ContextExtractionError,
99        trace_id: &TraceId,
100        path: &str,
101        method: &str,
102    ) -> Response {
103        let _span = tracing::error_span!(
104            "context_extraction_error",
105            trace_id = %trace_id,
106            path = %path,
107            method = %method,
108        )
109        .entered();
110
111        match error {
112            ContextExtractionError::DatabaseError(e) => {
113                tracing::error!(
114                    error = %e,
115                    error_type = "database",
116                    "Context extraction failed due to database error"
117                );
118            },
119            ContextExtractionError::InvalidToken(reason) => {
120                tracing::warn!(
121                    reason = %reason,
122                    error_type = "invalid_token",
123                    "Context extraction failed: invalid token"
124                );
125            },
126            ContextExtractionError::UserNotFound(user_id) => {
127                tracing::warn!(
128                    user_id = %user_id,
129                    error_type = "user_not_found",
130                    "Context extraction failed: user not found"
131                );
132            },
133            _ => {
134                tracing::warn!(
135                    error = %error,
136                    error_type = "context_extraction",
137                    "Context extraction failed"
138                );
139            },
140        }
141
142        Self::error_to_api_error(error)
143            .with_trace_id(trace_id.as_str())
144            .with_path(path)
145            .into_response()
146    }
147}
148
149fn create_request_span(ctx: &RequestContext) -> tracing::Span {
150    tracing::info_span!(
151        "request",
152        user_id = %ctx.user_id(),
153        session_id = %ctx.session_id(),
154        trace_id = %ctx.trace_id(),
155        context_id = %ctx.context_id(),
156    )
157}
158
159impl<E: ContextExtractor> ContextMiddleware<E> {
160    pub async fn handle(&self, request: Request, next: Next) -> Response {
161        let requirement = request
162            .extensions()
163            .get::<ContextRequirement>()
164            .copied()
165            .unwrap_or(self.auth_level);
166
167        if request.extensions().get::<RequestContext>().is_some()
168            && self.auth_level == ContextRequirement::None
169        {
170            return next.run(request).await;
171        }
172
173        match requirement {
174            ContextRequirement::None => self.handle_none_requirement(request, next).await,
175            ContextRequirement::UserOnly => self.handle_user_only(request, next).await,
176            ContextRequirement::UserWithContext => {
177                self.handle_user_with_context(request, next).await
178            },
179            ContextRequirement::McpWithHeaders => self.handle_mcp_with_headers(request, next).await,
180        }
181    }
182
183    async fn handle_none_requirement(&self, mut request: Request, next: Next) -> Response {
184        let headers = request.headers();
185        let mut req_ctx = if let Some(ctx) = request.extensions().get::<RequestContext>() {
186            ctx.clone()
187        } else {
188            return ApiError::internal_error(
189                "Middleware configuration error: SessionMiddleware must run before \
190                 ContextMiddleware",
191            )
192            .into_response();
193        };
194
195        if let Some(context_id) = headers.get("x-context-id") {
196            if let Ok(id) = context_id.to_str() {
197                req_ctx.execution.context_id = ContextId::new(id.to_string());
198            }
199        }
200
201        if let Some(agent_name) = headers.get("x-agent-name") {
202            if let Ok(name) = agent_name.to_str() {
203                req_ctx.execution.agent_name = AgentName::new(name.to_string());
204            }
205        }
206
207        let span = create_request_span(&req_ctx);
208        request.extensions_mut().insert(req_ctx);
209        next.run(request).instrument(span).await
210    }
211
212    async fn handle_user_only(&self, mut request: Request, next: Next) -> Response {
213        let trace_id = HeaderExtractor::extract_trace_id(request.headers());
214        let path = request.uri().path().to_string();
215        let method = request.method().to_string();
216
217        match self.extractor.extract_user_only(request.headers()).await {
218            Ok(context) => {
219                let span = create_request_span(&context);
220                request.extensions_mut().insert(context);
221                next.run(request).instrument(span).await
222            },
223            Err(e) => Self::log_error_response(&e, &trace_id, &path, &method),
224        }
225    }
226
227    async fn handle_user_with_context(&self, request: Request, next: Next) -> Response {
228        let trace_id = HeaderExtractor::extract_trace_id(request.headers());
229        let path = request.uri().path().to_string();
230        let method = request.method().to_string();
231
232        match self.extractor.extract_from_request(request).await {
233            Ok((context, reconstructed_request)) => {
234                let span = create_request_span(&context);
235                let mut req = reconstructed_request;
236                req.extensions_mut().insert(context);
237                next.run(req).instrument(span).await
238            },
239            Err(e) => Self::log_error_response(&e, &trace_id, &path, &method),
240        }
241    }
242
243    async fn handle_mcp_with_headers(&self, request: Request, next: Next) -> Response {
244        let trace_id = HeaderExtractor::extract_trace_id(request.headers());
245        let path = request.uri().path().to_string();
246        let method = request.method().to_string();
247
248        match self.extractor.extract_from_headers(request.headers()).await {
249            Ok(context) => {
250                let span = create_request_span(&context);
251                let mut req = request;
252                req.extensions_mut().insert(context);
253                next.run(req).instrument(span).await
254            },
255            Err(e) => {
256                let fallback_ctx = request.extensions().get::<RequestContext>().cloned();
257                #[allow(clippy::single_match_else)]
258                match fallback_ctx {
259                    Some(ctx) => {
260                        tracing::debug!(
261                            error = %e,
262                            trace_id = %trace_id,
263                            "MCP header extraction failed, using session context"
264                        );
265                        let span = create_request_span(&ctx);
266                        next.run(request).instrument(span).await
267                    },
268                    None => {
269                        tracing::error!(
270                            trace_id = %trace_id,
271                            path = %path,
272                            method = %method,
273                            "Middleware configuration error: SessionMiddleware must run before ContextMiddleware"
274                        );
275                        ApiError::internal_error("Middleware configuration error")
276                            .with_trace_id(trace_id.as_str())
277                            .with_path(&path)
278                            .into_response()
279                    },
280                }
281            },
282        }
283    }
284}