1pub mod cors_middleware;
2pub(crate) mod dns_rebind_protector;
3pub mod logging_middleware;
4
5use super::types::{BoxFutureResponse, GenericBody, RequestHandler};
6use crate::mcp_http::{McpAppState, MiddlewareNext};
7use crate::mcp_server::error::TransportServerResult;
8use http::{Request, Response};
9use std::sync::Arc;
10
11#[async_trait::async_trait]
12pub trait Middleware: Send + Sync + 'static {
13 async fn handle<'req>(
14 &self,
15 req: Request<&'req str>,
16 state: Arc<McpAppState>,
17 next: MiddlewareNext<'req>,
18 ) -> TransportServerResult<Response<GenericBody>>;
19}
20
21pub fn compose(
23 middlewares: &Vec<Arc<dyn Middleware>>,
24 final_handler: RequestHandler,
25) -> RequestHandler {
26 let mut handler = final_handler;
27
28 for mw in middlewares.iter().rev() {
29 let mw = mw.clone();
30 let next = handler.clone();
31
32 handler = Arc::new(move |req: Request<&str>, state: Arc<McpAppState>| {
33 let mw = mw.clone();
34 let next = next.clone();
35
36 Box::pin(async move { mw.handle(req, state, next).await }) as BoxFutureResponse<'_>
37 });
38 }
39
40 handler
41}
42
43#[cfg(test)]
44mod tests {
45 use super::*;
46 use crate::schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities};
47 use crate::{
48 id_generator::{FastIdGenerator, UuidGenerator},
49 mcp_http::{
50 middleware::{cors_middleware::CorsMiddleware, logging_middleware::LoggingMiddleware},
51 types::GenericBodyExt,
52 },
53 mcp_server::{error::TransportServerError, ServerHandler, ToMcpServerHandler},
54 session_store::InMemorySessionStore,
55 };
56 use async_trait::async_trait;
57 use http::{HeaderName, Request, Response, StatusCode};
58 use http_body_util::BodyExt;
59 use std::{
60 sync::{Arc, Mutex},
61 time::Duration,
62 };
63 struct TestHandler;
64 impl ServerHandler for TestHandler {}
65
66 fn app_state() -> Arc<McpAppState> {
67 let handler = TestHandler {};
68
69 Arc::new(McpAppState {
70 session_store: Arc::new(InMemorySessionStore::new()),
71 id_generator: Arc::new(UuidGenerator {}),
72 stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
73 server_details: Arc::new(InitializeResult {
74 capabilities: ServerCapabilities {
75 ..Default::default()
76 },
77 instructions: None,
78 meta: None,
79 protocol_version: ProtocolVersion::V2025_06_18.to_string(),
80 server_info: Implementation {
81 name: "server".to_string(),
82 title: None,
83 version: "0.1.0".to_string(),
84 },
85 }),
86 handler: handler.to_mcp_server_handler(),
87 ping_interval: Duration::from_secs(15),
88 transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()),
89 enable_json_response: false,
90 event_store: None,
91 })
92 }
93
94 async fn response_string(res: Response<GenericBody>) -> String {
96 let (_parts, body) = res.into_parts();
97 let bytes = body.collect().await.unwrap().to_bytes();
98 String::from_utf8(bytes.to_vec()).unwrap()
99 }
100
101 #[derive(Clone)]
103 struct TestMiddleware {
104 id: usize,
105 request_calls: Arc<Mutex<Vec<(usize, String, Vec<(String, String)>)>>>,
106 response_calls: Arc<Mutex<Vec<(usize, u16, Vec<(String, String)>)>>>,
107 add_req_header: Option<(String, String)>,
108 add_res_header: Option<(String, String)>,
109
110 early_return_status: Option<StatusCode>,
112 early_return_body: Option<String>,
113
114 fail_request: bool,
115 fail_response: bool,
116 }
117
118 impl TestMiddleware {
119 fn new(id: usize) -> Self {
120 Self {
121 id,
122 request_calls: Arc::new(Mutex::new(Vec::new())),
123 response_calls: Arc::new(Mutex::new(Vec::new())),
124 add_req_header: None,
125 add_res_header: None,
126 early_return_status: None,
127 early_return_body: None,
128 fail_request: false,
129 fail_response: false,
130 }
131 }
132
133 fn with_req_header(mut self, name: &str, value: &str) -> Self {
134 self.add_req_header = Some((name.to_string(), value.to_string()));
135 self
136 }
137
138 fn with_res_header(mut self, name: &str, value: &str) -> Self {
139 self.add_res_header = Some((name.to_string(), value.to_string()));
140 self
141 }
142
143 fn early_return_200(mut self) -> Self {
144 self.early_return_status = Some(StatusCode::OK);
145 self.early_return_body = Some(format!("early-{}", self.id));
146 self
147 }
148
149 #[allow(unused)]
150 fn early_return(mut self, status: StatusCode, body: impl Into<String>) -> Self {
151 self.early_return_status = Some(status);
152 self.early_return_body = Some(body.into());
153 self
154 }
155
156 fn fail_request(mut self) -> Self {
157 self.fail_request = true;
158 self
159 }
160
161 fn fail_response(mut self) -> Self {
162 self.fail_response = true;
163 self
164 }
165 }
166
167 #[async_trait]
168 impl Middleware for TestMiddleware {
169 async fn handle<'req>(
170 &self,
171 mut req: Request<&'req str>,
172 state: Arc<McpAppState>,
173 next: MiddlewareNext<'req>,
174 ) -> TransportServerResult<Response<GenericBody>> {
175 let headers = req
177 .headers()
178 .iter()
179 .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
180 .collect();
181 self.request_calls
182 .lock()
183 .unwrap()
184 .push((self.id, req.body().to_string(), headers));
185
186 if self.fail_request {
187 return Err(TransportServerError::HttpError(format!(
188 "middleware {} failed request",
189 self.id
190 )));
191 }
192
193 if let Some((name, value)) = &self.add_req_header {
195 req.headers_mut().insert(
196 HeaderName::from_bytes(name.as_bytes()).unwrap(),
197 value.parse().unwrap(),
198 );
199 }
200
201 if let (Some(status), Some(body)) = (&self.early_return_status, &self.early_return_body)
203 {
204 return Ok(Response::builder()
205 .status(*status)
206 .body(GenericBody::from_string(body.to_string()))
207 .unwrap());
208 }
209
210 let mut res = next(req, state).await?;
212 if let Some((name, value)) = &self.add_res_header {
214 res.headers_mut().insert(
215 HeaderName::from_bytes(name.as_bytes()).unwrap(),
216 value.parse().unwrap(),
217 );
218 }
219
220 if self.fail_response {
221 return Err(TransportServerError::HttpError(format!(
222 "middleware {} failed response",
223 self.id
224 )));
225 }
226
227 let headers = res
229 .headers()
230 .iter()
231 .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
232 .collect();
233
234 self.response_calls
235 .lock()
236 .unwrap()
237 .push((self.id, res.status().as_u16(), headers));
238
239 Ok(res)
240 }
241 }
242
243 fn final_handler(body: &'static str, status: StatusCode) -> RequestHandler {
245 Arc::new(move |_req, _| {
246 let resp = Response::builder()
247 .status(status)
248 .body(GenericBody::from_string(body.to_string()))
249 .unwrap();
250 Box::pin(async move { Ok(resp) })
251 })
252 }
253
254 #[tokio::test]
258 async fn test_middleware_order() {
259 let mw1 = Arc::new(TestMiddleware::new(1));
260 let mw2 = Arc::new(TestMiddleware::new(2));
261 let mw3 = Arc::new(TestMiddleware::new(3));
262
263 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
264 let handler = final_handler("final", StatusCode::OK);
265 let composed = compose(&middlewares, handler);
266
267 let req = Request::builder().body("").unwrap();
268 let _ = composed(req, app_state()).await.unwrap();
269
270 let rc3 = mw3.request_calls.lock().unwrap();
272 let rc2 = mw2.request_calls.lock().unwrap();
273 let rc1 = mw1.request_calls.lock().unwrap();
274 assert_eq!(rc3[0].0, 3);
275 assert_eq!(rc2[0].0, 2);
276 assert_eq!(rc1[0].0, 1);
277
278 let pc1 = mw1.response_calls.lock().unwrap();
280 let pc2 = mw2.response_calls.lock().unwrap();
281 let pc3 = mw3.response_calls.lock().unwrap();
282 assert_eq!(pc1[0].0, 1);
283 assert_eq!(pc2[0].0, 2);
284 assert_eq!(pc3[0].0, 3);
285 }
286
287 #[tokio::test]
289 async fn test_request_header_propagation() {
290 let mw1 = Arc::new(TestMiddleware::new(1).with_req_header("x-mid", "1"));
291 let mw2 = Arc::new(TestMiddleware::new(2));
292
293 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
294 let handler = final_handler("ok", StatusCode::OK);
295 let composed = compose(&middlewares, handler);
296
297 let req = Request::builder().body("").unwrap();
298 let _ = composed(req, app_state()).await.unwrap();
299
300 let rc = mw2.request_calls.lock().unwrap();
301 let hdr = rc[0].2.iter().find(|(k, _)| k == "x-mid").map(|(_, v)| v);
302 assert_eq!(hdr, Some(&"1".to_string()));
303 }
304
305 #[tokio::test]
307 async fn test_response_header_propagation() {
308 let mw1 = Arc::new(TestMiddleware::new(1));
309 let mw2 = Arc::new(TestMiddleware::new(2).with_res_header("x-mid", "1"));
310
311 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
312 let handler = final_handler("ok", StatusCode::OK);
313 let composed = compose(&middlewares, handler);
314
315 let req = Request::builder().body("").unwrap();
316 let res = composed(req, app_state()).await.unwrap();
317
318 let pc1 = mw1.response_calls.lock().unwrap();
319
320 let hdr = pc1[0].2.iter().find(|(k, _)| k == "x-mid").map(|(_, v)| v);
321 assert_eq!(hdr, Some(&"1".to_string()));
322
323 assert_eq!(res.headers().get("x-mid").unwrap().to_str().unwrap(), "1");
324 }
325
326 #[tokio::test]
328 async fn test_early_return_stops_chain() {
329 let mw1 = Arc::new(TestMiddleware::new(1).early_return_200());
330 let mw2 = Arc::new(TestMiddleware::new(2));
331 let mw3 = Arc::new(TestMiddleware::new(3));
332
333 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
334 let handler = final_handler("should-not-see", StatusCode::OK);
335 let composed = compose(&middlewares, handler);
336
337 let req = Request::builder().body("").unwrap();
338 let res = composed(req, app_state()).await.unwrap();
339
340 assert_eq!(response_string(res).await, "early-1");
341
342 assert!(mw2.request_calls.lock().unwrap().is_empty());
343 assert!(mw3.request_calls.lock().unwrap().is_empty());
344 }
345
346 #[tokio::test]
348 async fn test_request_error_stops_response_chain() {
349 let mw1 = Arc::new(TestMiddleware::new(1).fail_request());
350 let mw2 = Arc::new(TestMiddleware::new(2));
351
352 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
353 let handler = final_handler("ok", StatusCode::OK);
354 let composed = compose(&middlewares, handler);
355
356 let req = Request::builder().body("").unwrap();
357 let result = composed(req, app_state()).await;
358
359 assert!(result.is_err());
360 assert!(mw2.request_calls.lock().unwrap().is_empty());
361 assert!(mw2.response_calls.lock().unwrap().is_empty());
362 }
363
364 #[tokio::test]
366 async fn test_response_error_after_next() {
367 let mw1 = Arc::new(TestMiddleware::new(1).fail_response());
368 let mw2 = Arc::new(TestMiddleware::new(2));
369
370 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
371 let handler = final_handler("ok", StatusCode::OK);
372 let composed = compose(&middlewares, handler);
373
374 let req = Request::builder().body("").unwrap();
375 let result = composed(req, app_state()).await;
376
377 assert!(result.is_err());
378 assert!(!mw1.request_calls.lock().unwrap().is_empty());
379 assert!(mw1.response_calls.lock().unwrap().is_empty());
381 }
382
383 #[tokio::test]
385 async fn test_no_middleware() {
386 let middlewares: Vec<Arc<dyn Middleware>> = vec![];
387 let handler = final_handler("direct", StatusCode::IM_A_TEAPOT);
388 let composed = compose(&middlewares, handler);
389
390 let req = Request::builder().body("").unwrap();
391 let res = composed(req, app_state()).await.unwrap();
392
393 assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
394 assert_eq!(response_string(res).await, "direct");
395 }
396
397 #[tokio::test]
399 async fn test_multiple_headers_accumulate() {
400 let mw1 = Arc::new(
401 TestMiddleware::new(1)
402 .with_req_header("x-a", "1")
403 .with_res_header("x-b", "1"),
404 );
405 let mw2 = Arc::new(
406 TestMiddleware::new(2)
407 .with_req_header("x-c", "2")
408 .with_res_header("x-d", "2"),
409 );
410
411 let mw3 = Arc::new(TestMiddleware::new(3));
412
413 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
414 let handler = final_handler("ok", StatusCode::OK);
415 let composed = compose(&middlewares, handler);
416
417 let req = Request::builder().body("").unwrap();
418 let res = composed(req, app_state()).await.unwrap();
419
420 let h = res.headers();
421 assert_eq!(h["x-b"], "1");
422 assert_eq!(h["x-d"], "2");
423
424 assert!(!h.contains_key("x-a"));
426 assert!(!h.contains_key("x-c"));
427
428 let req_calls_mw3 = mw3.request_calls.lock().unwrap();
430 let req_headers = &req_calls_mw3[0].2;
431
432 assert!(req_headers.iter().any(|(k, v)| k == "x-a" && v == "1"));
433 assert!(req_headers.iter().any(|(k, v)| k == "x-c" && v == "2"));
434 }
435
436 #[tokio::test]
438 async fn test_request_body_unchanged() {
439 let mw1 = Arc::new(TestMiddleware::new(1));
440 let mw2 = Arc::new(TestMiddleware::new(2));
441
442 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
443 let handler: RequestHandler = Arc::new(move |req, _| {
444 let body = req.into_body().to_string();
445 Box::pin(async move {
446 Ok(Response::builder()
447 .body(GenericBody::from_string(format!("echo:{body}")))
448 .unwrap())
449 })
450 });
451 let composed = compose(&middlewares, handler);
452
453 let req = Request::builder().body("secret-payload").unwrap();
454 let res = composed(req, app_state()).await.unwrap();
455 assert_eq!(response_string(res).await, "echo:secret-payload");
456 }
457
458 #[tokio::test]
460 async fn test_cors_and_logger_integration() {
461 let cors = Arc::new(CorsMiddleware::permissive());
462 let logger = Arc::new(LoggingMiddleware);
463
464 let middlewares: Vec<Arc<dyn Middleware>> = vec![cors.clone(), logger.clone()];
467 let handler = final_handler("ok", StatusCode::OK);
468 let composed = compose(&middlewares, handler);
469
470 let req = Request::builder()
471 .method(http::Method::GET)
472 .uri("/api")
473 .header("Origin", "https://example.com")
474 .body("")
475 .unwrap();
476
477 let res = composed(req, app_state()).await.unwrap();
478
479 assert_eq!(
481 res.headers()["access-control-allow-origin"],
482 "https://example.com"
483 );
484 assert_eq!(res.headers()["access-control-allow-credentials"], "true");
485 }
486}