rust_mcp_sdk/mcp_http/
middleware.rs

1pub mod cors_middleware;
2pub(crate) mod dns_rebind_protector;
3pub mod logging_middleware;
4
5use super::types::{BoxFutureResponse, GenericBody, RequestHandler};
6use crate::mcp_http::{McpAppState, MiddlewareNext};
7use crate::mcp_server::error::TransportServerResult;
8use http::{Request, Response};
9use std::sync::Arc;
10
11#[async_trait::async_trait]
12pub trait Middleware: Send + Sync + 'static {
13    async fn handle<'req>(
14        &self,
15        req: Request<&'req str>,
16        state: Arc<McpAppState>,
17        next: MiddlewareNext<'req>,
18    ) -> TransportServerResult<Response<GenericBody>>;
19}
20
21/// Build the final handler by folding the middlewares **in reverse**.
22pub fn compose(
23    middlewares: &Vec<Arc<dyn Middleware>>,
24    final_handler: RequestHandler,
25) -> RequestHandler {
26    let mut handler = final_handler;
27
28    for mw in middlewares.iter().rev() {
29        let mw = mw.clone();
30        let next = handler.clone();
31
32        handler = Arc::new(move |req: Request<&str>, state: Arc<McpAppState>| {
33            let mw = mw.clone();
34            let next = next.clone();
35
36            Box::pin(async move { mw.handle(req, state, next).await }) as BoxFutureResponse<'_>
37        });
38    }
39
40    handler
41}
42
43#[cfg(test)]
44mod tests {
45    use super::*;
46    use crate::schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities};
47    use crate::{
48        id_generator::{FastIdGenerator, UuidGenerator},
49        mcp_http::{
50            middleware::{cors_middleware::CorsMiddleware, logging_middleware::LoggingMiddleware},
51            types::GenericBodyExt,
52        },
53        mcp_server::{error::TransportServerError, ServerHandler, ToMcpServerHandler},
54        session_store::InMemorySessionStore,
55    };
56    use async_trait::async_trait;
57    use http::{HeaderName, Request, Response, StatusCode};
58    use http_body_util::BodyExt;
59    use std::{
60        sync::{Arc, Mutex},
61        time::Duration,
62    };
63    struct TestHandler;
64    impl ServerHandler for TestHandler {}
65
66    fn app_state() -> Arc<McpAppState> {
67        let handler = TestHandler {};
68
69        Arc::new(McpAppState {
70            session_store: Arc::new(InMemorySessionStore::new()),
71            id_generator: Arc::new(UuidGenerator {}),
72            stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
73            server_details: Arc::new(InitializeResult {
74                capabilities: ServerCapabilities {
75                    ..Default::default()
76                },
77                instructions: None,
78                meta: None,
79                protocol_version: ProtocolVersion::V2025_06_18.to_string(),
80                server_info: Implementation {
81                    name: "server".to_string(),
82                    title: None,
83                    version: "0.1.0".to_string(),
84                },
85            }),
86            handler: handler.to_mcp_server_handler(),
87            ping_interval: Duration::from_secs(15),
88            transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()),
89            enable_json_response: false,
90            event_store: None,
91        })
92    }
93
94    /// Helper: Convert response to string
95    async fn response_string(res: Response<GenericBody>) -> String {
96        let (_parts, body) = res.into_parts();
97        let bytes = body.collect().await.unwrap().to_bytes();
98        String::from_utf8(bytes.to_vec()).unwrap()
99    }
100
101    /// Test Middleware – records everything, modifies req/res, supports early return
102    #[derive(Clone)]
103    struct TestMiddleware {
104        id: usize,
105        request_calls: Arc<Mutex<Vec<(usize, String, Vec<(String, String)>)>>>,
106        response_calls: Arc<Mutex<Vec<(usize, u16, Vec<(String, String)>)>>>,
107        add_req_header: Option<(String, String)>,
108        add_res_header: Option<(String, String)>,
109
110        // ---- early return (clone-able) ----
111        early_return_status: Option<StatusCode>,
112        early_return_body: Option<String>,
113
114        fail_request: bool,
115        fail_response: bool,
116    }
117
118    impl TestMiddleware {
119        fn new(id: usize) -> Self {
120            Self {
121                id,
122                request_calls: Arc::new(Mutex::new(Vec::new())),
123                response_calls: Arc::new(Mutex::new(Vec::new())),
124                add_req_header: None,
125                add_res_header: None,
126                early_return_status: None,
127                early_return_body: None,
128                fail_request: false,
129                fail_response: false,
130            }
131        }
132
133        fn with_req_header(mut self, name: &str, value: &str) -> Self {
134            self.add_req_header = Some((name.to_string(), value.to_string()));
135            self
136        }
137
138        fn with_res_header(mut self, name: &str, value: &str) -> Self {
139            self.add_res_header = Some((name.to_string(), value.to_string()));
140            self
141        }
142
143        fn early_return_200(mut self) -> Self {
144            self.early_return_status = Some(StatusCode::OK);
145            self.early_return_body = Some(format!("early-{}", self.id));
146            self
147        }
148
149        #[allow(unused)]
150        fn early_return(mut self, status: StatusCode, body: impl Into<String>) -> Self {
151            self.early_return_status = Some(status);
152            self.early_return_body = Some(body.into());
153            self
154        }
155
156        fn fail_request(mut self) -> Self {
157            self.fail_request = true;
158            self
159        }
160
161        fn fail_response(mut self) -> Self {
162            self.fail_response = true;
163            self
164        }
165    }
166
167    #[async_trait]
168    impl Middleware for TestMiddleware {
169        async fn handle<'req>(
170            &self,
171            mut req: Request<&'req str>,
172            state: Arc<McpAppState>,
173            next: MiddlewareNext<'req>,
174        ) -> TransportServerResult<Response<GenericBody>> {
175            // ---- record request -------------------------------------------------
176            let headers = req
177                .headers()
178                .iter()
179                .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
180                .collect();
181            self.request_calls
182                .lock()
183                .unwrap()
184                .push((self.id, req.body().to_string(), headers));
185
186            if self.fail_request {
187                return Err(TransportServerError::HttpError(format!(
188                    "middleware {} failed request",
189                    self.id
190                )));
191            }
192
193            // ---- add request header --------------------------------------------
194            if let Some((name, value)) = &self.add_req_header {
195                req.headers_mut().insert(
196                    HeaderName::from_bytes(name.as_bytes()).unwrap(),
197                    value.parse().unwrap(),
198                );
199            }
200
201            // ---- early return ---------------------------------------------------
202            if let (Some(status), Some(body)) = (&self.early_return_status, &self.early_return_body)
203            {
204                return Ok(Response::builder()
205                    .status(*status)
206                    .body(GenericBody::from_string(body.to_string()))
207                    .unwrap());
208            }
209
210            // ---- call next ------------------------------------------------------
211            let mut res = next(req, state).await?;
212            // ---- add response header --------------------------------------------
213            if let Some((name, value)) = &self.add_res_header {
214                res.headers_mut().insert(
215                    HeaderName::from_bytes(name.as_bytes()).unwrap(),
216                    value.parse().unwrap(),
217                );
218            }
219
220            if self.fail_response {
221                return Err(TransportServerError::HttpError(format!(
222                    "middleware {} failed response",
223                    self.id
224                )));
225            }
226
227            // ---- record response ------------------------------------------------
228            let headers = res
229                .headers()
230                .iter()
231                .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
232                .collect();
233
234            self.response_calls
235                .lock()
236                .unwrap()
237                .push((self.id, res.status().as_u16(), headers));
238
239            Ok(res)
240        }
241    }
242
243    /// Final handler – returns a fixed response
244    fn final_handler(body: &'static str, status: StatusCode) -> RequestHandler {
245        Arc::new(move |_req, _| {
246            let resp = Response::builder()
247                .status(status)
248                .body(GenericBody::from_string(body.to_string()))
249                .unwrap();
250            Box::pin(async move { Ok(resp) })
251        })
252    }
253
254    // TESTS
255
256    /// Middleware order (request → final → response)
257    #[tokio::test]
258    async fn test_middleware_order() {
259        let mw1 = Arc::new(TestMiddleware::new(1));
260        let mw2 = Arc::new(TestMiddleware::new(2));
261        let mw3 = Arc::new(TestMiddleware::new(3));
262
263        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
264        let handler = final_handler("final", StatusCode::OK);
265        let composed = compose(&middlewares, handler);
266
267        let req = Request::builder().body("").unwrap();
268        let _ = composed(req, app_state()).await.unwrap();
269
270        // request order: 3 → 2 → 1 → final
271        let rc3 = mw3.request_calls.lock().unwrap();
272        let rc2 = mw2.request_calls.lock().unwrap();
273        let rc1 = mw1.request_calls.lock().unwrap();
274        assert_eq!(rc3[0].0, 3);
275        assert_eq!(rc2[0].0, 2);
276        assert_eq!(rc1[0].0, 1);
277
278        // response order: 1 → 2 → 3
279        let pc1 = mw1.response_calls.lock().unwrap();
280        let pc2 = mw2.response_calls.lock().unwrap();
281        let pc3 = mw3.response_calls.lock().unwrap();
282        assert_eq!(pc1[0].0, 1);
283        assert_eq!(pc2[0].0, 2);
284        assert_eq!(pc3[0].0, 3);
285    }
286
287    /// Request header added by earlier middleware is visible later
288    #[tokio::test]
289    async fn test_request_header_propagation() {
290        let mw1 = Arc::new(TestMiddleware::new(1).with_req_header("x-mid", "1"));
291        let mw2 = Arc::new(TestMiddleware::new(2));
292
293        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
294        let handler = final_handler("ok", StatusCode::OK);
295        let composed = compose(&middlewares, handler);
296
297        let req = Request::builder().body("").unwrap();
298        let _ = composed(req, app_state()).await.unwrap();
299
300        let rc = mw2.request_calls.lock().unwrap();
301        let hdr = rc[0].2.iter().find(|(k, _)| k == "x-mid").map(|(_, v)| v);
302        assert_eq!(hdr, Some(&"1".to_string()));
303    }
304
305    /// Response header added by later middleware is visible earlier
306    #[tokio::test]
307    async fn test_response_header_propagation() {
308        let mw1 = Arc::new(TestMiddleware::new(1));
309        let mw2 = Arc::new(TestMiddleware::new(2).with_res_header("x-mid", "1"));
310
311        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
312        let handler = final_handler("ok", StatusCode::OK);
313        let composed = compose(&middlewares, handler);
314
315        let req = Request::builder().body("").unwrap();
316        let res = composed(req, app_state()).await.unwrap();
317
318        let pc1 = mw1.response_calls.lock().unwrap();
319
320        let hdr = pc1[0].2.iter().find(|(k, _)| k == "x-mid").map(|(_, v)| v);
321        assert_eq!(hdr, Some(&"1".to_string()));
322
323        assert_eq!(res.headers().get("x-mid").unwrap().to_str().unwrap(), "1");
324    }
325
326    /// Early return stops the chain
327    #[tokio::test]
328    async fn test_early_return_stops_chain() {
329        let mw1 = Arc::new(TestMiddleware::new(1).early_return_200());
330        let mw2 = Arc::new(TestMiddleware::new(2));
331        let mw3 = Arc::new(TestMiddleware::new(3));
332
333        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
334        let handler = final_handler("should-not-see", StatusCode::OK);
335        let composed = compose(&middlewares, handler);
336
337        let req = Request::builder().body("").unwrap();
338        let res = composed(req, app_state()).await.unwrap();
339
340        assert_eq!(response_string(res).await, "early-1");
341
342        assert!(mw2.request_calls.lock().unwrap().is_empty());
343        assert!(mw3.request_calls.lock().unwrap().is_empty());
344    }
345
346    /// Request error stops response processing
347    #[tokio::test]
348    async fn test_request_error_stops_response_chain() {
349        let mw1 = Arc::new(TestMiddleware::new(1).fail_request());
350        let mw2 = Arc::new(TestMiddleware::new(2));
351
352        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
353        let handler = final_handler("ok", StatusCode::OK);
354        let composed = compose(&middlewares, handler);
355
356        let req = Request::builder().body("").unwrap();
357        let result = composed(req, app_state()).await;
358
359        assert!(result.is_err());
360        assert!(mw2.request_calls.lock().unwrap().is_empty());
361        assert!(mw2.response_calls.lock().unwrap().is_empty());
362    }
363
364    ///Response error after next()
365    #[tokio::test]
366    async fn test_response_error_after_next() {
367        let mw1 = Arc::new(TestMiddleware::new(1).fail_response());
368        let mw2 = Arc::new(TestMiddleware::new(2));
369
370        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
371        let handler = final_handler("ok", StatusCode::OK);
372        let composed = compose(&middlewares, handler);
373
374        let req = Request::builder().body("").unwrap();
375        let result = composed(req, app_state()).await;
376
377        assert!(result.is_err());
378        assert!(!mw1.request_calls.lock().unwrap().is_empty());
379        // response_calls is empty because we error before recording
380        assert!(mw1.response_calls.lock().unwrap().is_empty());
381    }
382
383    /// No middleware → direct handler
384    #[tokio::test]
385    async fn test_no_middleware() {
386        let middlewares: Vec<Arc<dyn Middleware>> = vec![];
387        let handler = final_handler("direct", StatusCode::IM_A_TEAPOT);
388        let composed = compose(&middlewares, handler);
389
390        let req = Request::builder().body("").unwrap();
391        let res = composed(req, app_state()).await.unwrap();
392
393        assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
394        assert_eq!(response_string(res).await, "direct");
395    }
396
397    /// Multiple headers accumulate correctly
398    #[tokio::test]
399    async fn test_multiple_headers_accumulate() {
400        let mw1 = Arc::new(
401            TestMiddleware::new(1)
402                .with_req_header("x-a", "1")
403                .with_res_header("x-b", "1"),
404        );
405        let mw2 = Arc::new(
406            TestMiddleware::new(2)
407                .with_req_header("x-c", "2")
408                .with_res_header("x-d", "2"),
409        );
410
411        let mw3 = Arc::new(TestMiddleware::new(3));
412
413        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
414        let handler = final_handler("ok", StatusCode::OK);
415        let composed = compose(&middlewares, handler);
416
417        let req = Request::builder().body("").unwrap();
418        let res = composed(req, app_state()).await.unwrap();
419
420        let h = res.headers();
421        assert_eq!(h["x-b"], "1");
422        assert_eq!(h["x-d"], "2");
423
424        // Request headers are NOT in response
425        assert!(!h.contains_key("x-a"));
426        assert!(!h.contains_key("x-c"));
427
428        // But they were added to the request
429        let req_calls_mw3 = mw3.request_calls.lock().unwrap();
430        let req_headers = &req_calls_mw3[0].2;
431
432        assert!(req_headers.iter().any(|(k, v)| k == "x-a" && v == "1"));
433        assert!(req_headers.iter().any(|(k, v)| k == "x-c" && v == "2"));
434    }
435
436    /// Request body is passed unchanged
437    #[tokio::test]
438    async fn test_request_body_unchanged() {
439        let mw1 = Arc::new(TestMiddleware::new(1));
440        let mw2 = Arc::new(TestMiddleware::new(2));
441
442        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
443        let handler: RequestHandler = Arc::new(move |req, _| {
444            let body = req.into_body().to_string();
445            Box::pin(async move {
446                Ok(Response::builder()
447                    .body(GenericBody::from_string(format!("echo:{body}")))
448                    .unwrap())
449            })
450        });
451        let composed = compose(&middlewares, handler);
452
453        let req = Request::builder().body("secret-payload").unwrap();
454        let res = composed(req, app_state()).await.unwrap();
455        assert_eq!(response_string(res).await, "echo:secret-payload");
456    }
457
458    // Integration: CORS + Logger (order matters)
459    #[tokio::test]
460    async fn test_cors_and_logger_integration() {
461        let cors = Arc::new(CorsMiddleware::permissive());
462        let logger = Arc::new(LoggingMiddleware);
463
464        // Order in the vector is the order they are *registered*.
465        // compose folds in reverse, so logger runs *first* (request) and *last* (response).
466        let middlewares: Vec<Arc<dyn Middleware>> = vec![cors.clone(), logger.clone()];
467        let handler = final_handler("ok", StatusCode::OK);
468        let composed = compose(&middlewares, handler);
469
470        let req = Request::builder()
471            .method(http::Method::GET)
472            .uri("/api")
473            .header("Origin", "https://example.com")
474            .body("")
475            .unwrap();
476
477        let res = composed(req, app_state()).await.unwrap();
478
479        // CORS headers added by CorsMiddleware
480        assert_eq!(
481            res.headers()["access-control-allow-origin"],
482            "https://example.com"
483        );
484        assert_eq!(res.headers()["access-control-allow-credentials"], "true");
485    }
486}