1#[cfg(feature = "auth")]
2mod auth_middleware;
3mod cors_middleware;
4mod dns_rebind_protector;
5pub mod logging_middleware;
6
7use super::types::{GenericBody, RequestHandler};
8use crate::mcp_http::{McpAppState, MiddlewareNext};
9use crate::mcp_server::error::TransportServerResult;
10#[cfg(feature = "auth")]
11pub(crate) use auth_middleware::*;
12pub use cors_middleware::*;
13pub(crate) use dns_rebind_protector::*;
14use http::{Request, Response};
15use std::sync::Arc;
16
17#[async_trait::async_trait]
18pub trait Middleware: Send + Sync + 'static {
19 async fn handle<'req>(
20 &self,
21 req: Request<&'req str>,
22 state: Arc<McpAppState>,
23 next: MiddlewareNext<'req>,
24 ) -> TransportServerResult<Response<GenericBody>>;
25}
26
27pub fn compose<'a, I>(middlewares: I, final_handler: RequestHandler) -> RequestHandler
30where
31 I: IntoIterator<Item = &'a Arc<dyn Middleware>>,
32 I::IntoIter: DoubleEndedIterator,
33{
34 let mut handler = final_handler;
36
37 for mw in middlewares.into_iter().rev() {
39 let mw = Arc::clone(mw);
40 let next = handler;
41
42 handler = Box::new(move |req: Request<&str>, state: Arc<McpAppState>| {
44 let mw = Arc::clone(&mw);
45 Box::pin(async move { mw.handle(req, state, next).await })
46 });
47 }
48
49 handler
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55 use crate::schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities};
56 use crate::{
57 id_generator::{FastIdGenerator, UuidGenerator},
58 mcp_http::{
59 middleware::{cors_middleware::CorsMiddleware, logging_middleware::LoggingMiddleware},
60 types::GenericBodyExt,
61 },
62 mcp_server::{error::TransportServerError, ServerHandler, ToMcpServerHandler},
63 session_store::InMemorySessionStore,
64 };
65 use async_trait::async_trait;
66 use http::{HeaderName, Request, Response, StatusCode};
67 use http_body_util::BodyExt;
68 use std::{
69 sync::{Arc, Mutex},
70 time::Duration,
71 };
72 struct TestHandler;
73 impl ServerHandler for TestHandler {}
74
75 fn app_state() -> Arc<McpAppState> {
76 let handler = TestHandler {};
77
78 Arc::new(McpAppState {
79 session_store: Arc::new(InMemorySessionStore::new()),
80 id_generator: Arc::new(UuidGenerator {}),
81 stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
82 server_details: Arc::new(InitializeResult {
83 capabilities: ServerCapabilities {
84 ..Default::default()
85 },
86 instructions: None,
87 meta: None,
88 protocol_version: ProtocolVersion::V2025_06_18.to_string(),
89 server_info: Implementation {
90 name: "server".to_string(),
91 title: None,
92 version: "0.1.0".to_string(),
93 },
94 }),
95 handler: handler.to_mcp_server_handler(),
96 ping_interval: Duration::from_secs(15),
97 transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()),
98 enable_json_response: false,
99 event_store: None,
100 })
101 }
102
103 async fn response_string(res: Response<GenericBody>) -> String {
105 let (_parts, body) = res.into_parts();
106 let bytes = body.collect().await.unwrap().to_bytes();
107 String::from_utf8(bytes.to_vec()).unwrap()
108 }
109
110 #[derive(Clone)]
112 struct TestMiddleware {
113 id: usize,
114 request_calls: Arc<Mutex<Vec<(usize, String, Vec<(String, String)>)>>>,
115 response_calls: Arc<Mutex<Vec<(usize, u16, Vec<(String, String)>)>>>,
116 add_req_header: Option<(String, String)>,
117 add_res_header: Option<(String, String)>,
118
119 early_return_status: Option<StatusCode>,
121 early_return_body: Option<String>,
122
123 fail_request: bool,
124 fail_response: bool,
125 }
126
127 impl TestMiddleware {
128 fn new(id: usize) -> Self {
129 Self {
130 id,
131 request_calls: Arc::new(Mutex::new(Vec::new())),
132 response_calls: Arc::new(Mutex::new(Vec::new())),
133 add_req_header: None,
134 add_res_header: None,
135 early_return_status: None,
136 early_return_body: None,
137 fail_request: false,
138 fail_response: false,
139 }
140 }
141
142 fn with_req_header(mut self, name: &str, value: &str) -> Self {
143 self.add_req_header = Some((name.to_string(), value.to_string()));
144 self
145 }
146
147 fn with_res_header(mut self, name: &str, value: &str) -> Self {
148 self.add_res_header = Some((name.to_string(), value.to_string()));
149 self
150 }
151
152 fn early_return_200(mut self) -> Self {
153 self.early_return_status = Some(StatusCode::OK);
154 self.early_return_body = Some(format!("early-{}", self.id));
155 self
156 }
157
158 #[allow(unused)]
159 fn early_return(mut self, status: StatusCode, body: impl Into<String>) -> Self {
160 self.early_return_status = Some(status);
161 self.early_return_body = Some(body.into());
162 self
163 }
164
165 fn fail_request(mut self) -> Self {
166 self.fail_request = true;
167 self
168 }
169
170 fn fail_response(mut self) -> Self {
171 self.fail_response = true;
172 self
173 }
174 }
175
176 #[async_trait]
177 impl Middleware for TestMiddleware {
178 async fn handle<'req>(
179 &self,
180 mut req: Request<&'req str>,
181 state: Arc<McpAppState>,
182 next: MiddlewareNext<'req>,
183 ) -> TransportServerResult<Response<GenericBody>> {
184 let headers = req
186 .headers()
187 .iter()
188 .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
189 .collect();
190 self.request_calls
191 .lock()
192 .unwrap()
193 .push((self.id, req.body().to_string(), headers));
194
195 if self.fail_request {
196 return Err(TransportServerError::HttpError(format!(
197 "middleware {} failed request",
198 self.id
199 )));
200 }
201
202 if let Some((name, value)) = &self.add_req_header {
204 req.headers_mut().insert(
205 HeaderName::from_bytes(name.as_bytes()).unwrap(),
206 value.parse().unwrap(),
207 );
208 }
209
210 if let (Some(status), Some(body)) = (&self.early_return_status, &self.early_return_body)
212 {
213 return Ok(Response::builder()
214 .status(*status)
215 .body(GenericBody::from_string(body.to_string()))
216 .unwrap());
217 }
218
219 let mut res = next(req, state).await?;
221 if let Some((name, value)) = &self.add_res_header {
223 res.headers_mut().insert(
224 HeaderName::from_bytes(name.as_bytes()).unwrap(),
225 value.parse().unwrap(),
226 );
227 }
228
229 if self.fail_response {
230 return Err(TransportServerError::HttpError(format!(
231 "middleware {} failed response",
232 self.id
233 )));
234 }
235
236 let headers = res
238 .headers()
239 .iter()
240 .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
241 .collect();
242
243 self.response_calls
244 .lock()
245 .unwrap()
246 .push((self.id, res.status().as_u16(), headers));
247
248 Ok(res)
249 }
250 }
251
252 fn final_handler(body: &'static str, status: StatusCode) -> RequestHandler {
254 Box::new(move |_req, _| {
255 let resp = Response::builder()
256 .status(status)
257 .body(GenericBody::from_string(body.to_string()))
258 .unwrap();
259 Box::pin(async move { Ok(resp) })
260 })
261 }
262
263 #[tokio::test]
267 async fn test_middleware_order() {
268 let mw1 = Arc::new(TestMiddleware::new(1));
269 let mw2 = Arc::new(TestMiddleware::new(2));
270 let mw3 = Arc::new(TestMiddleware::new(3));
271
272 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
273
274 let handler = final_handler("final", StatusCode::OK);
275 let composed = compose(&middlewares, handler);
276
277 let req = Request::builder().body("").unwrap();
278 let _ = composed(req, app_state()).await.unwrap();
279
280 let rc3 = mw3.request_calls.lock().unwrap();
282 let rc2 = mw2.request_calls.lock().unwrap();
283 let rc1 = mw1.request_calls.lock().unwrap();
284 assert_eq!(rc3[0].0, 3);
285 assert_eq!(rc2[0].0, 2);
286 assert_eq!(rc1[0].0, 1);
287
288 let pc1 = mw1.response_calls.lock().unwrap();
290 let pc2 = mw2.response_calls.lock().unwrap();
291 let pc3 = mw3.response_calls.lock().unwrap();
292 assert_eq!(pc1[0].0, 1);
293 assert_eq!(pc2[0].0, 2);
294 assert_eq!(pc3[0].0, 3);
295 }
296
297 #[tokio::test]
299 async fn test_request_header_propagation() {
300 let mw1 = Arc::new(TestMiddleware::new(1).with_req_header("x-mid", "1"));
301 let mw2 = Arc::new(TestMiddleware::new(2));
302
303 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
304 let handler = final_handler("ok", StatusCode::OK);
305 let composed = compose(&middlewares, handler);
306
307 let req = Request::builder().body("").unwrap();
308 let _ = composed(req, app_state()).await.unwrap();
309
310 let rc = mw2.request_calls.lock().unwrap();
311 let hdr = rc[0].2.iter().find(|(k, _)| k == "x-mid").map(|(_, v)| v);
312 assert_eq!(hdr, Some(&"1".to_string()));
313 }
314
315 #[tokio::test]
317 async fn test_response_header_propagation() {
318 let mw1 = Arc::new(TestMiddleware::new(1));
319 let mw2 = Arc::new(TestMiddleware::new(2).with_res_header("x-mid", "1"));
320
321 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
322 let handler = final_handler("ok", StatusCode::OK);
323 let composed = compose(&middlewares, handler);
324
325 let req = Request::builder().body("").unwrap();
326 let res = composed(req, app_state()).await.unwrap();
327
328 let pc1 = mw1.response_calls.lock().unwrap();
329
330 let hdr = pc1[0].2.iter().find(|(k, _)| k == "x-mid").map(|(_, v)| v);
331 assert_eq!(hdr, Some(&"1".to_string()));
332
333 assert_eq!(res.headers().get("x-mid").unwrap().to_str().unwrap(), "1");
334 }
335
336 #[tokio::test]
338 async fn test_early_return_stops_chain() {
339 let mw1 = Arc::new(TestMiddleware::new(1).early_return_200());
340 let mw2 = Arc::new(TestMiddleware::new(2));
341 let mw3 = Arc::new(TestMiddleware::new(3));
342
343 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
344 let handler = final_handler("should-not-see", StatusCode::OK);
345 let composed = compose(&middlewares, handler);
346
347 let req = Request::builder().body("").unwrap();
348 let res = composed(req, app_state()).await.unwrap();
349
350 assert_eq!(response_string(res).await, "early-1");
351
352 assert!(mw2.request_calls.lock().unwrap().is_empty());
353 assert!(mw3.request_calls.lock().unwrap().is_empty());
354 }
355
356 #[tokio::test]
358 async fn test_request_error_stops_response_chain() {
359 let mw1 = Arc::new(TestMiddleware::new(1).fail_request());
360 let mw2 = Arc::new(TestMiddleware::new(2));
361
362 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
363 let handler = final_handler("ok", StatusCode::OK);
364 let composed = compose(&middlewares, handler);
365
366 let req = Request::builder().body("").unwrap();
367 let result = composed(req, app_state()).await;
368
369 assert!(result.is_err());
370 assert!(mw2.request_calls.lock().unwrap().is_empty());
371 assert!(mw2.response_calls.lock().unwrap().is_empty());
372 }
373
374 #[tokio::test]
376 async fn test_response_error_after_next() {
377 let mw1 = Arc::new(TestMiddleware::new(1).fail_response());
378 let mw2 = Arc::new(TestMiddleware::new(2));
379
380 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
381 let handler = final_handler("ok", StatusCode::OK);
382 let composed = compose(&middlewares, handler);
383
384 let req = Request::builder().body("").unwrap();
385 let result = composed(req, app_state()).await;
386
387 assert!(result.is_err());
388 assert!(!mw1.request_calls.lock().unwrap().is_empty());
389 assert!(mw1.response_calls.lock().unwrap().is_empty());
391 }
392
393 #[tokio::test]
395 async fn test_no_middleware() {
396 let middlewares: Vec<Arc<dyn Middleware>> = vec![];
397 let handler = final_handler("direct", StatusCode::IM_A_TEAPOT);
398 let composed = compose(&middlewares, handler);
399
400 let req = Request::builder().body("").unwrap();
401 let res = composed(req, app_state()).await.unwrap();
402
403 assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
404 assert_eq!(response_string(res).await, "direct");
405 }
406
407 #[tokio::test]
409 async fn test_multiple_headers_accumulate() {
410 let mw1 = Arc::new(
411 TestMiddleware::new(1)
412 .with_req_header("x-a", "1")
413 .with_res_header("x-b", "1"),
414 );
415 let mw2 = Arc::new(
416 TestMiddleware::new(2)
417 .with_req_header("x-c", "2")
418 .with_res_header("x-d", "2"),
419 );
420
421 let mw3 = Arc::new(TestMiddleware::new(3));
422
423 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone(), mw3.clone()];
424 let handler = final_handler("ok", StatusCode::OK);
425 let composed = compose(&middlewares, handler);
426
427 let req = Request::builder().body("").unwrap();
428 let res = composed(req, app_state()).await.unwrap();
429
430 let h = res.headers();
431 assert_eq!(h["x-b"], "1");
432 assert_eq!(h["x-d"], "2");
433
434 assert!(!h.contains_key("x-a"));
436 assert!(!h.contains_key("x-c"));
437
438 let req_calls_mw3 = mw3.request_calls.lock().unwrap();
440 let req_headers = &req_calls_mw3[0].2;
441
442 assert!(req_headers.iter().any(|(k, v)| k == "x-a" && v == "1"));
443 assert!(req_headers.iter().any(|(k, v)| k == "x-c" && v == "2"));
444 }
445
446 #[tokio::test]
448 async fn test_request_body_unchanged() {
449 let mw1 = Arc::new(TestMiddleware::new(1));
450 let mw2 = Arc::new(TestMiddleware::new(2));
451
452 let middlewares: Vec<Arc<dyn Middleware>> = vec![mw1.clone(), mw2.clone()];
453 let handler: RequestHandler = Box::new(move |req, _| {
454 let body = req.into_body().to_string();
455 Box::pin(async move {
456 Ok(Response::builder()
457 .body(GenericBody::from_string(format!("echo:{body}")))
458 .unwrap())
459 })
460 });
461 let composed = compose(&middlewares, handler);
462
463 let req = Request::builder().body("secret-payload").unwrap();
464 let res = composed(req, app_state()).await.unwrap();
465 assert_eq!(response_string(res).await, "echo:secret-payload");
466 }
467
468 #[tokio::test]
470 async fn test_cors_and_logger_integration() {
471 let cors = Arc::new(CorsMiddleware::permissive());
472 let logger = Arc::new(LoggingMiddleware);
473
474 let middlewares: Vec<Arc<dyn Middleware>> = vec![cors.clone(), logger.clone()];
477 let handler = final_handler("ok", StatusCode::OK);
478 let composed = compose(&middlewares, handler);
479
480 let req = Request::builder()
481 .method(http::Method::GET)
482 .uri("/api")
483 .header("Origin", "https://example.com")
484 .body("")
485 .unwrap();
486
487 let res = composed(req, app_state()).await.unwrap();
488
489 assert_eq!(
491 res.headers()["access-control-allow-origin"],
492 "https://example.com"
493 );
494 assert_eq!(res.headers()["access-control-allow-credentials"], "true");
495 }
496}