rust_mcp_sdk/mcp_http/middleware/
cors_middleware.rs

1//! # CORS Middleware
2//!
3//! A configurable CORS middleware that follows the
4//! [WHATWG CORS specification](https://fetch.spec.whatwg.org/#http-cors-protocol).
5//!
6//! ## Features
7//! - Full preflight (`OPTIONS`) handling
8//! - Configurable origins: `*`, explicit list, or echo
9//! - Credential support (with correct `Access-Control-Allow-Origin` behavior)
10//! - Header/method validation
11//! - `Access-Control-Expose-Headers` support
12
13use crate::{
14    mcp_http::{
15        http_utils::{build_response, empty_response},
16        types::GenericBody,
17        McpAppState, Middleware, MiddlewareNext,
18    },
19    mcp_server::error::TransportServerResult,
20};
21use http::{
22    header::{
23        self, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS,
24        ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
25        ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS,
26        ACCESS_CONTROL_REQUEST_METHOD,
27    },
28    Method, Request, Response, StatusCode,
29};
30use std::{collections::HashSet, sync::Arc};
31
32/// Configuration for CORS behavior.
33///
34/// See [MDN CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) for details.
35#[derive(Clone)]
36pub struct CorsConfig {
37    /// Which origins are allowed to make requests.
38    pub allow_origins: AllowOrigins,
39
40    /// HTTP methods allowed in preflight and actual requests.
41    pub allow_methods: Vec<Method>,
42
43    /// Request headers allowed in preflight.
44    pub allow_headers: Vec<HeaderName>,
45
46    /// Whether to allow credentials (cookies, HTTP auth, etc).
47    ///
48    /// **Important**: When `true`, `allow_origins` cannot be `Any` — browsers reject `*`.
49    pub allow_credentials: bool,
50
51    /// How long (in seconds) the preflight response can be cached.
52    pub max_age: Option<u32>,
53
54    /// Headers that should be exposed to the client JavaScript.
55    pub expose_headers: Vec<HeaderName>,
56}
57
58impl Default for CorsConfig {
59    fn default() -> Self {
60        Self {
61            allow_origins: AllowOrigins::Any,
62            allow_methods: vec![Method::GET, Method::POST, Method::OPTIONS],
63            allow_headers: vec![header::CONTENT_TYPE, header::AUTHORIZATION],
64            allow_credentials: false,
65            max_age: Some(86_400), // 24 hours
66            expose_headers: vec![],
67        }
68    }
69}
70
71/// Policy for allowed origins.
72#[derive(Clone, Debug)]
73pub enum AllowOrigins {
74    /// Allow any origin (`*`).
75    ///
76    /// **Cannot** be used with `allow_credentials = true`.
77    Any,
78
79    /// Allow only specific origins.
80    List(HashSet<String>),
81
82    /// Echo the `Origin` header back (required when `allow_credentials = true`).
83    Echo,
84}
85
86/// CORS middleware implementing the `Middleware` trait.
87///
88/// Handles both **preflight** (`OPTIONS`) and **actual** requests,
89/// adding appropriate CORS headers and rejecting invalid origins/methods/headers.
90#[derive(Clone)]
91pub struct CorsMiddleware {
92    config: Arc<CorsConfig>,
93}
94
95impl CorsMiddleware {
96    /// Create a new CORS middleware with custom config.
97    pub fn new(config: CorsConfig) -> Self {
98        Self {
99            config: Arc::new(config),
100        }
101    }
102
103    /// Create a permissive CORS config — useful for public APIs or local dev.
104    ///
105    /// Allows all common methods, credentials, and common headers.
106    pub fn permissive() -> Self {
107        Self::new(CorsConfig {
108            allow_origins: AllowOrigins::Any,
109            allow_methods: vec![
110                Method::GET,
111                Method::POST,
112                Method::PUT,
113                Method::DELETE,
114                Method::PATCH,
115                Method::OPTIONS,
116                Method::HEAD,
117            ],
118            allow_headers: vec![
119                header::CONTENT_TYPE,
120                header::AUTHORIZATION,
121                header::ACCEPT,
122                header::ORIGIN,
123            ],
124            allow_credentials: true,
125            max_age: Some(86_400),
126            expose_headers: vec![],
127        })
128    }
129
130    // Internal: resolve allowed origin header value
131    fn resolve_allowed_origin(&self, origin: &str) -> Option<String> {
132        match &self.config.allow_origins {
133            AllowOrigins::Any => {
134                // Only return "*" if credentials are not allowed
135                if self.config.allow_credentials {
136                    // rule MDN , RFC 6454
137                    // If Access-Control-Allow-Credentials: true is set,
138                    // then Access-Control-Allow-Origin CANNOT be *.
139                    // It MUST be the exact origin (e.g., https://example.com).
140                    Some(origin.to_string())
141                } else {
142                    Some("*".to_string())
143                }
144            }
145            AllowOrigins::List(allowed) => {
146                if allowed.contains(origin) {
147                    Some(origin.to_string())
148                } else {
149                    None
150                }
151            }
152            AllowOrigins::Echo => Some(origin.to_string()),
153        }
154    }
155
156    // Build preflight response (204 No Content)
157    fn preflight_response(&self, origin: &str) -> Response<GenericBody> {
158        let allowed_origin = self.resolve_allowed_origin(origin);
159        let mut resp = Response::builder()
160            .status(StatusCode::NO_CONTENT)
161            .body(empty_response())
162            .expect("preflight response is static");
163
164        let headers = resp.headers_mut();
165
166        if let Some(origin) = allowed_origin {
167            headers.insert(
168                ACCESS_CONTROL_ALLOW_ORIGIN,
169                HeaderValue::from_str(&origin).expect("origin is validated"),
170            );
171        }
172
173        if self.config.allow_credentials {
174            headers.insert(
175                ACCESS_CONTROL_ALLOW_CREDENTIALS,
176                HeaderValue::from_static("true"),
177            );
178        }
179
180        if let Some(age) = self.config.max_age {
181            headers.insert(
182                ACCESS_CONTROL_MAX_AGE,
183                HeaderValue::from_str(&age.to_string()).expect("u32 is valid"),
184            );
185        }
186
187        let methods = self
188            .config
189            .allow_methods
190            .iter()
191            .map(|m| m.as_str())
192            .collect::<Vec<_>>()
193            .join(", ");
194        headers.insert(
195            ACCESS_CONTROL_ALLOW_METHODS,
196            HeaderValue::from_str(&methods).expect("methods are static"),
197        );
198
199        let headers_list = self
200            .config
201            .allow_headers
202            .iter()
203            .map(|h| h.as_str())
204            .collect::<Vec<_>>()
205            .join(", ");
206        headers.insert(
207            ACCESS_CONTROL_ALLOW_HEADERS,
208            HeaderValue::from_str(&headers_list).expect("headers are static"),
209        );
210
211        resp
212    }
213
214    // Add CORS headers to normal response
215    fn add_cors_to_response(
216        &self,
217        mut resp: Response<GenericBody>,
218        origin: &str,
219    ) -> Response<GenericBody> {
220        let allowed_origin = self.resolve_allowed_origin(origin);
221        let headers = resp.headers_mut();
222
223        if let Some(origin) = allowed_origin {
224            headers.insert(
225                ACCESS_CONTROL_ALLOW_ORIGIN,
226                HeaderValue::from_str(&origin).expect("origin is validated"),
227            );
228        }
229
230        if self.config.allow_credentials {
231            headers.insert(
232                ACCESS_CONTROL_ALLOW_CREDENTIALS,
233                HeaderValue::from_static("true"),
234            );
235        }
236
237        if !self.config.expose_headers.is_empty() {
238            let expose = self
239                .config
240                .expose_headers
241                .iter()
242                .map(|h| h.as_str())
243                .collect::<Vec<_>>()
244                .join(", ");
245            headers.insert(
246                ACCESS_CONTROL_EXPOSE_HEADERS,
247                HeaderValue::from_str(&expose).expect("expose headers are static"),
248            );
249        }
250
251        resp
252    }
253}
254
255// Middleware trait implementation
256#[async_trait::async_trait]
257impl Middleware for CorsMiddleware {
258    /// Process a request, handling preflight or adding CORS headers.
259    ///
260    /// - For `OPTIONS` with `Access-Control-Request-Method`: performs preflight.
261    /// - For other requests: passes to `next`, then adds CORS headers.
262    async fn handle<'req>(
263        &self,
264        req: Request<&'req str>,
265        state: Arc<McpAppState>,
266        next: MiddlewareNext<'req>,
267    ) -> TransportServerResult<Response<GenericBody>> {
268        let origin = req
269            .headers()
270            .get(header::ORIGIN)
271            .and_then(|v| v.to_str().ok())
272            .map(|s| s.to_string());
273
274        // Preflight: OPTIONS + Access-Control-Request-Method
275        if *req.method() == Method::OPTIONS {
276            let requested_method = req
277                .headers()
278                .get(ACCESS_CONTROL_REQUEST_METHOD)
279                .and_then(|v| v.to_str().ok())
280                .and_then(|s| s.parse::<Method>().ok());
281
282            let requested_headers = req
283                .headers()
284                .get(ACCESS_CONTROL_REQUEST_HEADERS)
285                .and_then(|v| v.to_str().ok())
286                .map(|s| {
287                    s.split(',')
288                        .map(|h| h.trim().to_ascii_lowercase())
289                        .collect::<HashSet<_>>()
290                })
291                .unwrap_or_default();
292
293            let origin = match origin {
294                Some(o) => o,
295                None => {
296                    // Some tools send preflight without Origin — allow if Any
297                    if matches!(self.config.allow_origins, AllowOrigins::Any)
298                        && !self.config.allow_credentials
299                    {
300                        return Ok(self.preflight_response("*"));
301                    } else {
302                        let response = build_response(
303                            StatusCode::BAD_REQUEST,
304                            "CORS origin missing in preflight".to_string(),
305                        );
306                        return response;
307                    }
308                }
309            };
310
311            // Validate origin
312            if self.resolve_allowed_origin(&origin).is_none() {
313                let response =
314                    build_response(StatusCode::FORBIDDEN, "CORS origin not allowed".to_string());
315                return response;
316            }
317
318            // Validate method
319            if let Some(m) = requested_method {
320                if !self.config.allow_methods.contains(&m) {
321                    let response = build_response(
322                        StatusCode::METHOD_NOT_ALLOWED,
323                        "CORS method not allowed".to_string(),
324                    );
325                    return response;
326                }
327            }
328
329            // Validate headers
330            let allowed = self
331                .config
332                .allow_headers
333                .iter()
334                .map(|h| h.as_str().to_ascii_lowercase())
335                .collect::<HashSet<_>>();
336
337            if !requested_headers.is_subset(&allowed) {
338                let response = build_response(
339                    StatusCode::BAD_REQUEST,
340                    "CORS header not allowed".to_string(),
341                );
342                return response;
343            }
344
345            // All good — return preflight
346            return Ok(self.preflight_response(&origin));
347        }
348
349        // Normal request: forward to next handler
350        let mut resp = next(req, state).await?;
351        if let Some(origin) = origin {
352            if self.resolve_allowed_origin(&origin).is_some() {
353                resp = self.add_cors_to_response(resp, &origin);
354            }
355        }
356
357        Ok(resp)
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::{
365        id_generator::{FastIdGenerator, UuidGenerator},
366        mcp_http::{types::GenericBodyExt, MiddlewareNext},
367        mcp_server::{ServerHandler, ToMcpServerHandler},
368        schema::{Implementation, InitializeResult, ProtocolVersion, ServerCapabilities},
369        session_store::InMemorySessionStore,
370    };
371    use http::{header, Request, Response, StatusCode};
372    use std::time::Duration;
373
374    type TestResult = Result<(), Box<dyn std::error::Error>>;
375    struct TestHandler;
376    impl ServerHandler for TestHandler {}
377
378    fn app_state() -> Arc<McpAppState> {
379        let handler = TestHandler {};
380
381        Arc::new(McpAppState {
382            session_store: Arc::new(InMemorySessionStore::new()),
383            id_generator: Arc::new(UuidGenerator {}),
384            stream_id_gen: Arc::new(FastIdGenerator::new(Some("s_"))),
385            server_details: Arc::new(InitializeResult {
386                capabilities: ServerCapabilities {
387                    ..Default::default()
388                },
389                instructions: None,
390                meta: None,
391                protocol_version: ProtocolVersion::V2025_06_18.to_string(),
392                server_info: Implementation {
393                    name: "server".to_string(),
394                    title: None,
395                    version: "0.1.0".to_string(),
396                },
397            }),
398            handler: handler.to_mcp_server_handler(),
399            ping_interval: Duration::from_secs(15),
400            transport_options: Arc::new(rust_mcp_transport::TransportOptions::default()),
401            enable_json_response: false,
402            event_store: None,
403        })
404    }
405
406    fn make_handler<'req>(status: StatusCode, body: &'static str) -> MiddlewareNext<'req> {
407        Arc::new(move |_, _| {
408            let resp = Response::builder()
409                .status(status)
410                .body(GenericBody::from_string(body.to_string()))
411                .unwrap();
412            Box::pin(async { Ok(resp) })
413        })
414    }
415
416    #[tokio::test]
417    async fn test_preflight_allowed() -> TestResult {
418        let cors = CorsMiddleware::permissive();
419        let handler = make_handler(StatusCode::OK, "should not see");
420
421        let req = Request::builder()
422            .method(Method::OPTIONS)
423            .uri("/")
424            .header(header::ORIGIN, "https://example.com")
425            .header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
426            .header(
427                ACCESS_CONTROL_REQUEST_HEADERS,
428                "content-type, authorization",
429            )
430            .body("")?;
431
432        let resp = cors.handle(req, app_state(), handler).await?;
433
434        assert_eq!(resp.status(), StatusCode::NO_CONTENT);
435        assert_eq!(
436            resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
437            "https://example.com"
438        );
439        assert_eq!(
440            resp.headers()[ACCESS_CONTROL_ALLOW_METHODS],
441            "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD"
442        );
443        Ok(())
444    }
445
446    #[tokio::test]
447    async fn test_preflight_disallowed_origin() -> TestResult {
448        let mut allowed = HashSet::new();
449        allowed.insert("https://trusted.com".to_string());
450
451        let cors = CorsMiddleware::new(CorsConfig {
452            allow_origins: AllowOrigins::List(allowed),
453            allow_methods: vec![Method::GET],
454            allow_headers: vec![],
455            allow_credentials: false,
456            max_age: None,
457            expose_headers: vec![],
458        });
459
460        let handler = make_handler(StatusCode::OK, "irrelevant");
461
462        let req = Request::builder()
463            .method(Method::OPTIONS)
464            .uri("/")
465            .header(header::ORIGIN, "https://evil.com")
466            .header(ACCESS_CONTROL_REQUEST_METHOD, "GET")
467            .body("")?;
468
469        let result: Response<GenericBody> = cors.handle(req, app_state(), handler).await.unwrap();
470        let (parts, _body) = result.into_parts();
471        assert_eq!(parts.status, 403);
472        Ok(())
473    }
474
475    #[tokio::test]
476    async fn test_normal_request_with_origin() -> TestResult {
477        let cors = CorsMiddleware::permissive();
478        let handler = make_handler(StatusCode::OK, "hello");
479
480        let req = Request::builder()
481            .method(Method::GET)
482            .uri("/")
483            .header(header::ORIGIN, "https://client.com")
484            .body("")?;
485
486        let resp = cors.handle(req, app_state(), handler).await?;
487
488        assert_eq!(resp.status(), StatusCode::OK);
489
490        assert_eq!(
491            resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
492            "https://client.com"
493        );
494        assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS], "true");
495        Ok(())
496    }
497
498    #[tokio::test]
499    async fn test_wildcard_with_no_credentials() -> TestResult {
500        let cors = CorsMiddleware::new(CorsConfig {
501            allow_origins: AllowOrigins::Any,
502            allow_methods: vec![Method::GET],
503            allow_headers: vec![],
504            allow_credentials: false,
505            max_age: None,
506            expose_headers: vec![],
507        });
508
509        let handler = make_handler(StatusCode::OK, "ok");
510
511        let req = Request::builder()
512            .method(Method::GET)
513            .uri("/")
514            .header(header::ORIGIN, "https://any.com")
515            .body("")?;
516
517        let resp = cors.handle(req, app_state(), handler).await?;
518        assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN], "*");
519        Ok(())
520    }
521
522    #[tokio::test]
523    async fn test_no_wildcard_with_credentials() -> TestResult {
524        let cors = CorsMiddleware::new(CorsConfig {
525            allow_origins: AllowOrigins::Any,
526            allow_methods: vec![Method::GET],
527            allow_headers: vec![],
528            allow_credentials: true, // This should prevent "*"
529            max_age: None,
530            expose_headers: vec![],
531        });
532
533        let handler = make_handler(StatusCode::OK, "ok");
534
535        let req = Request::builder()
536            .method(Method::GET)
537            .uri("/")
538            .header(header::ORIGIN, "https://any.com")
539            .body("")?;
540
541        let resp = cors.handle(req, app_state(), handler).await?;
542
543        // Should NOT have "*" even though config says Any
544        let origin_header = resp
545            .headers()
546            .get(ACCESS_CONTROL_ALLOW_ORIGIN)
547            .expect("CORS header missing");
548        assert_eq!(origin_header, "https://any.com");
549
550        // And credentials should be allowed
551        assert_eq!(
552            resp.headers()
553                .get(ACCESS_CONTROL_ALLOW_CREDENTIALS)
554                .unwrap(),
555            "true"
556        );
557        Ok(())
558    }
559
560    #[tokio::test]
561    async fn test_echo_origin_with_credentials() -> TestResult {
562        let cors = CorsMiddleware::new(CorsConfig {
563            allow_origins: AllowOrigins::Echo,
564            allow_methods: vec![Method::GET],
565            allow_headers: vec![],
566            allow_credentials: true,
567            max_age: None,
568            expose_headers: vec![],
569        });
570
571        let handler = make_handler(StatusCode::OK, "ok");
572
573        let req = Request::builder()
574            .method(Method::GET)
575            .uri("/")
576            .header(header::ORIGIN, "https://dynamic.com")
577            .body("")?;
578
579        let resp = cors.handle(req, app_state(), handler).await?;
580        assert_eq!(
581            resp.headers()[ACCESS_CONTROL_ALLOW_ORIGIN],
582            "https://dynamic.com"
583        );
584        assert_eq!(resp.headers()[ACCESS_CONTROL_ALLOW_CREDENTIALS], "true");
585        Ok(())
586    }
587
588    #[tokio::test]
589    async fn test_expose_headers() -> TestResult {
590        let cors = CorsMiddleware::new(CorsConfig {
591            allow_origins: AllowOrigins::Any,
592            allow_methods: vec![Method::GET],
593            allow_headers: vec![],
594            allow_credentials: false,
595            max_age: None,
596            expose_headers: vec![HeaderName::from_static("x-ratelimit-remaining")],
597        });
598
599        let handler = make_handler(StatusCode::OK, "ok");
600
601        let req = Request::builder()
602            .method(Method::GET)
603            .uri("/")
604            .header(header::ORIGIN, "https://client.com")
605            .body("")?;
606
607        let resp = cors.handle(req, app_state(), handler).await?;
608        assert_eq!(
609            resp.headers()[ACCESS_CONTROL_EXPOSE_HEADERS],
610            "x-ratelimit-remaining"
611        );
612        Ok(())
613    }
614}