rust_mcp_sdk/mcp_http/
middleware.rs

1#[cfg(feature = "auth")]
2mod auth_middleware;
3mod cors_middleware;
4mod dns_rebind_protector;
5pub mod logging_middleware;
6
7use super::types::{GenericBody, RequestHandler};
8use crate::mcp_http::{McpAppState, MiddlewareNext};
9use crate::mcp_server::error::TransportServerResult;
10#[cfg(feature = "auth")]
11pub(crate) use auth_middleware::*;
12pub use cors_middleware::*;
13pub(crate) use dns_rebind_protector::*;
14use http::{Request, Response};
15use std::sync::Arc;
16
17#[async_trait::async_trait]
18pub trait Middleware: Send + Sync + 'static {
19    async fn handle<'req>(
20        &self,
21        req: Request<&'req str>,
22        state: Arc<McpAppState>,
23        next: MiddlewareNext<'req>,
24    ) -> TransportServerResult<Response<GenericBody>>;
25}
26
27/// Build the final handler by folding the middlewares **in reverse**.
28/// Each middleware and handler is consumed exactly once.
29pub fn compose<'a, I>(middlewares: I, final_handler: RequestHandler) -> RequestHandler
30where
31    I: IntoIterator<Item = &'a Arc<dyn Middleware>>,
32    I::IntoIter: DoubleEndedIterator,
33{
34    // Start with the final handler
35    let mut handler = final_handler;
36
37    // Fold middlewares in reverse order
38    for mw in middlewares.into_iter().rev() {
39        let mw = Arc::clone(mw);
40        let next = handler;
41
42        // Each loop iteration consumes `next` and returns a new boxed FnOnce
43        handler = Box::new(move |req: Request<&str>, state: Arc<McpAppState>| {
44            let mw = Arc::clone(&mw);
45            Box::pin(async move { mw.handle(req, state, next).await })
46        });
47    }
48
49    handler
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55    use crate::schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities};
56    use crate::{
57        id_generator::{FastIdGenerator, UuidGenerator},
58        mcp_http::{
59            middleware::{cors_middleware::CorsMiddleware, logging_middleware::LoggingMiddleware},
60            types::GenericBodyExt,
61        },
62        mcp_server::{error::TransportServerError, ServerHandler, ToMcpServerHandler},
63        session_store::InMemorySessionStore,
64    };
65    use async_trait::async_trait;
66    use http::{HeaderName, Request, Response, StatusCode};
67    use http_body_util::BodyExt;
68    use std::{
69        sync::{Arc, Mutex},
70        time::Duration,
71    };
72    struct TestHandler;
73    impl ServerHandler for TestHandler {}
74
75    fn app_state() -> Arc<McpAppState> {
76        let handler = TestHandler {};
77
78        Arc::new(McpAppState {
79            session_store: Arc::new(InMemorySessionStore::new()),
80            id_generator: Arc::new(UuidGenerator {}),
81            stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
82            server_details: Arc::new(InitializeResult {
83                capabilities: ServerCapabilities {
84                    ..Default::default()
85                },
86                instructions: None,
87                meta: None,
88                protocol_version: ProtocolVersion::V2025_06_18.to_string(),
89                server_info: Implementation {
90                    name: "server".to_string(),
91                    title: None,
92                    version: "0.1.0".to_string(),
93                },
94            }),
95            handler: handler.to_mcp_server_handler(),
96            ping_interval: Duration::from_secs(15),
97            transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()),
98            enable_json_response: false,
99            event_store: None,
100        })
101    }
102
103    /// Helper: Convert response to string
104    async fn response_string(res: Response<GenericBody>) -> String {
105        let (_parts, body) = res.into_parts();
106        let bytes = body.collect().await.unwrap().to_bytes();
107        String::from_utf8(bytes.to_vec()).unwrap()
108    }
109
110    /// Test Middleware – records everything, modifies req/res, supports early return
111    #[derive(Clone)]
112    struct TestMiddleware {
113        id: usize,
114        request_calls: Arc<Mutex<Vec<(usize, String, Vec<(String, String)>)>>>,
115        response_calls: Arc<Mutex<Vec<(usize, u16, Vec<(String, String)>)>>>,
116        add_req_header: Option<(String, String)>,
117        add_res_header: Option<(String, String)>,
118
119        // ---- early return (clone-able) ----
120        early_return_status: Option<StatusCode>,
121        early_return_body: Option<String>,
122
123        fail_request: bool,
124        fail_response: bool,
125    }
126
127    impl TestMiddleware {
128        fn new(id: usize) -> Self {
129            Self {
130                id,
131                request_calls: Arc::new(Mutex::new(Vec::new())),
132                response_calls: Arc::new(Mutex::new(Vec::new())),
133                add_req_header: None,
134                add_res_header: None,
135                early_return_status: None,
136                early_return_body: None,
137                fail_request: false,
138                fail_response: false,
139            }
140        }
141
142        fn with_req_header(mut self, name: &str, value: &str) -> Self {
143            self.add_req_header = Some((name.to_string(), value.to_string()));
144            self
145        }
146
147        fn with_res_header(mut self, name: &str, value: &str) -> Self {
148            self.add_res_header = Some((name.to_string(), value.to_string()));
149            self
150        }
151
152        fn early_return_200(mut self) -> Self {
153            self.early_return_status = Some(StatusCode::OK);
154            self.early_return_body = Some(format!("early-{}", self.id));
155            self
156        }
157
158        #[allow(unused)]
159        fn early_return(mut self, status: StatusCode, body: impl Into<String>) -> Self {
160            self.early_return_status = Some(status);
161            self.early_return_body = Some(body.into());
162            self
163        }
164
165        fn fail_request(mut self) -> Self {
166            self.fail_request = true;
167            self
168        }
169
170        fn fail_response(mut self) -> Self {
171            self.fail_response = true;
172            self
173        }
174    }
175
176    #[async_trait]
177    impl Middleware for TestMiddleware {
178        async fn handle<'req>(
179            &self,
180            mut req: Request<&'req str>,
181            state: Arc<McpAppState>,
182            next: MiddlewareNext<'req>,
183        ) -> TransportServerResult<Response<GenericBody>> {
184            // ---- record request -------------------------------------------------
185            let headers = req
186                .headers()
187                .iter()
188                .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
189                .collect();
190            self.request_calls
191                .lock()
192                .unwrap()
193                .push((self.id, req.body().to_string(), headers));
194
195            if self.fail_request {
196                return Err(TransportServerError::HttpError(format!(
197                    "middleware {} failed request",
198                    self.id
199                )));
200            }
201
202            // ---- add request header --------------------------------------------
203            if let Some((name, value)) = &self.add_req_header {
204                req.headers_mut().insert(
205                    HeaderName::from_bytes(name.as_bytes()).unwrap(),
206                    value.parse().unwrap(),
207                );
208            }
209
210            // ---- early return ---------------------------------------------------
211            if let (Some(status), Some(body)) = (&self.early_return_status, &self.early_return_body)
212            {
213                return Ok(Response::builder()
214                    .status(*status)
215                    .body(GenericBody::from_string(body.to_string()))
216                    .unwrap());
217            }
218
219            // ---- call next ------------------------------------------------------
220            let mut res = next(req, state).await?;
221            // ---- add response header --------------------------------------------
222            if let Some((name, value)) = &self.add_res_header {
223                res.headers_mut().insert(
224                    HeaderName::from_bytes(name.as_bytes()).unwrap(),
225                    value.parse().unwrap(),
226                );
227            }
228
229            if self.fail_response {
230                return Err(TransportServerError::HttpError(format!(
231                    "middleware {} failed response",
232                    self.id
233                )));
234            }
235
236            // ---- record response ------------------------------------------------
237            let headers = res
238                .headers()
239                .iter()
240                .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
241                .collect();
242
243            self.response_calls
244                .lock()
245                .unwrap()
246                .push((self.id, res.status().as_u16(), headers));
247
248            Ok(res)
249        }
250    }
251
252    /// Final handler – returns a fixed response
253    fn final_handler(body: &'static str, status: StatusCode) -> RequestHandler {
254        Box::new(move |_req, _| {
255            let resp = Response::builder()
256                .status(status)
257                .body(GenericBody::from_string(body.to_string()))
258                .unwrap();
259            Box::pin(async move { Ok(resp) })
260        })
261    }
262
263    // TESTS
264
265    /// Middleware order (request → final → response)
266    #[tokio::test]
267    async fn test_middleware_order() {
268        let mw1 = Arc::new(TestMiddleware::new(1));
269        let mw2 = Arc::new(TestMiddleware::new(2));
270        let mw3 = Arc::new(TestMiddleware::new(3));
271
272        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
273
274        let handler = final_handler("final", StatusCode::OK);
275        let composed = compose(&middlewares, handler);
276
277        let req = Request::builder().body("").unwrap();
278        let _ = composed(req, app_state()).await.unwrap();
279
280        // request order: 3 → 2 → 1 → final
281        let rc3 = mw3.request_calls.lock().unwrap();
282        let rc2 = mw2.request_calls.lock().unwrap();
283        let rc1 = mw1.request_calls.lock().unwrap();
284        assert_eq!(rc3[0].0, 3);
285        assert_eq!(rc2[0].0, 2);
286        assert_eq!(rc1[0].0, 1);
287
288        // response order: 1 → 2 → 3
289        let pc1 = mw1.response_calls.lock().unwrap();
290        let pc2 = mw2.response_calls.lock().unwrap();
291        let pc3 = mw3.response_calls.lock().unwrap();
292        assert_eq!(pc1[0].0, 1);
293        assert_eq!(pc2[0].0, 2);
294        assert_eq!(pc3[0].0, 3);
295    }
296
297    /// Request header added by earlier middleware is visible later
298    #[tokio::test]
299    async fn test_request_header_propagation() {
300        let mw1 = Arc::new(TestMiddleware::new(1).with_req_header("x-mid", "1"));
301        let mw2 = Arc::new(TestMiddleware::new(2));
302
303        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
304        let handler = final_handler("ok", StatusCode::OK);
305        let composed = compose(&middlewares, handler);
306
307        let req = Request::builder().body("").unwrap();
308        let _ = composed(req, app_state()).await.unwrap();
309
310        let rc = mw2.request_calls.lock().unwrap();
311        let hdr = rc[0].2.iter().find(|(k, _)| k == "x-mid").map(|(_, v)| v);
312        assert_eq!(hdr, Some(&"1".to_string()));
313    }
314
315    /// Response header added by later middleware is visible earlier
316    #[tokio::test]
317    async fn test_response_header_propagation() {
318        let mw1 = Arc::new(TestMiddleware::new(1));
319        let mw2 = Arc::new(TestMiddleware::new(2).with_res_header("x-mid", "1"));
320
321        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
322        let handler = final_handler("ok", StatusCode::OK);
323        let composed = compose(&middlewares, handler);
324
325        let req = Request::builder().body("").unwrap();
326        let res = composed(req, app_state()).await.unwrap();
327
328        let pc1 = mw1.response_calls.lock().unwrap();
329
330        let hdr = pc1[0].2.iter().find(|(k, _)| k == "x-mid").map(|(_, v)| v);
331        assert_eq!(hdr, Some(&"1".to_string()));
332
333        assert_eq!(res.headers().get("x-mid").unwrap().to_str().unwrap(), "1");
334    }
335
336    /// Early return stops the chain
337    #[tokio::test]
338    async fn test_early_return_stops_chain() {
339        let mw1 = Arc::new(TestMiddleware::new(1).early_return_200());
340        let mw2 = Arc::new(TestMiddleware::new(2));
341        let mw3 = Arc::new(TestMiddleware::new(3));
342
343        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
344        let handler = final_handler("should-not-see", StatusCode::OK);
345        let composed = compose(&middlewares, handler);
346
347        let req = Request::builder().body("").unwrap();
348        let res = composed(req, app_state()).await.unwrap();
349
350        assert_eq!(response_string(res).await, "early-1");
351
352        assert!(mw2.request_calls.lock().unwrap().is_empty());
353        assert!(mw3.request_calls.lock().unwrap().is_empty());
354    }
355
356    /// Request error stops response processing
357    #[tokio::test]
358    async fn test_request_error_stops_response_chain() {
359        let mw1 = Arc::new(TestMiddleware::new(1).fail_request());
360        let mw2 = Arc::new(TestMiddleware::new(2));
361
362        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
363        let handler = final_handler("ok", StatusCode::OK);
364        let composed = compose(&middlewares, handler);
365
366        let req = Request::builder().body("").unwrap();
367        let result = composed(req, app_state()).await;
368
369        assert!(result.is_err());
370        assert!(mw2.request_calls.lock().unwrap().is_empty());
371        assert!(mw2.response_calls.lock().unwrap().is_empty());
372    }
373
374    ///Response error after next()
375    #[tokio::test]
376    async fn test_response_error_after_next() {
377        let mw1 = Arc::new(TestMiddleware::new(1).fail_response());
378        let mw2 = Arc::new(TestMiddleware::new(2));
379
380        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
381        let handler = final_handler("ok", StatusCode::OK);
382        let composed = compose(&middlewares, handler);
383
384        let req = Request::builder().body("").unwrap();
385        let result = composed(req, app_state()).await;
386
387        assert!(result.is_err());
388        assert!(!mw1.request_calls.lock().unwrap().is_empty());
389        // response_calls is empty because we error before recording
390        assert!(mw1.response_calls.lock().unwrap().is_empty());
391    }
392
393    /// No middleware → direct handler
394    #[tokio::test]
395    async fn test_no_middleware() {
396        let middlewares: Vec<Arc<dyn Middleware>> = vec![];
397        let handler = final_handler("direct", StatusCode::IM_A_TEAPOT);
398        let composed = compose(&middlewares, handler);
399
400        let req = Request::builder().body("").unwrap();
401        let res = composed(req, app_state()).await.unwrap();
402
403        assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
404        assert_eq!(response_string(res).await, "direct");
405    }
406
407    /// Multiple headers accumulate correctly
408    #[tokio::test]
409    async fn test_multiple_headers_accumulate() {
410        let mw1 = Arc::new(
411            TestMiddleware::new(1)
412                .with_req_header("x-a", "1")
413                .with_res_header("x-b", "1"),
414        );
415        let mw2 = Arc::new(
416            TestMiddleware::new(2)
417                .with_req_header("x-c", "2")
418                .with_res_header("x-d", "2"),
419        );
420
421        let mw3 = Arc::new(TestMiddleware::new(3));
422
423        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
424        let handler = final_handler("ok", StatusCode::OK);
425        let composed = compose(&middlewares, handler);
426
427        let req = Request::builder().body("").unwrap();
428        let res = composed(req, app_state()).await.unwrap();
429
430        let h = res.headers();
431        assert_eq!(h["x-b"], "1");
432        assert_eq!(h["x-d"], "2");
433
434        // Request headers are NOT in response
435        assert!(!h.contains_key("x-a"));
436        assert!(!h.contains_key("x-c"));
437
438        // But they were added to the request
439        let req_calls_mw3 = mw3.request_calls.lock().unwrap();
440        let req_headers = &req_calls_mw3[0].2;
441
442        assert!(req_headers.iter().any(|(k, v)| k == "x-a" && v == "1"));
443        assert!(req_headers.iter().any(|(k, v)| k == "x-c" && v == "2"));
444    }
445
446    /// Request body is passed unchanged
447    #[tokio::test]
448    async fn test_request_body_unchanged() {
449        let mw1 = Arc::new(TestMiddleware::new(1));
450        let mw2 = Arc::new(TestMiddleware::new(2));
451
452        let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
453        let handler: RequestHandler = Box::new(move |req, _| {
454            let body = req.into_body().to_string();
455            Box::pin(async move {
456                Ok(Response::builder()
457                    .body(GenericBody::from_string(format!("echo:{body}")))
458                    .unwrap())
459            })
460        });
461        let composed = compose(&middlewares, handler);
462
463        let req = Request::builder().body("secret-payload").unwrap();
464        let res = composed(req, app_state()).await.unwrap();
465        assert_eq!(response_string(res).await, "echo:secret-payload");
466    }
467
468    // Integration: CORS + Logger (order matters)
469    #[tokio::test]
470    async fn test_cors_and_logger_integration() {
471        let cors = Arc::new(CorsMiddleware::permissive());
472        let logger = Arc::new(LoggingMiddleware);
473
474        // Order in the vector is the order they are *registered*.
475        // compose folds in reverse, so logger runs *first* (request) and *last* (response).
476        let middlewares: Vec<Arc<dyn Middleware>> = vec![cors.clone(), logger.clone()];
477        let handler = final_handler("ok", StatusCode::OK);
478        let composed = compose(&middlewares, handler);
479
480        let req = Request::builder()
481            .method(http::Method::GET)
482            .uri("/api")
483            .header("Origin", "https://example.com")
484            .body("")
485            .unwrap();
486
487        let res = composed(req, app_state()).await.unwrap();
488
489        // CORS headers added by CorsMiddleware
490        assert_eq!(
491            res.headers()["access-control-allow-origin"],
492            "https://example.com"
493        );
494        assert_eq!(res.headers()["access-control-allow-credentials"], "true");
495    }
496}