1use crate::{
14 mcp_http::{types::GenericBody, GenericBodyExt, McpAppState, Middleware, MiddlewareNext},
15 mcp_server::error::TransportServerResult,
16};
17use http::{
18 header::{
19 self, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS,
20 ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
21 ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS,
22 ACCESS_CONTROL_REQUEST_METHOD,
23 },
24 Method, Request, Response, StatusCode,
25};
26use rust_mcp_transport::MCP_SESSION_ID_HEADER;
27use std::{collections::HashSet, sync::Arc};
28
29#[derive(Clone)]
33pub struct CorsConfig {
34 pub allow_origins: AllowOrigins,
36
37 pub allow_methods: Vec<Method>,
39
40 pub allow_headers: Vec<HeaderName>,
42
43 pub allow_credentials: bool,
47
48 pub max_age: Option<u32>,
50
51 pub expose_headers: Vec<HeaderName>,
53}
54
55impl Default for CorsConfig {
56 fn default() -> Self {
57 Self {
58 allow_origins: AllowOrigins::Any,
59 allow_methods: vec![Method::GET, Method::POST, Method::OPTIONS],
60 allow_headers: vec![
61 header::CONTENT_TYPE,
62 header::AUTHORIZATION,
63 HeaderName::from_static(MCP_SESSION_ID_HEADER),
64 ],
65 allow_credentials: false,
66 max_age: Some(86_400), expose_headers: vec![],
68 }
69 }
70}
71
72#[derive(Clone, Debug)]
74pub enum AllowOrigins {
75 Any,
79
80 List(HashSet<String>),
82
83 Echo,
85}
86
87#[derive(Clone, Default)]
92pub struct CorsMiddleware {
93 config: Arc<CorsConfig>,
94}
95
96impl CorsMiddleware {
97 pub fn new(config: CorsConfig) -> Self {
99 Self {
100 config: Arc::new(config),
101 }
102 }
103
104 pub fn permissive() -> Self {
108 Self::new(CorsConfig {
109 allow_origins: AllowOrigins::Any,
110 allow_methods: vec![
111 Method::GET,
112 Method::POST,
113 Method::PUT,
114 Method::DELETE,
115 Method::PATCH,
116 Method::OPTIONS,
117 Method::HEAD,
118 ],
119 allow_headers: vec![
120 header::CONTENT_TYPE,
121 header::AUTHORIZATION,
122 header::ACCEPT,
123 header::ORIGIN,
124 ],
125 allow_credentials: true,
126 max_age: Some(86_400),
127 expose_headers: vec![],
128 })
129 }
130
131 fn resolve_allowed_origin(&self, origin: &str) -> Option<String> {
133 match &self.config.allow_origins {
134 AllowOrigins::Any => {
135 if self.config.allow_credentials {
137 Some(origin.to_string())
142 } else {
143 Some("*".to_string())
144 }
145 }
146 AllowOrigins::List(allowed) => {
147 if allowed.contains(origin) {
148 Some(origin.to_string())
149 } else {
150 None
151 }
152 }
153 AllowOrigins::Echo => Some(origin.to_string()),
154 }
155 }
156
157 fn preflight_response(&self, origin: &str) -> Response<GenericBody> {
159 let allowed_origin = self.resolve_allowed_origin(origin);
160 let mut resp = Response::builder()
161 .status(StatusCode::NO_CONTENT)
162 .body(GenericBody::empty())
163 .expect("preflight response is static");
164
165 let headers = resp.headers_mut();
166
167 if let Some(origin) = allowed_origin {
168 headers.insert(
169 ACCESS_CONTROL_ALLOW_ORIGIN,
170 HeaderValue::from_str(&origin).expect("origin is validated"),
171 );
172 }
173
174 if self.config.allow_credentials {
175 headers.insert(
176 ACCESS_CONTROL_ALLOW_CREDENTIALS,
177 HeaderValue::from_static("true"),
178 );
179 }
180
181 if let Some(age) = self.config.max_age {
182 headers.insert(
183 ACCESS_CONTROL_MAX_AGE,
184 HeaderValue::from_str(&age.to_string()).expect("u32 is valid"),
185 );
186 }
187
188 let methods = self
189 .config
190 .allow_methods
191 .iter()
192 .map(|m| m.as_str())
193 .collect::<Vec<_>>()
194 .join(", ");
195 headers.insert(
196 ACCESS_CONTROL_ALLOW_METHODS,
197 HeaderValue::from_str(&methods).expect("methods are static"),
198 );
199
200 let headers_list = self
201 .config
202 .allow_headers
203 .iter()
204 .map(|h| h.as_str())
205 .collect::<Vec<_>>()
206 .join(", ");
207 headers.insert(
208 ACCESS_CONTROL_ALLOW_HEADERS,
209 HeaderValue::from_str(&headers_list).expect("headers are static"),
210 );
211
212 resp
213 }
214
215 fn add_cors_to_response(
217 &self,
218 mut resp: Response<GenericBody>,
219 origin: &str,
220 ) -> Response<GenericBody> {
221 let allowed_origin = self.resolve_allowed_origin(origin);
222 let headers = resp.headers_mut();
223
224 if let Some(origin) = allowed_origin {
225 headers.insert(
226 ACCESS_CONTROL_ALLOW_ORIGIN,
227 HeaderValue::from_str(&origin).expect("origin is validated"),
228 );
229 }
230
231 if self.config.allow_credentials {
232 headers.insert(
233 ACCESS_CONTROL_ALLOW_CREDENTIALS,
234 HeaderValue::from_static("true"),
235 );
236 }
237
238 if !self.config.expose_headers.is_empty() {
239 let expose = self
240 .config
241 .expose_headers
242 .iter()
243 .map(|h| h.as_str())
244 .collect::<Vec<_>>()
245 .join(", ");
246 headers.insert(
247 ACCESS_CONTROL_EXPOSE_HEADERS,
248 HeaderValue::from_str(&expose).expect("expose headers are static"),
249 );
250 }
251
252 resp
253 }
254}
255
256#[async_trait::async_trait]
258impl Middleware for CorsMiddleware {
259 async fn handle<'req>(
264 &self,
265 req: Request<&'req str>,
266 state: Arc<McpAppState>,
267 next: MiddlewareNext<'req>,
268 ) -> TransportServerResult<Response<GenericBody>> {
269 let origin = req
270 .headers()
271 .get(header::ORIGIN)
272 .and_then(|v| v.to_str().ok())
273 .map(|s| s.to_string());
274
275 if *req.method() == Method::OPTIONS {
277 let requested_method = req
278 .headers()
279 .get(ACCESS_CONTROL_REQUEST_METHOD)
280 .and_then(|v| v.to_str().ok())
281 .and_then(|s| s.parse::<Method>().ok());
282
283 let requested_headers = req
284 .headers()
285 .get(ACCESS_CONTROL_REQUEST_HEADERS)
286 .and_then(|v| v.to_str().ok())
287 .map(|s| {
288 s.split(',')
289 .map(|h| h.trim().to_ascii_lowercase())
290 .collect::<HashSet<_>>()
291 })
292 .unwrap_or_default();
293
294 let origin = match origin {
295 Some(o) => o,
296 None => {
297 if matches!(self.config.allow_origins, AllowOrigins::Any)
299 && !self.config.allow_credentials
300 {
301 return Ok(self.preflight_response("*"));
302 } else {
303 return Ok(GenericBody::build_response(
304 StatusCode::BAD_REQUEST,
305 "CORS origin missing in preflight".to_string(),
306 None,
307 ));
308 }
309 }
310 };
311
312 if self.resolve_allowed_origin(&origin).is_none() {
314 return Ok(GenericBody::build_response(
315 StatusCode::FORBIDDEN,
316 "CORS origin not allowed".to_string(),
317 None,
318 ));
319 }
320
321 if let Some(m) = requested_method {
323 if !self.config.allow_methods.contains(&m) {
324 return Ok(GenericBody::build_response(
325 StatusCode::METHOD_NOT_ALLOWED,
326 "CORS method not allowed".to_string(),
327 None,
328 ));
329 }
330 }
331
332 let allowed = self
334 .config
335 .allow_headers
336 .iter()
337 .map(|h| h.as_str().to_ascii_lowercase())
338 .collect::<HashSet<_>>();
339
340 if !requested_headers.is_subset(&allowed) {
341 return Ok(GenericBody::build_response(
342 StatusCode::BAD_REQUEST,
343 "CORS header not allowed".to_string(),
344 None,
345 ));
346 }
347
348 return Ok(self.preflight_response(&origin));
350 }
351
352 let mut resp = next(req, state).await?;
354 if let Some(origin) = origin {
355 if self.resolve_allowed_origin(&origin).is_some() {
356 resp = self.add_cors_to_response(resp, &origin);
357 }
358 }
359
360 Ok(resp)
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use crate::{
368 id_generator::{FastIdGenerator, UuidGenerator},
369 mcp_http::{types::GenericBodyExt, MiddlewareNext},
370 mcp_server::{ServerHandler, ToMcpServerHandler},
371 schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities},
372 session_store::InMemorySessionStore,
373 };
374 use http::{header, Request, Response, StatusCode};
375 use std::time::Duration;
376
377 type TestResult = Result<(), Box<dyn std::error::Error>>;
378 struct TestHandler;
379 impl ServerHandler for TestHandler {}
380
381 fn app_state() -> Arc<McpAppState> {
382 let handler = TestHandler {};
383
384 Arc::new(McpAppState {
385 session_store: Arc::new(InMemorySessionStore::new()),
386 id_generator: Arc::new(UuidGenerator {}),
387 stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
388 server_details: Arc::new(InitializeResult {
389 capabilities: ServerCapabilities {
390 ..Default::default()
391 },
392 instructions: None,
393 meta: None,
394 protocol_version: ProtocolVersion::V2025_06_18.to_string(),
395 server_info: Implementation {
396 name: "server".to_string(),
397 title: None,
398 version: "0.1.0".to_string(),
399 },
400 }),
401 handler: handler.to_mcp_server_handler(),
402 ping_interval: Duration::from_secs(15),
403 transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()),
404 enable_json_response: false,
405 event_store: None,
406 })
407 }
408
409 fn make_handler<'req>(status: StatusCode, body: &'static str) -> MiddlewareNext<'req> {
410 Box::new(move |_, _| {
411 let resp = Response::builder()
412 .status(status)
413 .body(GenericBody::from_string(body.to_string()))
414 .unwrap();
415 Box::pin(async { Ok(resp) })
416 })
417 }
418
419 #[tokio::test]
420 async fn test_preflight_allowed() -> TestResult {
421 let cors = CorsMiddleware::permissive();
422 let handler = make_handler(StatusCode::OK, "should not see");
423
424 let req = Request::builder()
425 .method(Method::OPTIONS)
426 .uri("/")
427 .header(header::ORIGIN, "https://example.com")
428 .header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
429 .header(
430 ACCESS_CONTROL_REQUEST_HEADERS,
431 "content-type, authorization",
432 )
433 .body("")?;
434
435 let resp = cors.handle(req, app_state(), handler).await?;
436
437 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
438 assert_eq!(
439 resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
440 "https://example.com"
441 );
442 assert_eq!(
443 resp.headers()[ACCESS_CONTROL_ALLOW_METHODS],
444 "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD"
445 );
446 Ok(())
447 }
448
449 #[tokio::test]
450 async fn test_preflight_disallowed_origin() -> TestResult {
451 let mut allowed = HashSet::new();
452 allowed.insert("https://trusted.com".to_string());
453
454 let cors = CorsMiddleware::new(CorsConfig {
455 allow_origins: AllowOrigins::List(allowed),
456 allow_methods: vec![Method::GET],
457 allow_headers: vec![],
458 allow_credentials: false,
459 max_age: None,
460 expose_headers: vec![],
461 });
462
463 let handler = make_handler(StatusCode::OK, "irrelevant");
464
465 let req = Request::builder()
466 .method(Method::OPTIONS)
467 .uri("/")
468 .header(header::ORIGIN, "https://evil.com")
469 .header(ACCESS_CONTROL_REQUEST_METHOD, "GET")
470 .body("")?;
471
472 let result: Response<GenericBody> = cors.handle(req, app_state(), handler).await.unwrap();
473 let (parts, _body) = result.into_parts();
474 assert_eq!(parts.status, 403);
475 Ok(())
476 }
477
478 #[tokio::test]
479 async fn test_normal_request_with_origin() -> TestResult {
480 let cors = CorsMiddleware::permissive();
481 let handler = make_handler(StatusCode::OK, "hello");
482
483 let req = Request::builder()
484 .method(Method::GET)
485 .uri("/")
486 .header(header::ORIGIN, "https://client.com")
487 .body("")?;
488
489 let resp = cors.handle(req, app_state(), handler).await?;
490
491 assert_eq!(resp.status(), StatusCode::OK);
492
493 assert_eq!(
494 resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
495 "https://client.com"
496 );
497 assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS], "true");
498 Ok(())
499 }
500
501 #[tokio::test]
502 async fn test_wildcard_with_no_credentials() -> TestResult {
503 let cors = CorsMiddleware::new(CorsConfig {
504 allow_origins: AllowOrigins::Any,
505 allow_methods: vec![Method::GET],
506 allow_headers: vec![],
507 allow_credentials: false,
508 max_age: None,
509 expose_headers: vec![],
510 });
511
512 let handler = make_handler(StatusCode::OK, "ok");
513
514 let req = Request::builder()
515 .method(Method::GET)
516 .uri("/")
517 .header(header::ORIGIN, "https://any.com")
518 .body("")?;
519
520 let resp = cors.handle(req, app_state(), handler).await?;
521 assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN], "*");
522 Ok(())
523 }
524
525 #[tokio::test]
526 async fn test_no_wildcard_with_credentials() -> TestResult {
527 let cors = CorsMiddleware::new(CorsConfig {
528 allow_origins: AllowOrigins::Any,
529 allow_methods: vec![Method::GET],
530 allow_headers: vec![],
531 allow_credentials: true, max_age: None,
533 expose_headers: vec![],
534 });
535
536 let handler = make_handler(StatusCode::OK, "ok");
537
538 let req = Request::builder()
539 .method(Method::GET)
540 .uri("/")
541 .header(header::ORIGIN, "https://any.com")
542 .body("")?;
543
544 let resp = cors.handle(req, app_state(), handler).await?;
545
546 let origin_header = resp
548 .headers()
549 .get(ACCESS_CONTROL_ALLOW_ORIGIN)
550 .expect("CORS header missing");
551 assert_eq!(origin_header, "https://any.com");
552
553 assert_eq!(
555 resp.headers()
556 .get(ACCESS_CONTROL_ALLOW_CREDENTIALS)
557 .unwrap(),
558 "true"
559 );
560 Ok(())
561 }
562
563 #[tokio::test]
564 async fn test_echo_origin_with_credentials() -> TestResult {
565 let cors = CorsMiddleware::new(CorsConfig {
566 allow_origins: AllowOrigins::Echo,
567 allow_methods: vec![Method::GET],
568 allow_headers: vec![],
569 allow_credentials: true,
570 max_age: None,
571 expose_headers: vec![],
572 });
573
574 let handler = make_handler(StatusCode::OK, "ok");
575
576 let req = Request::builder()
577 .method(Method::GET)
578 .uri("/")
579 .header(header::ORIGIN, "https://dynamic.com")
580 .body("")?;
581
582 let resp = cors.handle(req, app_state(), handler).await?;
583 assert_eq!(
584 resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
585 "https://dynamic.com"
586 );
587 assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS], "true");
588 Ok(())
589 }
590
591 #[tokio::test]
592 async fn test_expose_headers() -> TestResult {
593 let cors = CorsMiddleware::new(CorsConfig {
594 allow_origins: AllowOrigins::Any,
595 allow_methods: vec![Method::GET],
596 allow_headers: vec![],
597 allow_credentials: false,
598 max_age: None,
599 expose_headers: vec![HeaderName::from_static("x-ratelimit-remaining")],
600 });
601
602 let handler = make_handler(StatusCode::OK, "ok");
603
604 let req = Request::builder()
605 .method(Method::GET)
606 .uri("/")
607 .header(header::ORIGIN, "https://client.com")
608 .body("")?;
609
610 let resp = cors.handle(req, app_state(), handler).await?;
611 assert_eq!(
612 resp.headers()[ACCESS_CONTROL_EXPOSE_HEADERS],
613 "x-ratelimit-remaining"
614 );
615 Ok(())
616 }
617}