1use crate::{
14 mcp_http::{
15 http_utils::{build_response, empty_response},
16 types::GenericBody,
17 McpAppState, Middleware, MiddlewareNext,
18 },
19 mcp_server::error::TransportServerResult,
20};
21use http::{
22 header::{
23 self, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS,
24 ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
25 ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS,
26 ACCESS_CONTROL_REQUEST_METHOD,
27 },
28 Method, Request, Response, StatusCode,
29};
30use std::{collections::HashSet, sync::Arc};
31
32#[derive(Clone)]
36pub struct CorsConfig {
37 pub allow_origins: AllowOrigins,
39
40 pub allow_methods: Vec<Method>,
42
43 pub allow_headers: Vec<HeaderName>,
45
46 pub allow_credentials: bool,
50
51 pub max_age: Option<u32>,
53
54 pub expose_headers: Vec<HeaderName>,
56}
57
58impl Default for CorsConfig {
59 fn default() -> Self {
60 Self {
61 allow_origins: AllowOrigins::Any,
62 allow_methods: vec![Method::GET, Method::POST, Method::OPTIONS],
63 allow_headers: vec![header::CONTENT_TYPE, header::AUTHORIZATION],
64 allow_credentials: false,
65 max_age: Some(86_400), expose_headers: vec![],
67 }
68 }
69}
70
71#[derive(Clone, Debug)]
73pub enum AllowOrigins {
74 Any,
78
79 List(HashSet<String>),
81
82 Echo,
84}
85
86#[derive(Clone)]
91pub struct CorsMiddleware {
92 config: Arc<CorsConfig>,
93}
94
95impl CorsMiddleware {
96 pub fn new(config: CorsConfig) -> Self {
98 Self {
99 config: Arc::new(config),
100 }
101 }
102
103 pub fn permissive() -> Self {
107 Self::new(CorsConfig {
108 allow_origins: AllowOrigins::Any,
109 allow_methods: vec![
110 Method::GET,
111 Method::POST,
112 Method::PUT,
113 Method::DELETE,
114 Method::PATCH,
115 Method::OPTIONS,
116 Method::HEAD,
117 ],
118 allow_headers: vec![
119 header::CONTENT_TYPE,
120 header::AUTHORIZATION,
121 header::ACCEPT,
122 header::ORIGIN,
123 ],
124 allow_credentials: true,
125 max_age: Some(86_400),
126 expose_headers: vec![],
127 })
128 }
129
130 fn resolve_allowed_origin(&self, origin: &str) -> Option<String> {
132 match &self.config.allow_origins {
133 AllowOrigins::Any => {
134 if self.config.allow_credentials {
136 Some(origin.to_string())
141 } else {
142 Some("*".to_string())
143 }
144 }
145 AllowOrigins::List(allowed) => {
146 if allowed.contains(origin) {
147 Some(origin.to_string())
148 } else {
149 None
150 }
151 }
152 AllowOrigins::Echo => Some(origin.to_string()),
153 }
154 }
155
156 fn preflight_response(&self, origin: &str) -> Response<GenericBody> {
158 let allowed_origin = self.resolve_allowed_origin(origin);
159 let mut resp = Response::builder()
160 .status(StatusCode::NO_CONTENT)
161 .body(empty_response())
162 .expect("preflight response is static");
163
164 let headers = resp.headers_mut();
165
166 if let Some(origin) = allowed_origin {
167 headers.insert(
168 ACCESS_CONTROL_ALLOW_ORIGIN,
169 HeaderValue::from_str(&origin).expect("origin is validated"),
170 );
171 }
172
173 if self.config.allow_credentials {
174 headers.insert(
175 ACCESS_CONTROL_ALLOW_CREDENTIALS,
176 HeaderValue::from_static("true"),
177 );
178 }
179
180 if let Some(age) = self.config.max_age {
181 headers.insert(
182 ACCESS_CONTROL_MAX_AGE,
183 HeaderValue::from_str(&age.to_string()).expect("u32 is valid"),
184 );
185 }
186
187 let methods = self
188 .config
189 .allow_methods
190 .iter()
191 .map(|m| m.as_str())
192 .collect::<Vec<_>>()
193 .join(", ");
194 headers.insert(
195 ACCESS_CONTROL_ALLOW_METHODS,
196 HeaderValue::from_str(&methods).expect("methods are static"),
197 );
198
199 let headers_list = self
200 .config
201 .allow_headers
202 .iter()
203 .map(|h| h.as_str())
204 .collect::<Vec<_>>()
205 .join(", ");
206 headers.insert(
207 ACCESS_CONTROL_ALLOW_HEADERS,
208 HeaderValue::from_str(&headers_list).expect("headers are static"),
209 );
210
211 resp
212 }
213
214 fn add_cors_to_response(
216 &self,
217 mut resp: Response<GenericBody>,
218 origin: &str,
219 ) -> Response<GenericBody> {
220 let allowed_origin = self.resolve_allowed_origin(origin);
221 let headers = resp.headers_mut();
222
223 if let Some(origin) = allowed_origin {
224 headers.insert(
225 ACCESS_CONTROL_ALLOW_ORIGIN,
226 HeaderValue::from_str(&origin).expect("origin is validated"),
227 );
228 }
229
230 if self.config.allow_credentials {
231 headers.insert(
232 ACCESS_CONTROL_ALLOW_CREDENTIALS,
233 HeaderValue::from_static("true"),
234 );
235 }
236
237 if !self.config.expose_headers.is_empty() {
238 let expose = self
239 .config
240 .expose_headers
241 .iter()
242 .map(|h| h.as_str())
243 .collect::<Vec<_>>()
244 .join(", ");
245 headers.insert(
246 ACCESS_CONTROL_EXPOSE_HEADERS,
247 HeaderValue::from_str(&expose).expect("expose headers are static"),
248 );
249 }
250
251 resp
252 }
253}
254
255#[async_trait::async_trait]
257impl Middleware for CorsMiddleware {
258 async fn handle<'req>(
263 &self,
264 req: Request<&'req str>,
265 state: Arc<McpAppState>,
266 next: MiddlewareNext<'req>,
267 ) -> TransportServerResult<Response<GenericBody>> {
268 let origin = req
269 .headers()
270 .get(header::ORIGIN)
271 .and_then(|v| v.to_str().ok())
272 .map(|s| s.to_string());
273
274 if *req.method() == Method::OPTIONS {
276 let requested_method = req
277 .headers()
278 .get(ACCESS_CONTROL_REQUEST_METHOD)
279 .and_then(|v| v.to_str().ok())
280 .and_then(|s| s.parse::<Method>().ok());
281
282 let requested_headers = req
283 .headers()
284 .get(ACCESS_CONTROL_REQUEST_HEADERS)
285 .and_then(|v| v.to_str().ok())
286 .map(|s| {
287 s.split(',')
288 .map(|h| h.trim().to_ascii_lowercase())
289 .collect::<HashSet<_>>()
290 })
291 .unwrap_or_default();
292
293 let origin = match origin {
294 Some(o) => o,
295 None => {
296 if matches!(self.config.allow_origins, AllowOrigins::Any)
298 && !self.config.allow_credentials
299 {
300 return Ok(self.preflight_response("*"));
301 } else {
302 let response = build_response(
303 StatusCode::BAD_REQUEST,
304 "CORS origin missing in preflight".to_string(),
305 );
306 return response;
307 }
308 }
309 };
310
311 if self.resolve_allowed_origin(&origin).is_none() {
313 let response =
314 build_response(StatusCode::FORBIDDEN, "CORS origin not allowed".to_string());
315 return response;
316 }
317
318 if let Some(m) = requested_method {
320 if !self.config.allow_methods.contains(&m) {
321 let response = build_response(
322 StatusCode::METHOD_NOT_ALLOWED,
323 "CORS method not allowed".to_string(),
324 );
325 return response;
326 }
327 }
328
329 let allowed = self
331 .config
332 .allow_headers
333 .iter()
334 .map(|h| h.as_str().to_ascii_lowercase())
335 .collect::<HashSet<_>>();
336
337 if !requested_headers.is_subset(&allowed) {
338 let response = build_response(
339 StatusCode::BAD_REQUEST,
340 "CORS header not allowed".to_string(),
341 );
342 return response;
343 }
344
345 return Ok(self.preflight_response(&origin));
347 }
348
349 let mut resp = next(req, state).await?;
351 if let Some(origin) = origin {
352 if self.resolve_allowed_origin(&origin).is_some() {
353 resp = self.add_cors_to_response(resp, &origin);
354 }
355 }
356
357 Ok(resp)
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::{
365 id_generator::{FastIdGenerator, UuidGenerator},
366 mcp_http::{types::GenericBodyExt, MiddlewareNext},
367 mcp_server::{ServerHandler, ToMcpServerHandler},
368 schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities},
369 session_store::InMemorySessionStore,
370 };
371 use http::{header, Request, Response, StatusCode};
372 use std::time::Duration;
373
374 type TestResult = Result<(), Box<dyn std::error::Error>>;
375 struct TestHandler;
376 impl ServerHandler for TestHandler {}
377
378 fn app_state() -> Arc<McpAppState> {
379 let handler = TestHandler {};
380
381 Arc::new(McpAppState {
382 session_store: Arc::new(InMemorySessionStore::new()),
383 id_generator: Arc::new(UuidGenerator {}),
384 stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
385 server_details: Arc::new(InitializeResult {
386 capabilities: ServerCapabilities {
387 ..Default::default()
388 },
389 instructions: None,
390 meta: None,
391 protocol_version: ProtocolVersion::V2025_06_18.to_string(),
392 server_info: Implementation {
393 name: "server".to_string(),
394 title: None,
395 version: "0.1.0".to_string(),
396 },
397 }),
398 handler: handler.to_mcp_server_handler(),
399 ping_interval: Duration::from_secs(15),
400 transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()),
401 enable_json_response: false,
402 event_store: None,
403 })
404 }
405
406 fn make_handler<'req>(status: StatusCode, body: &'static str) -> MiddlewareNext<'req> {
407 Arc::new(move |_, _| {
408 let resp = Response::builder()
409 .status(status)
410 .body(GenericBody::from_string(body.to_string()))
411 .unwrap();
412 Box::pin(async { Ok(resp) })
413 })
414 }
415
416 #[tokio::test]
417 async fn test_preflight_allowed() -> TestResult {
418 let cors = CorsMiddleware::permissive();
419 let handler = make_handler(StatusCode::OK, "should not see");
420
421 let req = Request::builder()
422 .method(Method::OPTIONS)
423 .uri("/")
424 .header(header::ORIGIN, "https://example.com")
425 .header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
426 .header(
427 ACCESS_CONTROL_REQUEST_HEADERS,
428 "content-type, authorization",
429 )
430 .body("")?;
431
432 let resp = cors.handle(req, app_state(), handler).await?;
433
434 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
435 assert_eq!(
436 resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
437 "https://example.com"
438 );
439 assert_eq!(
440 resp.headers()[ACCESS_CONTROL_ALLOW_METHODS],
441 "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD"
442 );
443 Ok(())
444 }
445
446 #[tokio::test]
447 async fn test_preflight_disallowed_origin() -> TestResult {
448 let mut allowed = HashSet::new();
449 allowed.insert("https://trusted.com".to_string());
450
451 let cors = CorsMiddleware::new(CorsConfig {
452 allow_origins: AllowOrigins::List(allowed),
453 allow_methods: vec![Method::GET],
454 allow_headers: vec![],
455 allow_credentials: false,
456 max_age: None,
457 expose_headers: vec![],
458 });
459
460 let handler = make_handler(StatusCode::OK, "irrelevant");
461
462 let req = Request::builder()
463 .method(Method::OPTIONS)
464 .uri("/")
465 .header(header::ORIGIN, "https://evil.com")
466 .header(ACCESS_CONTROL_REQUEST_METHOD, "GET")
467 .body("")?;
468
469 let result: Response<GenericBody> = cors.handle(req, app_state(), handler).await.unwrap();
470 let (parts, _body) = result.into_parts();
471 assert_eq!(parts.status, 403);
472 Ok(())
473 }
474
475 #[tokio::test]
476 async fn test_normal_request_with_origin() -> TestResult {
477 let cors = CorsMiddleware::permissive();
478 let handler = make_handler(StatusCode::OK, "hello");
479
480 let req = Request::builder()
481 .method(Method::GET)
482 .uri("/")
483 .header(header::ORIGIN, "https://client.com")
484 .body("")?;
485
486 let resp = cors.handle(req, app_state(), handler).await?;
487
488 assert_eq!(resp.status(), StatusCode::OK);
489
490 assert_eq!(
491 resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
492 "https://client.com"
493 );
494 assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS], "true");
495 Ok(())
496 }
497
498 #[tokio::test]
499 async fn test_wildcard_with_no_credentials() -> TestResult {
500 let cors = CorsMiddleware::new(CorsConfig {
501 allow_origins: AllowOrigins::Any,
502 allow_methods: vec![Method::GET],
503 allow_headers: vec![],
504 allow_credentials: false,
505 max_age: None,
506 expose_headers: vec![],
507 });
508
509 let handler = make_handler(StatusCode::OK, "ok");
510
511 let req = Request::builder()
512 .method(Method::GET)
513 .uri("/")
514 .header(header::ORIGIN, "https://any.com")
515 .body("")?;
516
517 let resp = cors.handle(req, app_state(), handler).await?;
518 assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN], "*");
519 Ok(())
520 }
521
522 #[tokio::test]
523 async fn test_no_wildcard_with_credentials() -> TestResult {
524 let cors = CorsMiddleware::new(CorsConfig {
525 allow_origins: AllowOrigins::Any,
526 allow_methods: vec![Method::GET],
527 allow_headers: vec![],
528 allow_credentials: true, max_age: None,
530 expose_headers: vec![],
531 });
532
533 let handler = make_handler(StatusCode::OK, "ok");
534
535 let req = Request::builder()
536 .method(Method::GET)
537 .uri("/")
538 .header(header::ORIGIN, "https://any.com")
539 .body("")?;
540
541 let resp = cors.handle(req, app_state(), handler).await?;
542
543 let origin_header = resp
545 .headers()
546 .get(ACCESS_CONTROL_ALLOW_ORIGIN)
547 .expect("CORS header missing");
548 assert_eq!(origin_header, "https://any.com");
549
550 assert_eq!(
552 resp.headers()
553 .get(ACCESS_CONTROL_ALLOW_CREDENTIALS)
554 .unwrap(),
555 "true"
556 );
557 Ok(())
558 }
559
560 #[tokio::test]
561 async fn test_echo_origin_with_credentials() -> TestResult {
562 let cors = CorsMiddleware::new(CorsConfig {
563 allow_origins: AllowOrigins::Echo,
564 allow_methods: vec![Method::GET],
565 allow_headers: vec![],
566 allow_credentials: true,
567 max_age: None,
568 expose_headers: vec![],
569 });
570
571 let handler = make_handler(StatusCode::OK, "ok");
572
573 let req = Request::builder()
574 .method(Method::GET)
575 .uri("/")
576 .header(header::ORIGIN, "https://dynamic.com")
577 .body("")?;
578
579 let resp = cors.handle(req, app_state(), handler).await?;
580 assert_eq!(
581 resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
582 "https://dynamic.com"
583 );
584 assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS], "true");
585 Ok(())
586 }
587
588 #[tokio::test]
589 async fn test_expose_headers() -> TestResult {
590 let cors = CorsMiddleware::new(CorsConfig {
591 allow_origins: AllowOrigins::Any,
592 allow_methods: vec![Method::GET],
593 allow_headers: vec![],
594 allow_credentials: false,
595 max_age: None,
596 expose_headers: vec![HeaderName::from_static("x-ratelimit-remaining")],
597 });
598
599 let handler = make_handler(StatusCode::OK, "ok");
600
601 let req = Request::builder()
602 .method(Method::GET)
603 .uri("/")
604 .header(header::ORIGIN, "https://client.com")
605 .body("")?;
606
607 let resp = cors.handle(req, app_state(), handler).await?;
608 assert_eq!(
609 resp.headers()[ACCESS_CONTROL_EXPOSE_HEADERS],
610 "x-ratelimit-remaining"
611 );
612 Ok(())
613 }
614}