rust_mcp_sdk/mcp_http/middleware/
cors_middleware.rs

1//! # CORS Middleware
2//!
3//! A configurable CORS middleware that follows the
4//! [WHATWG CORS specification](https://fetch.spec.whatwg.org/#http-cors-protocol).
5//!
6//! ## Features
7//! - Full preflight (`OPTIONS`) handling
8//! - Configurable origins: `*`, explicit list, or echo
9//! - Credential support (with correct `Access-Control-Allow-Origin` behavior)
10//! - Header/method validation
11//! - `Access-Control-Expose-Headers` support
12
13use crate::{
14    mcp_http::{types::GenericBody, GenericBodyExt, McpAppState, Middleware, MiddlewareNext},
15    mcp_server::error::TransportServerResult,
16};
17use http::{
18    header::{
19        self, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS,
20        ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
21        ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS,
22        ACCESS_CONTROL_REQUEST_METHOD,
23    },
24    Method, Request, Response, StatusCode,
25};
26use rust_mcp_transport::MCP_SESSION_ID_HEADER;
27use std::{collections::HashSet, sync::Arc};
28
29/// Configuration for CORS behavior.
30///
31/// See [MDN CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) for details.
32#[derive(Clone)]
33pub struct CorsConfig {
34    /// Which origins are allowed to make requests.
35    pub allow_origins: AllowOrigins,
36
37    /// HTTP methods allowed in preflight and actual requests.
38    pub allow_methods: Vec<Method>,
39
40    /// Request headers allowed in preflight.
41    pub allow_headers: Vec<HeaderName>,
42
43    /// Whether to allow credentials (cookies, HTTP auth, etc).
44    ///
45    /// **Important**: When `true`, `allow_origins` cannot be `Any` - browsers reject `*`.
46    pub allow_credentials: bool,
47
48    /// How long (in seconds) the preflight response can be cached.
49    pub max_age: Option<u32>,
50
51    /// Headers that should be exposed to the client JavaScript.
52    pub expose_headers: Vec<HeaderName>,
53}
54
55impl Default for CorsConfig {
56    fn default() -> Self {
57        Self {
58            allow_origins: AllowOrigins::Any,
59            allow_methods: vec![Method::GET, Method::POST, Method::OPTIONS],
60            allow_headers: vec![
61                header::CONTENT_TYPE,
62                header::AUTHORIZATION,
63                HeaderName::from_static(MCP_SESSION_ID_HEADER),
64            ],
65            allow_credentials: false,
66            max_age: Some(86_400), // 24 hours
67            expose_headers: vec![],
68        }
69    }
70}
71
72/// Policy for allowed origins.
73#[derive(Clone, Debug)]
74pub enum AllowOrigins {
75    /// Allow any origin (`*`).
76    ///
77    /// **Cannot** be used with `allow_credentials = true`.
78    Any,
79
80    /// Allow only specific origins.
81    List(HashSet<String>),
82
83    /// Echo the `Origin` header back (required when `allow_credentials = true`).
84    Echo,
85}
86
87/// CORS middleware implementing the `Middleware` trait.
88///
89/// Handles both **preflight** (`OPTIONS`) and **actual** requests,
90/// adding appropriate CORS headers and rejecting invalid origins/methods/headers.
91#[derive(Clone, Default)]
92pub struct CorsMiddleware {
93    config: Arc<CorsConfig>,
94}
95
96impl CorsMiddleware {
97    /// Create a new CORS middleware with custom config.
98    pub fn new(config: CorsConfig) -> Self {
99        Self {
100            config: Arc::new(config),
101        }
102    }
103
104    /// Create a permissive CORS config - useful for public APIs or local dev.
105    ///
106    /// Allows all common methods, credentials, and common headers.
107    pub fn permissive() -> Self {
108        Self::new(CorsConfig {
109            allow_origins: AllowOrigins::Any,
110            allow_methods: vec![
111                Method::GET,
112                Method::POST,
113                Method::PUT,
114                Method::DELETE,
115                Method::PATCH,
116                Method::OPTIONS,
117                Method::HEAD,
118            ],
119            allow_headers: vec![
120                header::CONTENT_TYPE,
121                header::AUTHORIZATION,
122                header::ACCEPT,
123                header::ORIGIN,
124            ],
125            allow_credentials: true,
126            max_age: Some(86_400),
127            expose_headers: vec![],
128        })
129    }
130
131    // Internal: resolve allowed origin header value
132    fn resolve_allowed_origin(&self, origin: &str) -> Option<String> {
133        match &self.config.allow_origins {
134            AllowOrigins::Any => {
135                // Only return "*" if credentials are not allowed
136                if self.config.allow_credentials {
137                    // rule MDN , RFC 6454
138                    // If Access-Control-Allow-Credentials: true is set,
139                    // then Access-Control-Allow-Origin CANNOT be *.
140                    // It MUST be the exact origin (e.g., https://example.com).
141                    Some(origin.to_string())
142                } else {
143                    Some("*".to_string())
144                }
145            }
146            AllowOrigins::List(allowed) => {
147                if allowed.contains(origin) {
148                    Some(origin.to_string())
149                } else {
150                    None
151                }
152            }
153            AllowOrigins::Echo => Some(origin.to_string()),
154        }
155    }
156
157    // Build preflight response (204 No Content)
158    fn preflight_response(&self, origin: &str) -> Response<GenericBody> {
159        let allowed_origin = self.resolve_allowed_origin(origin);
160        let mut resp = Response::builder()
161            .status(StatusCode::NO_CONTENT)
162            .body(GenericBody::empty())
163            .expect("preflight response is static");
164
165        let headers = resp.headers_mut();
166
167        if let Some(origin) = allowed_origin {
168            headers.insert(
169                ACCESS_CONTROL_ALLOW_ORIGIN,
170                HeaderValue::from_str(&origin).expect("origin is validated"),
171            );
172        }
173
174        if self.config.allow_credentials {
175            headers.insert(
176                ACCESS_CONTROL_ALLOW_CREDENTIALS,
177                HeaderValue::from_static("true"),
178            );
179        }
180
181        if let Some(age) = self.config.max_age {
182            headers.insert(
183                ACCESS_CONTROL_MAX_AGE,
184                HeaderValue::from_str(&age.to_string()).expect("u32 is valid"),
185            );
186        }
187
188        let methods = self
189            .config
190            .allow_methods
191            .iter()
192            .map(|m| m.as_str())
193            .collect::<Vec<_>>()
194            .join(", ");
195        headers.insert(
196            ACCESS_CONTROL_ALLOW_METHODS,
197            HeaderValue::from_str(&methods).expect("methods are static"),
198        );
199
200        let headers_list = self
201            .config
202            .allow_headers
203            .iter()
204            .map(|h| h.as_str())
205            .collect::<Vec<_>>()
206            .join(", ");
207        headers.insert(
208            ACCESS_CONTROL_ALLOW_HEADERS,
209            HeaderValue::from_str(&headers_list).expect("headers are static"),
210        );
211
212        resp
213    }
214
215    // Add CORS headers to normal response
216    fn add_cors_to_response(
217        &self,
218        mut resp: Response<GenericBody>,
219        origin: &str,
220    ) -> Response<GenericBody> {
221        let allowed_origin = self.resolve_allowed_origin(origin);
222        let headers = resp.headers_mut();
223
224        if let Some(origin) = allowed_origin {
225            headers.insert(
226                ACCESS_CONTROL_ALLOW_ORIGIN,
227                HeaderValue::from_str(&origin).expect("origin is validated"),
228            );
229        }
230
231        if self.config.allow_credentials {
232            headers.insert(
233                ACCESS_CONTROL_ALLOW_CREDENTIALS,
234                HeaderValue::from_static("true"),
235            );
236        }
237
238        if !self.config.expose_headers.is_empty() {
239            let expose = self
240                .config
241                .expose_headers
242                .iter()
243                .map(|h| h.as_str())
244                .collect::<Vec<_>>()
245                .join(", ");
246            headers.insert(
247                ACCESS_CONTROL_EXPOSE_HEADERS,
248                HeaderValue::from_str(&expose).expect("expose headers are static"),
249            );
250        }
251
252        resp
253    }
254}
255
256// Middleware trait implementation
257#[async_trait::async_trait]
258impl Middleware for CorsMiddleware {
259    /// Process a request, handling preflight or adding CORS headers.
260    ///
261    /// - For `OPTIONS` with `Access-Control-Request-Method`: performs preflight.
262    /// - For other requests: passes to `next`, then adds CORS headers.
263    async fn handle<'req>(
264        &self,
265        req: Request<&'req str>,
266        state: Arc<McpAppState>,
267        next: MiddlewareNext<'req>,
268    ) -> TransportServerResult<Response<GenericBody>> {
269        let origin = req
270            .headers()
271            .get(header::ORIGIN)
272            .and_then(|v| v.to_str().ok())
273            .map(|s| s.to_string());
274
275        // Preflight: OPTIONS + Access-Control-Request-Method
276        if *req.method() == Method::OPTIONS {
277            let requested_method = req
278                .headers()
279                .get(ACCESS_CONTROL_REQUEST_METHOD)
280                .and_then(|v| v.to_str().ok())
281                .and_then(|s| s.parse::<Method>().ok());
282
283            let requested_headers = req
284                .headers()
285                .get(ACCESS_CONTROL_REQUEST_HEADERS)
286                .and_then(|v| v.to_str().ok())
287                .map(|s| {
288                    s.split(',')
289                        .map(|h| h.trim().to_ascii_lowercase())
290                        .collect::<HashSet<_>>()
291                })
292                .unwrap_or_default();
293
294            let origin = match origin {
295                Some(o) => o,
296                None => {
297                    // Some tools send preflight without Origin - allow if Any
298                    if matches!(self.config.allow_origins, AllowOrigins::Any)
299                        && !self.config.allow_credentials
300                    {
301                        return Ok(self.preflight_response("*"));
302                    } else {
303                        return Ok(GenericBody::build_response(
304                            StatusCode::BAD_REQUEST,
305                            "CORS origin missing in preflight".to_string(),
306                            None,
307                        ));
308                    }
309                }
310            };
311
312            // Validate origin
313            if self.resolve_allowed_origin(&origin).is_none() {
314                return Ok(GenericBody::build_response(
315                    StatusCode::FORBIDDEN,
316                    "CORS origin not allowed".to_string(),
317                    None,
318                ));
319            }
320
321            // Validate method
322            if let Some(m) = requested_method {
323                if !self.config.allow_methods.contains(&m) {
324                    return Ok(GenericBody::build_response(
325                        StatusCode::METHOD_NOT_ALLOWED,
326                        "CORS method not allowed".to_string(),
327                        None,
328                    ));
329                }
330            }
331
332            // Validate headers
333            let allowed = self
334                .config
335                .allow_headers
336                .iter()
337                .map(|h| h.as_str().to_ascii_lowercase())
338                .collect::<HashSet<_>>();
339
340            if !requested_headers.is_subset(&allowed) {
341                return Ok(GenericBody::build_response(
342                    StatusCode::BAD_REQUEST,
343                    "CORS header not allowed".to_string(),
344                    None,
345                ));
346            }
347
348            // All good - return preflight
349            return Ok(self.preflight_response(&origin));
350        }
351
352        // Normal request: forward to next handler
353        let mut resp = next(req, state).await?;
354        if let Some(origin) = origin {
355            if self.resolve_allowed_origin(&origin).is_some() {
356                resp = self.add_cors_to_response(resp, &origin);
357            }
358        }
359
360        Ok(resp)
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use crate::{
368        id_generator::{FastIdGenerator, UuidGenerator},
369        mcp_http::{types::GenericBodyExt, MiddlewareNext},
370        mcp_server::{ServerHandler, ToMcpServerHandler},
371        schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities},
372        session_store::InMemorySessionStore,
373    };
374    use http::{header, Request, Response, StatusCode};
375    use std::time::Duration;
376
377    type TestResult = Result<(), Box<dyn std::error::Error>>;
378    struct TestHandler;
379    impl ServerHandler for TestHandler {}
380
381    fn app_state() -> Arc<McpAppState> {
382        let handler = TestHandler {};
383
384        Arc::new(McpAppState {
385            session_store: Arc::new(InMemorySessionStore::new()),
386            id_generator: Arc::new(UuidGenerator {}),
387            stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
388            server_details: Arc::new(InitializeResult {
389                capabilities: ServerCapabilities {
390                    ..Default::default()
391                },
392                instructions: None,
393                meta: None,
394                protocol_version: ProtocolVersion::V2025_06_18.to_string(),
395                server_info: Implementation {
396                    name: "server".to_string(),
397                    title: None,
398                    version: "0.1.0".to_string(),
399                },
400            }),
401            handler: handler.to_mcp_server_handler(),
402            ping_interval: Duration::from_secs(15),
403            transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()),
404            enable_json_response: false,
405            event_store: None,
406        })
407    }
408
409    fn make_handler<'req>(status: StatusCode, body: &'static str) -> MiddlewareNext<'req> {
410        Box::new(move |_, _| {
411            let resp = Response::builder()
412                .status(status)
413                .body(GenericBody::from_string(body.to_string()))
414                .unwrap();
415            Box::pin(async { Ok(resp) })
416        })
417    }
418
419    #[tokio::test]
420    async fn test_preflight_allowed() -> TestResult {
421        let cors = CorsMiddleware::permissive();
422        let handler = make_handler(StatusCode::OK, "should not see");
423
424        let req = Request::builder()
425            .method(Method::OPTIONS)
426            .uri("/")
427            .header(header::ORIGIN, "https://example.com")
428            .header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
429            .header(
430                ACCESS_CONTROL_REQUEST_HEADERS,
431                "content-type, authorization",
432            )
433            .body("")?;
434
435        let resp = cors.handle(req, app_state(), handler).await?;
436
437        assert_eq!(resp.status(), StatusCode::NO_CONTENT);
438        assert_eq!(
439            resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
440            "https://example.com"
441        );
442        assert_eq!(
443            resp.headers()[ACCESS_CONTROL_ALLOW_METHODS],
444            "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD"
445        );
446        Ok(())
447    }
448
449    #[tokio::test]
450    async fn test_preflight_disallowed_origin() -> TestResult {
451        let mut allowed = HashSet::new();
452        allowed.insert("https://trusted.com".to_string());
453
454        let cors = CorsMiddleware::new(CorsConfig {
455            allow_origins: AllowOrigins::List(allowed),
456            allow_methods: vec![Method::GET],
457            allow_headers: vec![],
458            allow_credentials: false,
459            max_age: None,
460            expose_headers: vec![],
461        });
462
463        let handler = make_handler(StatusCode::OK, "irrelevant");
464
465        let req = Request::builder()
466            .method(Method::OPTIONS)
467            .uri("/")
468            .header(header::ORIGIN, "https://evil.com")
469            .header(ACCESS_CONTROL_REQUEST_METHOD, "GET")
470            .body("")?;
471
472        let result: Response<GenericBody> = cors.handle(req, app_state(), handler).await.unwrap();
473        let (parts, _body) = result.into_parts();
474        assert_eq!(parts.status, 403);
475        Ok(())
476    }
477
478    #[tokio::test]
479    async fn test_normal_request_with_origin() -> TestResult {
480        let cors = CorsMiddleware::permissive();
481        let handler = make_handler(StatusCode::OK, "hello");
482
483        let req = Request::builder()
484            .method(Method::GET)
485            .uri("/")
486            .header(header::ORIGIN, "https://client.com")
487            .body("")?;
488
489        let resp = cors.handle(req, app_state(), handler).await?;
490
491        assert_eq!(resp.status(), StatusCode::OK);
492
493        assert_eq!(
494            resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
495            "https://client.com"
496        );
497        assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS], "true");
498        Ok(())
499    }
500
501    #[tokio::test]
502    async fn test_wildcard_with_no_credentials() -> TestResult {
503        let cors = CorsMiddleware::new(CorsConfig {
504            allow_origins: AllowOrigins::Any,
505            allow_methods: vec![Method::GET],
506            allow_headers: vec![],
507            allow_credentials: false,
508            max_age: None,
509            expose_headers: vec![],
510        });
511
512        let handler = make_handler(StatusCode::OK, "ok");
513
514        let req = Request::builder()
515            .method(Method::GET)
516            .uri("/")
517            .header(header::ORIGIN, "https://any.com")
518            .body("")?;
519
520        let resp = cors.handle(req, app_state(), handler).await?;
521        assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN], "*");
522        Ok(())
523    }
524
525    #[tokio::test]
526    async fn test_no_wildcard_with_credentials() -> TestResult {
527        let cors = CorsMiddleware::new(CorsConfig {
528            allow_origins: AllowOrigins::Any,
529            allow_methods: vec![Method::GET],
530            allow_headers: vec![],
531            allow_credentials: true, // This should prevent "*"
532            max_age: None,
533            expose_headers: vec![],
534        });
535
536        let handler = make_handler(StatusCode::OK, "ok");
537
538        let req = Request::builder()
539            .method(Method::GET)
540            .uri("/")
541            .header(header::ORIGIN, "https://any.com")
542            .body("")?;
543
544        let resp = cors.handle(req, app_state(), handler).await?;
545
546        // Should NOT have "*" even though config says Any
547        let origin_header = resp
548            .headers()
549            .get(ACCESS_CONTROL_ALLOW_ORIGIN)
550            .expect("CORS header missing");
551        assert_eq!(origin_header, "https://any.com");
552
553        // And credentials should be allowed
554        assert_eq!(
555            resp.headers()
556                .get(ACCESS_CONTROL_ALLOW_CREDENTIALS)
557                .unwrap(),
558            "true"
559        );
560        Ok(())
561    }
562
563    #[tokio::test]
564    async fn test_echo_origin_with_credentials() -> TestResult {
565        let cors = CorsMiddleware::new(CorsConfig {
566            allow_origins: AllowOrigins::Echo,
567            allow_methods: vec![Method::GET],
568            allow_headers: vec![],
569            allow_credentials: true,
570            max_age: None,
571            expose_headers: vec![],
572        });
573
574        let handler = make_handler(StatusCode::OK, "ok");
575
576        let req = Request::builder()
577            .method(Method::GET)
578            .uri("/")
579            .header(header::ORIGIN, "https://dynamic.com")
580            .body("")?;
581
582        let resp = cors.handle(req, app_state(), handler).await?;
583        assert_eq!(
584            resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
585            "https://dynamic.com"
586        );
587        assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS], "true");
588        Ok(())
589    }
590
591    #[tokio::test]
592    async fn test_expose_headers() -> TestResult {
593        let cors = CorsMiddleware::new(CorsConfig {
594            allow_origins: AllowOrigins::Any,
595            allow_methods: vec![Method::GET],
596            allow_headers: vec![],
597            allow_credentials: false,
598            max_age: None,
599            expose_headers: vec![HeaderName::from_static("x-ratelimit-remaining")],
600        });
601
602        let handler = make_handler(StatusCode::OK, "ok");
603
604        let req = Request::builder()
605            .method(Method::GET)
606            .uri("/")
607            .header(header::ORIGIN, "https://client.com")
608            .body("")?;
609
610        let resp = cors.handle(req, app_state(), handler).await?;
611        assert_eq!(
612            resp.headers()[ACCESS_CONTROL_EXPOSE_HEADERS],
613            "x-ratelimit-remaining"
614        );
615        Ok(())
616    }
617}